Run rustfmt, clean up aquatic_http_protocol/Cargo.toml

This commit is contained in:
Joakim Frostegård 2021-08-15 22:26:11 +02:00
parent 0cc312a78d
commit d0e716f80b
65 changed files with 1754 additions and 2590 deletions

View file

@ -1,15 +1,9 @@
use aquatic_cli_helpers::run_app_with_cli_and_config;
use aquatic_http::config::Config;
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
fn main(){
run_app_with_cli_and_config::<Config>(
aquatic_http::APP_NAME,
aquatic_http::run,
None
)
}
fn main() {
run_app_with_cli_and_config::<Config>(aquatic_http::APP_NAME, aquatic_http::run, None)
}

View file

@ -1,32 +1,29 @@
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use crossbeam_channel::{Receiver, Sender};
use either::Either;
use crossbeam_channel::{Sender, Receiver};
use hashbrown::HashMap;
use indexmap::IndexMap;
use log::error;
use mio::Token;
use parking_lot::Mutex;
use smartstring::{SmartString, LazyCompact};
use smartstring::{LazyCompact, SmartString};
pub use aquatic_common::{ValidUntil, convert_ipv4_mapped_ipv6};
pub use aquatic_common::{convert_ipv4_mapped_ipv6, ValidUntil};
use aquatic_http_protocol::common::*;
use aquatic_http_protocol::request::Request;
use aquatic_http_protocol::response::{Response, ResponsePeer};
pub const LISTENER_TOKEN: Token = Token(0);
pub const CHANNEL_TOKEN: Token = Token(1);
pub trait Ip: ::std::fmt::Debug + Copy + Eq + ::std::hash::Hash {}
impl Ip for Ipv4Addr {}
impl Ip for Ipv6Addr {}
#[derive(Clone, Copy, Debug)]
pub struct ConnectionMeta {
/// Index of socket worker responsible for this connection. Required for
@ -36,7 +33,6 @@ pub struct ConnectionMeta {
pub poll_token: Token,
}
#[derive(Clone, Copy, Debug)]
pub struct PeerConnectionMeta<I: Ip> {
pub worker_index: usize,
@ -44,24 +40,19 @@ pub struct PeerConnectionMeta<I: Ip> {
pub peer_ip_address: I,
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum PeerStatus {
Seeding,
Leeching,
Stopped
Stopped,
}
impl PeerStatus {
/// Determine peer status from announce event and number of bytes left.
///
///
/// Likely, the last branch will be taken most of the time.
#[inline]
pub fn from_event_and_bytes_left(
event: AnnounceEvent,
opt_bytes_left: Option<usize>
) -> Self {
pub fn from_event_and_bytes_left(event: AnnounceEvent, opt_bytes_left: Option<usize>) -> Self {
if let AnnounceEvent::Stopped = event {
Self::Stopped
} else if let Some(0) = opt_bytes_left {
@ -72,7 +63,6 @@ impl PeerStatus {
}
}
#[derive(Debug, Clone, Copy)]
pub struct Peer<I: Ip> {
pub connection_meta: PeerConnectionMeta<I>,
@ -81,35 +71,30 @@ pub struct Peer<I: Ip> {
pub valid_until: ValidUntil,
}
impl <I: Ip>Peer<I> {
impl<I: Ip> Peer<I> {
pub fn to_response_peer(&self) -> ResponsePeer<I> {
ResponsePeer {
ip_address: self.connection_meta.peer_ip_address,
port: self.port
port: self.port,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PeerMapKey<I: Ip> {
pub peer_id: PeerId,
pub ip_or_key: Either<I, SmartString<LazyCompact>>
pub ip_or_key: Either<I, SmartString<LazyCompact>>,
}
pub type PeerMap<I> = IndexMap<PeerMapKey<I>, Peer<I>>;
pub struct TorrentData<I: Ip> {
pub peers: PeerMap<I>,
pub num_seeders: usize,
pub num_leechers: usize,
}
impl <I: Ip> Default for TorrentData<I> {
impl<I: Ip> Default for TorrentData<I> {
#[inline]
fn default() -> Self {
Self {
@ -120,23 +105,19 @@ impl <I: Ip> Default for TorrentData<I> {
}
}
pub type TorrentMap<I> = HashMap<InfoHash, TorrentData<I>>;
#[derive(Default)]
pub struct TorrentMaps {
pub ipv4: TorrentMap<Ipv4Addr>,
pub ipv6: TorrentMap<Ipv6Addr>,
}
#[derive(Clone)]
pub struct State {
pub torrent_maps: Arc<Mutex<TorrentMaps>>,
}
impl Default for State {
fn default() -> Self {
Self {
@ -145,39 +126,27 @@ impl Default for State {
}
}
pub type RequestChannelSender = Sender<(ConnectionMeta, Request)>;
pub type RequestChannelReceiver = Receiver<(ConnectionMeta, Request)>;
pub type ResponseChannelReceiver = Receiver<(ConnectionMeta, Response)>;
#[derive(Clone)]
pub struct ResponseChannelSender {
senders: Vec<Sender<(ConnectionMeta, Response)>>,
}
impl ResponseChannelSender {
pub fn new(
senders: Vec<Sender<(ConnectionMeta, Response)>>,
) -> Self {
Self {
senders,
}
pub fn new(senders: Vec<Sender<(ConnectionMeta, Response)>>) -> Self {
Self { senders }
}
#[inline]
pub fn send(
&self,
meta: ConnectionMeta,
message: Response
){
if let Err(err) = self.senders[meta.worker_index].send((meta, message)){
pub fn send(&self, meta: ConnectionMeta, message: Response) {
if let Err(err) = self.senders[meta.worker_index].send((meta, message)) {
error!("ResponseChannelSender: couldn't send message: {:?}", err);
}
}
}
pub type SocketWorkerStatus = Option<Result<(), String>>;
pub type SocketWorkerStatuses = Arc<Mutex<Vec<SocketWorkerStatus>>>;
pub type SocketWorkerStatuses = Arc<Mutex<Vec<SocketWorkerStatus>>>;

View file

@ -1,10 +1,9 @@
use std::net::SocketAddr;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use aquatic_cli_helpers::LogLevel;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct Config {
@ -24,14 +23,12 @@ pub struct Config {
pub privileges: PrivilegeConfig,
}
impl aquatic_cli_helpers::Config for Config {
fn get_log_level(&self) -> Option<LogLevel>{
fn get_log_level(&self) -> Option<LogLevel> {
Some(self.log_level)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct TlsConfig {
@ -40,7 +37,6 @@ pub struct TlsConfig {
pub tls_pkcs12_password: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct NetworkConfig {
@ -54,7 +50,6 @@ pub struct NetworkConfig {
pub poll_timeout_microseconds: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct ProtocolConfig {
@ -66,7 +61,6 @@ pub struct ProtocolConfig {
pub peer_announce_interval: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct HandlerConfig {
@ -76,7 +70,6 @@ pub struct HandlerConfig {
pub channel_recv_timeout_microseconds: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct CleaningConfig {
@ -88,7 +81,6 @@ pub struct CleaningConfig {
pub max_connection_age: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct StatisticsConfig {
@ -96,7 +88,6 @@ pub struct StatisticsConfig {
pub interval: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct PrivilegeConfig {
@ -108,7 +99,6 @@ pub struct PrivilegeConfig {
pub user: String,
}
impl Default for Config {
fn default() -> Self {
Self {
@ -125,7 +115,6 @@ impl Default for Config {
}
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
@ -139,7 +128,6 @@ impl Default for NetworkConfig {
}
}
impl Default for ProtocolConfig {
fn default() -> Self {
Self {
@ -150,7 +138,6 @@ impl Default for ProtocolConfig {
}
}
impl Default for HandlerConfig {
fn default() -> Self {
Self {
@ -160,7 +147,6 @@ impl Default for HandlerConfig {
}
}
impl Default for CleaningConfig {
fn default() -> Self {
Self {
@ -171,16 +157,12 @@ impl Default for CleaningConfig {
}
}
impl Default for StatisticsConfig {
fn default() -> Self {
Self {
interval: 0,
}
Self { interval: 0 }
}
}
impl Default for PrivilegeConfig {
fn default() -> Self {
Self {
@ -191,7 +173,6 @@ impl Default for PrivilegeConfig {
}
}
impl Default for TlsConfig {
fn default() -> Self {
Self {

View file

@ -1,13 +1,13 @@
use std::collections::BTreeMap;
use std::time::Duration;
use std::vec::Drain;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use std::time::Duration;
use std::vec::Drain;
use either::Either;
use mio::Waker;
use parking_lot::MutexGuard;
use rand::{Rng, SeedableRng, rngs::SmallRng};
use rand::{rngs::SmallRng, Rng, SeedableRng};
use aquatic_common::extract_response_peers;
use aquatic_http_protocol::request::*;
@ -16,14 +16,13 @@ use aquatic_http_protocol::response::*;
use crate::common::*;
use crate::config::Config;
pub fn run_request_worker(
config: Config,
state: State,
request_channel_receiver: RequestChannelReceiver,
response_channel_sender: ResponseChannelSender,
wakers: Vec<Arc<Waker>>,
){
) {
let mut wake_socket_workers: Vec<bool> = (0..config.socket_workers).map(|_| false).collect();
let mut announce_requests = Vec::new();
@ -31,9 +30,7 @@ pub fn run_request_worker(
let mut rng = SmallRng::from_entropy();
let timeout = Duration::from_micros(
config.handlers.channel_recv_timeout_microseconds
);
let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds);
loop {
let mut opt_torrent_map_guard: Option<MutexGuard<TorrentMaps>> = None;
@ -51,22 +48,22 @@ pub fn run_request_worker(
match opt_in_message {
Some((meta, Request::Announce(r))) => {
announce_requests.push((meta, r));
},
}
Some((meta, Request::Scrape(r))) => {
scrape_requests.push((meta, r));
},
}
None => {
if let Some(torrent_guard) = state.torrent_maps.try_lock(){
if let Some(torrent_guard) = state.torrent_maps.try_lock() {
opt_torrent_map_guard = Some(torrent_guard);
break
break;
}
}
}
}
let mut torrent_map_guard = opt_torrent_map_guard
.unwrap_or_else(|| state.torrent_maps.lock());
let mut torrent_map_guard =
opt_torrent_map_guard.unwrap_or_else(|| state.torrent_maps.lock());
handle_announce_requests(
&config,
@ -74,7 +71,7 @@ pub fn run_request_worker(
&mut torrent_map_guard,
&response_channel_sender,
&mut wake_socket_workers,
announce_requests.drain(..)
announce_requests.drain(..),
);
handle_scrape_requests(
@ -82,12 +79,12 @@ pub fn run_request_worker(
&mut torrent_map_guard,
&response_channel_sender,
&mut wake_socket_workers,
scrape_requests.drain(..)
scrape_requests.drain(..),
);
for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate(){
for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() {
if *wake {
if let Err(err) = wakers[worker_index].wake(){
if let Err(err) = wakers[worker_index].wake() {
::log::error!("request handler couldn't wake poll: {:?}", err);
}
@ -97,7 +94,6 @@ pub fn run_request_worker(
}
}
pub fn handle_announce_requests(
config: &Config,
rng: &mut impl Rng,
@ -105,22 +101,19 @@ pub fn handle_announce_requests(
response_channel_sender: &ResponseChannelSender,
wake_socket_workers: &mut Vec<bool>,
requests: Drain<(ConnectionMeta, AnnounceRequest)>,
){
) {
let valid_until = ValidUntil::new(config.cleaning.max_peer_age);
for (meta, request) in requests {
let peer_ip = convert_ipv4_mapped_ipv6(
meta.peer_addr.ip()
);
let peer_ip = convert_ipv4_mapped_ipv6(meta.peer_addr.ip());
::log::debug!("peer ip: {:?}", peer_ip);
let response = match peer_ip {
IpAddr::V4(peer_ip_address) => {
let torrent_data: &mut TorrentData<Ipv4Addr> = torrent_maps.ipv4
.entry(request.info_hash)
.or_default();
let torrent_data: &mut TorrentData<Ipv4Addr> =
torrent_maps.ipv4.entry(request.info_hash).or_default();
let peer_connection_meta = PeerConnectionMeta {
worker_index: meta.worker_index,
poll_token: meta.poll_token,
@ -133,7 +126,7 @@ pub fn handle_announce_requests(
peer_connection_meta,
torrent_data,
request,
valid_until
valid_until,
);
let response = AnnounceResponse {
@ -145,16 +138,15 @@ pub fn handle_announce_requests(
};
Response::Announce(response)
},
}
IpAddr::V6(peer_ip_address) => {
let torrent_data: &mut TorrentData<Ipv6Addr> = torrent_maps.ipv6
.entry(request.info_hash)
.or_default();
let torrent_data: &mut TorrentData<Ipv6Addr> =
torrent_maps.ipv6.entry(request.info_hash).or_default();
let peer_connection_meta = PeerConnectionMeta {
worker_index: meta.worker_index,
poll_token: meta.poll_token,
peer_ip_address
peer_ip_address,
};
let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers(
@ -163,7 +155,7 @@ pub fn handle_announce_requests(
peer_connection_meta,
torrent_data,
request,
valid_until
valid_until,
);
let response = AnnounceResponse {
@ -175,15 +167,14 @@ pub fn handle_announce_requests(
};
Response::Announce(response)
},
}
};
response_channel_sender.send(meta, response);
wake_socket_workers[meta.worker_index] = true;
};
}
}
/// Insert/update peer. Return num_seeders, num_leechers and response peers
fn upsert_peer_and_get_response_peers<I: Ip>(
config: &Config,
@ -195,10 +186,8 @@ fn upsert_peer_and_get_response_peers<I: Ip>(
) -> (usize, usize, Vec<ResponsePeer<I>>) {
// Insert/update/remove peer who sent this request
let peer_status = PeerStatus::from_event_and_bytes_left(
request.event,
Some(request.bytes_left)
);
let peer_status =
PeerStatus::from_event_and_bytes_left(request.event, Some(request.bytes_left));
let peer = Peer {
connection_meta: request_sender_meta,
@ -209,11 +198,10 @@ fn upsert_peer_and_get_response_peers<I: Ip>(
::log::debug!("peer: {:?}", peer);
let ip_or_key = request.key
let ip_or_key = request
.key
.map(Either::Right)
.unwrap_or_else(||
Either::Left(request_sender_meta.peer_ip_address)
);
.unwrap_or_else(|| Either::Left(request_sender_meta.peer_ip_address));
let peer_map_key = PeerMapKey {
peer_id: request.peer_id,
@ -227,26 +215,24 @@ fn upsert_peer_and_get_response_peers<I: Ip>(
torrent_data.num_leechers += 1;
torrent_data.peers.insert(peer_map_key.clone(), peer)
},
}
PeerStatus::Seeding => {
torrent_data.num_seeders += 1;
torrent_data.peers.insert(peer_map_key.clone(), peer)
},
PeerStatus::Stopped => {
torrent_data.peers.remove(&peer_map_key)
}
PeerStatus::Stopped => torrent_data.peers.remove(&peer_map_key),
};
::log::debug!("opt_removed_peer: {:?}", opt_removed_peer);
match opt_removed_peer.map(|peer| peer.status){
match opt_removed_peer.map(|peer| peer.status) {
Some(PeerStatus::Leeching) => {
torrent_data.num_leechers -= 1;
},
}
Some(PeerStatus::Seeding) => {
torrent_data.num_seeders -= 1;
},
}
_ => {}
}
@ -262,38 +248,40 @@ fn upsert_peer_and_get_response_peers<I: Ip>(
&torrent_data.peers,
max_num_peers_to_take,
peer_map_key,
Peer::to_response_peer
Peer::to_response_peer,
);
(torrent_data.num_seeders, torrent_data.num_leechers, response_peers)
(
torrent_data.num_seeders,
torrent_data.num_leechers,
response_peers,
)
}
pub fn handle_scrape_requests(
config: &Config,
torrent_maps: &mut TorrentMaps,
response_channel_sender: &ResponseChannelSender,
wake_socket_workers: &mut Vec<bool>,
requests: Drain<(ConnectionMeta, ScrapeRequest)>,
){
) {
for (meta, request) in requests {
let num_to_take = request.info_hashes.len().min(
config.protocol.max_scrape_torrents
);
let num_to_take = request
.info_hashes
.len()
.min(config.protocol.max_scrape_torrents);
let mut response = ScrapeResponse {
files: BTreeMap::new(),
};
let peer_ip = convert_ipv4_mapped_ipv6(
meta.peer_addr.ip()
);
let peer_ip = convert_ipv4_mapped_ipv6(meta.peer_addr.ip());
// If request.info_hashes is empty, don't return scrape for all
// torrents, even though reference server does it. It is too expensive.
if peer_ip.is_ipv4(){
for info_hash in request.info_hashes.into_iter().take(num_to_take){
if let Some(torrent_data) = torrent_maps.ipv4.get(&info_hash){
if peer_ip.is_ipv4() {
for info_hash in request.info_hashes.into_iter().take(num_to_take) {
if let Some(torrent_data) = torrent_maps.ipv4.get(&info_hash) {
let stats = ScrapeStatistics {
complete: torrent_data.num_seeders,
downloaded: 0, // No implementation planned
@ -304,8 +292,8 @@ pub fn handle_scrape_requests(
}
}
} else {
for info_hash in request.info_hashes.into_iter().take(num_to_take){
if let Some(torrent_data) = torrent_maps.ipv6.get(&info_hash){
for info_hash in request.info_hashes.into_iter().take(num_to_take) {
if let Some(torrent_data) = torrent_maps.ipv6.get(&info_hash) {
let stats = ScrapeStatistics {
complete: torrent_data.num_seeders,
downloaded: 0, // No implementation planned
@ -317,8 +305,7 @@ pub fn handle_scrape_requests(
}
};
response_channel_sender.send(meta, Response::Scrape(response));
wake_socket_workers[meta.worker_index] = true;
};
}
}
}

View file

@ -1,6 +1,6 @@
use std::time::Duration;
use std::sync::Arc;
use std::thread::Builder;
use std::time::Duration;
use anyhow::Context;
use mio::{Poll, Waker};
@ -17,10 +17,8 @@ use common::*;
use config::Config;
use network::utils::create_tls_acceptor;
pub const APP_NAME: &str = "aquatic_http: HTTP/TLS BitTorrent tracker";
pub fn run(config: Config) -> anyhow::Result<()> {
let state = State::default();
@ -33,7 +31,6 @@ pub fn run(config: Config) -> anyhow::Result<()> {
}
}
pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
let opt_tls_acceptor = create_tls_acceptor(&config.network.tls)?;
@ -65,17 +62,19 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
out_message_senders.push(response_channel_sender);
wakers.push(waker);
Builder::new().name(format!("socket-{:02}", i + 1)).spawn(move || {
network::run_socket_worker(
config,
i,
socket_worker_statuses,
request_channel_sender,
response_channel_receiver,
opt_tls_acceptor,
poll
);
})?;
Builder::new()
.name(format!("socket-{:02}", i + 1))
.spawn(move || {
network::run_socket_worker(
config,
i,
socket_worker_statuses,
request_channel_sender,
response_channel_receiver,
opt_tls_acceptor,
poll,
);
})?;
}
// Wait for socket worker statuses. On error from any, quit program.
@ -84,14 +83,14 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
loop {
::std::thread::sleep(::std::time::Duration::from_millis(10));
if let Some(statuses) = socket_worker_statuses.try_lock(){
for opt_status in statuses.iter(){
if let Some(Err(err)) = opt_status {
if let Some(statuses) = socket_worker_statuses.try_lock() {
for opt_status in statuses.iter() {
if let Some(Err(err)) = opt_status {
return Err(::anyhow::anyhow!(err.to_owned()));
}
}
if statuses.iter().all(Option::is_some){
if statuses.iter().all(Option::is_some) {
if config.privileges.drop_privileges {
PrivDrop::default()
.chroot(config.privileges.chroot_path.clone())
@ -100,7 +99,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
.context("Couldn't drop root privileges")?;
}
break
break;
}
}
}
@ -114,32 +113,32 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
let response_channel_sender = response_channel_sender.clone();
let wakers = wakers.clone();
Builder::new().name(format!("request-{:02}", i + 1)).spawn(move || {
handler::run_request_worker(
config,
state,
request_channel_receiver,
response_channel_sender,
wakers,
);
})?;
Builder::new()
.name(format!("request-{:02}", i + 1))
.spawn(move || {
handler::run_request_worker(
config,
state,
request_channel_receiver,
response_channel_sender,
wakers,
);
})?;
}
if config.statistics.interval != 0 {
let state = state.clone();
let config = config.clone();
Builder::new().name("statistics".to_string()).spawn(move ||
loop {
::std::thread::sleep(Duration::from_secs(
config.statistics.interval
));
Builder::new()
.name("statistics".to_string())
.spawn(move || loop {
::std::thread::sleep(Duration::from_secs(config.statistics.interval));
tasks::print_statistics(&state);
}
).expect("spawn statistics thread");
})
.expect("spawn statistics thread");
}
Ok(())
}

View file

@ -1,12 +1,12 @@
use std::net::{SocketAddr};
use std::io::ErrorKind;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::sync::Arc;
use hashbrown::HashMap;
use mio::{Token, Poll};
use mio::net::TcpStream;
use native_tls::{TlsAcceptor, MidHandshakeTlsStream};
use mio::{Poll, Token};
use native_tls::{MidHandshakeTlsStream, TlsAcceptor};
use aquatic_http_protocol::request::{Request, RequestParseError};
@ -14,7 +14,6 @@ use crate::common::*;
use super::stream::Stream;
#[derive(Debug)]
pub enum RequestReadError {
NeedMoreData,
@ -23,7 +22,6 @@ pub enum RequestReadError {
Io(::std::io::Error),
}
pub struct EstablishedConnection {
stream: Stream,
pub peer_addr: SocketAddr,
@ -31,7 +29,6 @@ pub struct EstablishedConnection {
bytes_read: usize,
}
impl EstablishedConnection {
#[inline]
fn new(stream: Stream) -> Self {
@ -46,11 +43,11 @@ impl EstablishedConnection {
}
pub fn read_request(&mut self) -> Result<Request, RequestReadError> {
if (self.buf.len() - self.bytes_read < 512) & (self.buf.len() <= 3072){
if (self.buf.len() - self.bytes_read < 512) & (self.buf.len() <= 3072) {
self.buf.extend_from_slice(&[0; 1024]);
}
match self.stream.read(&mut self.buf[self.bytes_read..]){
match self.stream.read(&mut self.buf[self.bytes_read..]) {
Ok(0) => {
self.clear_buffer();
@ -60,10 +57,10 @@ impl EstablishedConnection {
self.bytes_read += bytes_read;
::log::debug!("read_request read {} bytes", bytes_read);
},
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
return Err(RequestReadError::NeedMoreData);
},
}
Err(err) => {
self.clear_buffer();
@ -71,20 +68,18 @@ impl EstablishedConnection {
}
}
match Request::from_bytes(&self.buf[..self.bytes_read]){
match Request::from_bytes(&self.buf[..self.bytes_read]) {
Ok(request) => {
self.clear_buffer();
Ok(request)
},
Err(RequestParseError::NeedMoreData) => {
Err(RequestReadError::NeedMoreData)
},
}
Err(RequestParseError::NeedMoreData) => Err(RequestReadError::NeedMoreData),
Err(RequestParseError::Invalid(err)) => {
self.clear_buffer();
Err(RequestReadError::Parse(err))
},
}
}
}
@ -92,9 +87,7 @@ impl EstablishedConnection {
let content_len = body.len() + 2; // 2 is for newlines at end
let content_len_num_digits = Self::num_digits_in_usize(content_len);
let mut response = Vec::with_capacity(
39 + content_len_num_digits + body.len()
);
let mut response = Vec::with_capacity(39 + content_len_num_digits + body.len());
response.extend_from_slice(b"HTTP/1.1 200 OK\r\nContent-Length: ");
::itoa::write(&mut response, content_len)?;
@ -130,40 +123,33 @@ impl EstablishedConnection {
}
#[inline]
pub fn clear_buffer(&mut self){
pub fn clear_buffer(&mut self) {
self.bytes_read = 0;
self.buf = Vec::new();
}
}
pub enum TlsHandshakeMachineError {
WouldBlock(TlsHandshakeMachine),
Failure(native_tls::Error)
Failure(native_tls::Error),
}
enum TlsHandshakeMachineInner {
TcpStream(TcpStream),
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
}
pub struct TlsHandshakeMachine {
tls_acceptor: Arc<TlsAcceptor>,
inner: TlsHandshakeMachineInner,
}
impl <'a>TlsHandshakeMachine {
impl<'a> TlsHandshakeMachine {
#[inline]
fn new(
tls_acceptor: Arc<TlsAcceptor>,
tcp_stream: TcpStream
) -> Self {
fn new(tls_acceptor: Arc<TlsAcceptor>, tcp_stream: TcpStream) -> Self {
Self {
tls_acceptor,
inner: TlsHandshakeMachineInner::TcpStream(tcp_stream)
inner: TlsHandshakeMachineInner::TcpStream(tcp_stream),
}
}
@ -171,36 +157,28 @@ impl <'a>TlsHandshakeMachine {
/// the machine wrapped in an error for later attempts.
pub fn establish_tls(self) -> Result<EstablishedConnection, TlsHandshakeMachineError> {
let handshake_result = match self.inner {
TlsHandshakeMachineInner::TcpStream(stream) => {
self.tls_acceptor.accept(stream)
},
TlsHandshakeMachineInner::TlsMidHandshake(handshake) => {
handshake.handshake()
},
TlsHandshakeMachineInner::TcpStream(stream) => self.tls_acceptor.accept(stream),
TlsHandshakeMachineInner::TlsMidHandshake(handshake) => handshake.handshake(),
};
match handshake_result {
Ok(stream) => {
let established = EstablishedConnection::new(
Stream::TlsStream(stream)
);
let established = EstablishedConnection::new(Stream::TlsStream(stream));
::log::debug!("established tls connection");
Ok(established)
},
}
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
let inner = TlsHandshakeMachineInner::TlsMidHandshake(
handshake
);
let inner = TlsHandshakeMachineInner::TlsMidHandshake(handshake);
let machine = Self {
tls_acceptor: self.tls_acceptor,
inner,
};
Err(TlsHandshakeMachineError::WouldBlock(machine))
},
}
Err(native_tls::HandshakeError::Failure(err)) => {
Err(TlsHandshakeMachineError::Failure(err))
}
@ -208,19 +186,16 @@ impl <'a>TlsHandshakeMachine {
}
}
enum ConnectionInner {
Established(EstablishedConnection),
InProgress(TlsHandshakeMachine),
}
pub struct Connection {
pub valid_until: ValidUntil,
inner: ConnectionInner,
}
impl Connection {
#[inline]
pub fn new(
@ -230,42 +205,29 @@ impl Connection {
) -> Self {
// Setup handshake machine if TLS is requested
let inner = if let Some(tls_acceptor) = opt_tls_acceptor {
ConnectionInner::InProgress(
TlsHandshakeMachine::new(tls_acceptor.clone(), tcp_stream)
)
ConnectionInner::InProgress(TlsHandshakeMachine::new(tls_acceptor.clone(), tcp_stream))
} else {
::log::debug!("established tcp connection");
ConnectionInner::Established(
EstablishedConnection::new(Stream::TcpStream(tcp_stream))
)
ConnectionInner::Established(EstablishedConnection::new(Stream::TcpStream(tcp_stream)))
};
Self { valid_until, inner }
}
#[inline]
pub fn from_established(valid_until: ValidUntil, established: EstablishedConnection) -> Self {
Self {
valid_until,
inner,
inner: ConnectionInner::Established(established),
}
}
#[inline]
pub fn from_established(
valid_until: ValidUntil,
established: EstablishedConnection,
) -> Self {
pub fn from_in_progress(valid_until: ValidUntil, machine: TlsHandshakeMachine) -> Self {
Self {
valid_until,
inner: ConnectionInner::Established(established)
}
}
#[inline]
pub fn from_in_progress(
valid_until: ValidUntil,
machine: TlsHandshakeMachine,
) -> Self {
Self {
valid_until,
inner: ConnectionInner::InProgress(machine)
inner: ConnectionInner::InProgress(machine),
}
}
@ -290,40 +252,30 @@ impl Connection {
pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
match &mut self.inner {
ConnectionInner::Established(established) => {
match &mut established.stream {
Stream::TcpStream(ref mut stream) => {
poll.registry().deregister(stream)
},
Stream::TlsStream(ref mut stream) => {
poll.registry().deregister(stream.get_mut())
},
}
ConnectionInner::Established(established) => match &mut established.stream {
Stream::TcpStream(ref mut stream) => poll.registry().deregister(stream),
Stream::TlsStream(ref mut stream) => poll.registry().deregister(stream.get_mut()),
},
ConnectionInner::InProgress(TlsHandshakeMachine { inner, ..}) => {
match inner {
TlsHandshakeMachineInner::TcpStream(ref mut stream) => {
poll.registry().deregister(stream)
},
TlsHandshakeMachineInner::TlsMidHandshake(ref mut mid_handshake) => {
poll.registry().deregister(mid_handshake.get_mut())
},
ConnectionInner::InProgress(TlsHandshakeMachine { inner, .. }) => match inner {
TlsHandshakeMachineInner::TcpStream(ref mut stream) => {
poll.registry().deregister(stream)
}
TlsHandshakeMachineInner::TlsMidHandshake(ref mut mid_handshake) => {
poll.registry().deregister(mid_handshake.get_mut())
}
},
}
}
}
pub type ConnectionMap = HashMap<Token, Connection>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_num_digits_in_usize(){
fn test_num_digits_in_usize() {
let f = EstablishedConnection::num_digits_in_usize;
assert_eq!(f(0), 1);
@ -336,4 +288,4 @@ mod tests {
assert_eq!(f(101), 3);
assert_eq!(f(1000), 4);
}
}
}

View file

@ -1,13 +1,13 @@
use std::time::{Duration, Instant};
use std::io::{ErrorKind, Cursor};
use std::io::{Cursor, ErrorKind};
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::vec::Drain;
use hashbrown::HashMap;
use log::{info, debug, error};
use native_tls::TlsAcceptor;
use mio::{Events, Poll, Interest, Token};
use log::{debug, error, info};
use mio::net::TcpListener;
use mio::{Events, Interest, Poll, Token};
use native_tls::TlsAcceptor;
use aquatic_http_protocol::response::*;
@ -21,10 +21,8 @@ pub mod utils;
use connection::*;
use utils::*;
const CONNECTION_CLEAN_INTERVAL: usize = 2 ^ 22;
pub fn run_socket_worker(
config: Config,
socket_worker_index: usize,
@ -33,8 +31,8 @@ pub fn run_socket_worker(
response_channel_receiver: ResponseChannelReceiver,
opt_tls_acceptor: Option<TlsAcceptor>,
poll: Poll,
){
match create_listener(config.network.address, config.network.ipv6_only){
) {
match create_listener(config.network.address, config.network.ipv6_only) {
Ok(listener) => {
socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(()));
@ -47,16 +45,14 @@ pub fn run_socket_worker(
opt_tls_acceptor,
poll,
);
},
}
Err(err) => {
socket_worker_statuses.lock()[socket_worker_index] = Some(
Err(format!("Couldn't open socket: {:#}", err))
);
socket_worker_statuses.lock()[socket_worker_index] =
Some(Err(format!("Couldn't open socket: {:#}", err)));
}
}
}
pub fn run_poll_loop(
config: Config,
socket_worker_index: usize,
@ -65,10 +61,8 @@ pub fn run_poll_loop(
listener: ::std::net::TcpListener,
opt_tls_acceptor: Option<TlsAcceptor>,
mut poll: Poll,
){
let poll_timeout = Duration::from_micros(
config.network.poll_timeout_microseconds
);
) {
let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds);
let mut listener = TcpListener::from_std(listener);
let mut events = Events::with_capacity(config.network.poll_event_capacity);
@ -91,7 +85,7 @@ pub fn run_poll_loop(
poll.poll(&mut events, Some(poll_timeout))
.expect("failed polling");
for event in events.iter(){
for event in events.iter() {
let token = event.token();
if token == LISTENER_TOKEN {
@ -124,7 +118,7 @@ pub fn run_poll_loop(
&mut response_buffer,
local_responses.drain(..),
&response_channel_receiver,
&mut connections
&mut connections,
);
}
@ -137,7 +131,6 @@ pub fn run_poll_loop(
}
}
fn accept_new_streams(
config: &Config,
listener: &mut TcpListener,
@ -145,11 +138,11 @@ fn accept_new_streams(
connections: &mut ConnectionMap,
poll_token_counter: &mut Token,
opt_tls_acceptor: &Option<Arc<TlsAcceptor>>,
){
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
loop {
match listener.accept(){
match listener.accept() {
Ok((mut stream, _)) => {
poll_token_counter.0 = poll_token_counter.0.wrapping_add(1);
@ -167,17 +160,13 @@ fn accept_new_streams(
.register(&mut stream, token, Interest::READABLE)
.unwrap();
let connection = Connection::new(
opt_tls_acceptor,
valid_until,
stream
);
let connection = Connection::new(opt_tls_acceptor, valid_until, stream);
connections.insert(token, connection);
},
}
Err(err) => {
if err.kind() == ErrorKind::WouldBlock {
break
break;
}
info!("error while accepting streams: {}", err);
@ -186,7 +175,6 @@ fn accept_new_streams(
}
}
/// On the stream given by poll_token, get TLS up and running if requested,
/// then read requests and pass on through channel.
pub fn handle_connection_read_event(
@ -197,119 +185,106 @@ pub fn handle_connection_read_event(
local_responses: &mut Vec<(ConnectionMeta, Response)>,
connections: &mut ConnectionMap,
poll_token: Token,
){
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
loop {
// Get connection, updating valid_until
let connection = if let Some(c) = connections.get_mut(&poll_token){
let connection = if let Some(c) = connections.get_mut(&poll_token) {
c
} else {
// If there is no connection, there is no stream, so there
// shouldn't be any (relevant) poll events. In other words, it's
// safe to return here
return
return;
};
connection.valid_until = valid_until;
if let Some(established) = connection.get_established(){
match established.read_request(){
if let Some(established) = connection.get_established() {
match established.read_request() {
Ok(request) => {
let meta = ConnectionMeta {
worker_index: socket_worker_index,
poll_token,
peer_addr: established.peer_addr
peer_addr: established.peer_addr,
};
debug!("read request, sending to handler");
if let Err(err) = request_channel_sender
.send((meta, request))
{
error!(
"RequestChannelSender: couldn't send message: {:?}",
err
);
if let Err(err) = request_channel_sender.send((meta, request)) {
error!("RequestChannelSender: couldn't send message: {:?}", err);
}
break
},
break;
}
Err(RequestReadError::NeedMoreData) => {
info!("need more data");
// Stop reading data (defer to later events)
break;
},
}
Err(RequestReadError::Parse(err)) => {
info!("error reading request (invalid): {:#?}", err);
let meta = ConnectionMeta {
worker_index: socket_worker_index,
poll_token,
peer_addr: established.peer_addr
peer_addr: established.peer_addr,
};
let response = FailureResponse {
failure_reason: "invalid request".to_string()
failure_reason: "invalid request".to_string(),
};
local_responses.push(
(meta, Response::Failure(response))
);
local_responses.push((meta, Response::Failure(response)));
break;
},
}
Err(RequestReadError::StreamEnded) => {
::log::debug!("stream ended");
remove_connection(poll, connections, &poll_token);
break
},
break;
}
Err(RequestReadError::Io(err)) => {
::log::info!("error reading request (io): {}", err);
remove_connection(poll, connections, &poll_token);
break;
},
break;
}
}
} else if let Some(handshake_machine) = connections.remove(&poll_token)
} else if let Some(handshake_machine) = connections
.remove(&poll_token)
.and_then(Connection::get_in_progress)
{
match handshake_machine.establish_tls(){
match handshake_machine.establish_tls() {
Ok(established) => {
let connection = Connection::from_established(
valid_until,
established
);
let connection = Connection::from_established(valid_until, established);
connections.insert(poll_token, connection);
},
}
Err(TlsHandshakeMachineError::WouldBlock(machine)) => {
let connection = Connection::from_in_progress(
valid_until,
machine
);
let connection = Connection::from_in_progress(valid_until, machine);
connections.insert(poll_token, connection);
// Break and wait for more data
break
},
break;
}
Err(TlsHandshakeMachineError::Failure(err)) => {
info!("tls handshake error: {}", err);
// TLS negotiation failed
break
break;
}
}
}
}
}
/// Read responses from channel, send to peers
pub fn send_responses(
config: &Config,
@ -318,13 +293,13 @@ pub fn send_responses(
local_responses: Drain<(ConnectionMeta, Response)>,
channel_responses: &ResponseChannelReceiver,
connections: &mut ConnectionMap,
){
) {
let channel_responses_len = channel_responses.len();
let channel_responses_drain = channel_responses.try_iter()
.take(channel_responses_len);
let channel_responses_drain = channel_responses.try_iter().take(channel_responses_len);
for (meta, response) in local_responses.chain(channel_responses_drain){
if let Some(established) = connections.get_mut(&meta.poll_token)
for (meta, response) in local_responses.chain(channel_responses_drain) {
if let Some(established) = connections
.get_mut(&meta.poll_token)
.and_then(Connection::get_established)
{
if established.peer_addr != meta.peer_addr {
@ -337,7 +312,7 @@ pub fn send_responses(
let bytes_written = response.write(buffer).unwrap();
match established.send_response(&buffer.get_mut()[..bytes_written]){
match established.send_response(&buffer.get_mut()[..bytes_written]) {
Ok(()) => {
::log::debug!(
"sent response: {:?} with response string {}",
@ -348,33 +323,29 @@ pub fn send_responses(
if !config.network.keep_alive {
remove_connection(poll, connections, &meta.poll_token);
}
},
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
debug!("send response: would block");
},
}
Err(err) => {
info!("error sending response: {}", err);
remove_connection(poll, connections, &meta.poll_token);
},
}
}
}
}
}
// Close and remove inactive connections
pub fn remove_inactive_connections(
poll: &mut Poll,
connections: &mut ConnectionMap,
){
pub fn remove_inactive_connections(poll: &mut Poll, connections: &mut ConnectionMap) {
let now = Instant::now();
connections.retain(|_, connection| {
let keep = connection.valid_until.0 >= now;
if !keep {
if let Err(err) = connection.deregister(poll){
if let Err(err) = connection.deregister(poll) {
::log::error!("deregister connection error: {}", err);
}
}
@ -385,15 +356,10 @@ pub fn remove_inactive_connections(
connections.shrink_to_fit();
}
fn remove_connection(
poll: &mut Poll,
connections: &mut ConnectionMap,
connection_token: &Token,
){
if let Some(mut connection) = connections.remove(connection_token){
if let Err(err) = connection.deregister(poll){
fn remove_connection(poll: &mut Poll, connections: &mut ConnectionMap, connection_token: &Token) {
if let Some(mut connection) = connections.remove(connection_token) {
if let Err(err) = connection.deregister(poll) {
::log::error!("deregister connection error: {}", err);
}
}
}
}

View file

@ -1,16 +1,14 @@
use std::net::{SocketAddr};
use std::io::{Read, Write};
use std::net::SocketAddr;
use mio::net::TcpStream;
use native_tls::TlsStream;
pub enum Stream {
TcpStream(TcpStream),
TlsStream(TlsStream<TcpStream>),
}
impl Stream {
#[inline]
pub fn get_peer_addr(&self) -> SocketAddr {
@ -21,7 +19,6 @@ impl Stream {
}
}
impl Read for Stream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> Result<usize, ::std::io::Error> {
@ -35,7 +32,7 @@ impl Read for Stream {
#[inline]
fn read_vectored(
&mut self,
bufs: &mut [::std::io::IoSliceMut<'_>]
bufs: &mut [::std::io::IoSliceMut<'_>],
) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.read_vectored(bufs),
@ -44,7 +41,6 @@ impl Read for Stream {
}
}
impl Write for Stream {
#[inline]
fn write(&mut self, buf: &[u8]) -> ::std::io::Result<usize> {
@ -56,10 +52,7 @@ impl Write for Stream {
/// Not used but provided for completeness
#[inline]
fn write_vectored(
&mut self,
bufs: &[::std::io::IoSlice<'_>]
) -> ::std::io::Result<usize> {
fn write_vectored(&mut self, bufs: &[::std::io::IoSlice<'_>]) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.write_vectored(bufs),
Self::TlsStream(stream) => stream.write_vectored(bufs),
@ -73,4 +66,4 @@ impl Write for Stream {
Self::TlsStream(stream) => stream.flush(),
}
}
}
}

View file

@ -4,26 +4,21 @@ use std::net::SocketAddr;
use anyhow::Context;
use native_tls::{Identity, TlsAcceptor};
use socket2::{Socket, Domain, Type, Protocol};
use socket2::{Domain, Protocol, Socket, Type};
use crate::config::TlsConfig;
pub fn create_tls_acceptor(
config: &TlsConfig,
) -> anyhow::Result<Option<TlsAcceptor>> {
pub fn create_tls_acceptor(config: &TlsConfig) -> anyhow::Result<Option<TlsAcceptor>> {
if config.use_tls {
let mut identity_bytes = Vec::new();
let mut file = File::open(&config.tls_pkcs12_path)
.context("Couldn't open pkcs12 identity file")?;
let mut file =
File::open(&config.tls_pkcs12_path).context("Couldn't open pkcs12 identity file")?;
file.read_to_end(&mut identity_bytes)
.context("Couldn't read pkcs12 identity file")?;
let identity = Identity::from_pkcs12(
&identity_bytes[..],
&config.tls_pkcs12_password
).context("Couldn't parse pkcs12 identity file")?;
let identity = Identity::from_pkcs12(&identity_bytes[..], &config.tls_pkcs12_password)
.context("Couldn't parse pkcs12 identity file")?;
let acceptor = TlsAcceptor::new(identity)
.context("Couldn't create TlsAcceptor from pkcs12 identity")?;
@ -34,31 +29,35 @@ pub fn create_tls_acceptor(
}
}
pub fn create_listener(
address: SocketAddr,
ipv6_only: bool
ipv6_only: bool,
) -> ::anyhow::Result<::std::net::TcpListener> {
let builder = if address.is_ipv4(){
let builder = if address.is_ipv4() {
Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp()))
} else {
Socket::new(Domain::ipv6(), Type::stream(), Some(Protocol::tcp()))
}.context("Couldn't create socket2::Socket")?;
}
.context("Couldn't create socket2::Socket")?;
if ipv6_only {
builder.set_only_v6(true)
builder
.set_only_v6(true)
.context("Couldn't put socket in ipv6 only mode")?
}
builder.set_nonblocking(true)
builder
.set_nonblocking(true)
.context("Couldn't put socket in non-blocking mode")?;
builder.set_reuse_port(true)
builder
.set_reuse_port(true)
.context("Couldn't put socket in reuse_port mode")?;
builder.bind(&address.into()).with_context(||
format!("Couldn't bind socket to address {}", address)
)?;
builder.listen(128)
builder
.bind(&address.into())
.with_context(|| format!("Couldn't bind socket to address {}", address))?;
builder
.listen(128)
.context("Couldn't listen for connections on socket")?;
Ok(builder.into_tcp_listener())
}
}

View file

@ -4,19 +4,14 @@ use histogram::Histogram;
use crate::common::*;
pub fn clean_torrents(state: &State){
pub fn clean_torrents(state: &State) {
let mut torrent_maps = state.torrent_maps.lock();
clean_torrent_map(&mut torrent_maps.ipv4);
clean_torrent_map(&mut torrent_maps.ipv6);
}
fn clean_torrent_map<I: Ip>(
torrent_map: &mut TorrentMap<I>,
){
fn clean_torrent_map<I: Ip>(torrent_map: &mut TorrentMap<I>) {
let now = Instant::now();
torrent_map.retain(|_, torrent_data| {
@ -30,10 +25,10 @@ fn clean_torrent_map<I: Ip>(
match peer.status {
PeerStatus::Seeding => {
*num_seeders -= 1;
},
}
PeerStatus::Leeching => {
*num_leechers -= 1;
},
}
_ => (),
};
}
@ -47,24 +42,23 @@ fn clean_torrent_map<I: Ip>(
torrent_map.shrink_to_fit();
}
pub fn print_statistics(state: &State){
pub fn print_statistics(state: &State) {
let mut peers_per_torrent = Histogram::new();
{
let torrents = &mut state.torrent_maps.lock();
for torrent in torrents.ipv4.values(){
for torrent in torrents.ipv4.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers){
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
for torrent in torrents.ipv6.values(){
for torrent in torrents.ipv6.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers){
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
@ -82,4 +76,4 @@ pub fn print_statistics(state: &State){
peers_per_torrent.maximum().unwrap(),
);
}
}
}