From 465cf5920debfe3b8ef074e24744531681277207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Fri, 5 Nov 2021 12:42:55 +0100 Subject: [PATCH] WIP: ws: put back mio implementation --- Cargo.lock | 156 ++++++++ aquatic_udp/src/lib/glommio/handlers.rs | 5 +- aquatic_udp/src/lib/glommio/network.rs | 30 +- aquatic_ws/Cargo.toml | 8 + aquatic_ws/src/lib/{ => common}/handlers.rs | 107 +----- .../src/lib/{common.rs => common/mod.rs} | 7 +- aquatic_ws/src/lib/config.rs | 53 +++ aquatic_ws/src/lib/glommio/common.rs | 8 + aquatic_ws/src/lib/glommio/handlers.rs | 114 ++++++ aquatic_ws/src/lib/glommio/mod.rs | 132 +++++++ aquatic_ws/src/lib/{ => glommio}/network.rs | 4 +- aquatic_ws/src/lib/lib.rs | 137 +------- aquatic_ws/src/lib/mio/common.rs | 51 +++ aquatic_ws/src/lib/mio/handlers.rs | 101 ++++++ aquatic_ws/src/lib/mio/mod.rs | 206 +++++++++++ aquatic_ws/src/lib/mio/network/connection.rs | 298 ++++++++++++++++ aquatic_ws/src/lib/mio/network/mod.rs | 332 ++++++++++++++++++ aquatic_ws/src/lib/mio/network/utils.rs | 66 ++++ aquatic_ws/src/lib/tasks.rs | 10 - 19 files changed, 1559 insertions(+), 266 deletions(-) rename aquatic_ws/src/lib/{ => common}/handlers.rs (64%) rename aquatic_ws/src/lib/{common.rs => common/mod.rs} (97%) create mode 100644 aquatic_ws/src/lib/glommio/common.rs create mode 100644 aquatic_ws/src/lib/glommio/handlers.rs create mode 100644 aquatic_ws/src/lib/glommio/mod.rs rename aquatic_ws/src/lib/{ => glommio}/network.rs (99%) create mode 100644 aquatic_ws/src/lib/mio/common.rs create mode 100644 aquatic_ws/src/lib/mio/handlers.rs create mode 100644 aquatic_ws/src/lib/mio/mod.rs create mode 100644 aquatic_ws/src/lib/mio/network/connection.rs create mode 100644 aquatic_ws/src/lib/mio/network/mod.rs create mode 100644 aquatic_ws/src/lib/mio/network/utils.rs diff --git a/Cargo.lock b/Cargo.lock index fc8e788..5fa200b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,14 +241,19 @@ dependencies = [ "aquatic_ws_protocol", "async-tungstenite", "core_affinity", + "crossbeam-channel", "either", "futures", "futures-lite", "futures-rustls", "glommio", "hashbrown 0.11.2", + "histogram", "log", "mimalloc", + "mio", + "native-tls", + "parking_lot", "privdrop", "quickcheck", "quickcheck_macros", @@ -257,6 +262,7 @@ dependencies = [ "serde", "signal-hook", "slab", + "socket2 0.4.2", "tungstenite", ] @@ -511,6 +517,22 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "core-foundation" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6888e10551bb93e424d8df1d07f1a8b4fceb0001a3a4b048bfc47554946f47b3" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" + [[package]] name = "core_affinity" version = "0.5.10" @@ -741,6 +763,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.0.1" @@ -1241,6 +1278,24 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "native-tls" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nix" version = "0.23.0" @@ -1342,6 +1397,39 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "openssl" +version = "0.10.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7ae222234c30df141154f159066c5093ff73b63204dcda7121eb082fc56a95" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-sys", +] + +[[package]] +name = "openssl-probe" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a" + +[[package]] +name = "openssl-sys" +version = "0.9.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6517987b3f8226b5da3661dad65ff7f300cc59fb5ea8333ca191fc65fde3edf" +dependencies = [ + "autocfg", + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "owned-alloc" version = "0.2.0" @@ -1397,6 +1485,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12295df4f294471248581bc09bef3c38a5e46f1e36d6a37353621a0c6c357e1f" + [[package]] name = "plotters" version = "0.3.1" @@ -1600,6 +1694,15 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi 0.3.9", +] + [[package]] name = "ring" version = "0.16.20" @@ -1675,6 +1778,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75" +dependencies = [ + "lazy_static", + "winapi 0.3.9", +] + [[package]] name = "scoped-tls" version = "1.0.0" @@ -1697,6 +1810,29 @@ dependencies = [ "untrusted", ] +[[package]] +name = "security-framework" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525bc1abfda2e1998d152c45cf13e696f76d0a4972310b22fac1658b05df7c87" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9dd14d83160b528b7bfd66439110573efcfbe281b17fc2ca9f39f550d619c7e" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.4" @@ -1902,6 +2038,20 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "tempfile" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" +dependencies = [ + "cfg-if", + "libc", + "rand", + "redox_syscall", + "remove_dir_all", + "winapi 0.3.9", +] + [[package]] name = "termcolor" version = "1.1.2" @@ -2120,6 +2270,12 @@ dependencies = [ "ryu", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.3" diff --git a/aquatic_udp/src/lib/glommio/handlers.rs b/aquatic_udp/src/lib/glommio/handlers.rs index 862710f..55adc4a 100644 --- a/aquatic_udp/src/lib/glommio/handlers.rs +++ b/aquatic_udp/src/lib/glommio/handlers.rs @@ -105,7 +105,10 @@ async fn handle_request_stream( ::log::debug!("preparing to send response to channel: {:?}", response); - if let Err(err) = response_senders.send_to(producer_index, (response, src)).await { + if let Err(err) = response_senders + .send_to(producer_index, (response, src)) + .await + { ::log::error!("response_sender.send: {:?}", err); } diff --git a/aquatic_udp/src/lib/glommio/network.rs b/aquatic_udp/src/lib/glommio/network.rs index c42ea57..f8eb462 100644 --- a/aquatic_udp/src/lib/glommio/network.rs +++ b/aquatic_udp/src/lib/glommio/network.rs @@ -240,14 +240,17 @@ async fn read_requests( let request_consumer_index = calculate_request_consumer_index(&config, request.info_hash); - if let Err(err) = request_senders.send_to( - request_consumer_index, - ( - response_consumer_index, - ConnectedRequest::Announce(request), - src, - ), - ).await { + if let Err(err) = request_senders + .send_to( + request_consumer_index, + ( + response_consumer_index, + ConnectedRequest::Announce(request), + src, + ), + ) + .await + { ::log::error!("request_sender.try_send failed: {:?}", err) } } else { @@ -300,10 +303,13 @@ async fn read_requests( original_indices, }; - if let Err(err) = request_senders.send_to( - consumer_index, - (response_consumer_index, request, src), - ).await { + if let Err(err) = request_senders + .send_to( + consumer_index, + (response_consumer_index, request, src), + ) + .await + { ::log::error!("request_sender.send failed: {:?}", err) } } diff --git a/aquatic_ws/Cargo.toml b/aquatic_ws/Cargo.toml index 4c58732..9fdd95b 100644 --- a/aquatic_ws/Cargo.toml +++ b/aquatic_ws/Cargo.toml @@ -38,6 +38,14 @@ signal-hook = { version = "0.3" } slab = "0.4" tungstenite = "0.15" + +crossbeam-channel = "0.5" +histogram = "0.6" +mio = { version = "0.7", features = ["tcp", "os-poll", "os-util"] } +native-tls = "0.2" +parking_lot = "0.11" +socket2 = { version = "0.4.1", features = ["all"] } + [dev-dependencies] quickcheck = "1.0" quickcheck_macros = "1.0" diff --git a/aquatic_ws/src/lib/handlers.rs b/aquatic_ws/src/lib/common/handlers.rs similarity index 64% rename from aquatic_ws/src/lib/handlers.rs rename to aquatic_ws/src/lib/common/handlers.rs index cd64de3..29677ba 100644 --- a/aquatic_ws/src/lib/handlers.rs +++ b/aquatic_ws/src/lib/common/handlers.rs @@ -1,117 +1,12 @@ -use std::cell::RefCell; -use std::rc::Rc; -use std::time::Duration; - use aquatic_common::extract_response_peers; -use futures_lite::StreamExt; -use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; -use glommio::enclose; -use glommio::prelude::*; -use glommio::timer::TimerActionRepeat; use hashbrown::HashMap; -use rand::{rngs::SmallRng, Rng, SeedableRng}; +use rand::Rng; use aquatic_ws_protocol::*; use crate::common::*; use crate::config::Config; -pub async fn run_request_worker( - config: Config, - state: State, - in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, - out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, -) { - let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); - let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); - - let out_message_senders = Rc::new(out_message_senders); - - let torrents = Rc::new(RefCell::new(TorrentMaps::default())); - let access_list = state.access_list; - - // Periodically clean torrents - TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || { - enclose!((config, torrents, access_list) move || async move { - torrents.borrow_mut().clean(&config, &access_list); - - Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval)) - })() - })); - - let mut handles = Vec::new(); - - for (_, receiver) in in_message_receivers.streams() { - let handle = spawn_local(handle_request_stream( - config.clone(), - torrents.clone(), - out_message_senders.clone(), - receiver, - )) - .detach(); - - handles.push(handle); - } - - for handle in handles { - handle.await; - } -} - -async fn handle_request_stream( - config: Config, - torrents: Rc>, - out_message_senders: Rc>, - mut stream: S, -) where - S: futures_lite::Stream + ::std::marker::Unpin, -{ - let mut rng = SmallRng::from_entropy(); - - let max_peer_age = config.cleaning.max_peer_age; - let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age))); - - TimerActionRepeat::repeat(enclose!((peer_valid_until) move || { - enclose!((peer_valid_until) move || async move { - *peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age); - - Some(Duration::from_secs(1)) - })() - })); - - let mut out_messages = Vec::new(); - - while let Some((meta, in_message)) = stream.next().await { - match in_message { - InMessage::AnnounceRequest(request) => handle_announce_request( - &config, - &mut rng, - &mut torrents.borrow_mut(), - &mut out_messages, - peer_valid_until.borrow().to_owned(), - meta, - request, - ), - InMessage::ScrapeRequest(request) => handle_scrape_request( - &config, - &mut torrents.borrow_mut(), - &mut out_messages, - meta, - request, - ), - }; - - for (meta, out_message) in out_messages.drain(..) { - out_message_senders - .send_to(meta.out_message_consumer_id.0, (meta, out_message)) - .await - .expect("failed sending out_message to socket worker"); - } - - yield_if_needed().await; - } -} - pub fn handle_announce_request( config: &Config, rng: &mut impl Rng, diff --git a/aquatic_ws/src/lib/common.rs b/aquatic_ws/src/lib/common/mod.rs similarity index 97% rename from aquatic_ws/src/lib/common.rs rename to aquatic_ws/src/lib/common/mod.rs index 41cea30..be82373 100644 --- a/aquatic_ws/src/lib/common.rs +++ b/aquatic_ws/src/lib/common/mod.rs @@ -1,3 +1,5 @@ +pub mod handlers; + use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Instant; @@ -142,8 +144,3 @@ impl TorrentMaps { torrent_map.shrink_to_fit(); } } - -#[derive(Default, Clone)] -pub struct State { - pub access_list: Arc, -} diff --git a/aquatic_ws/src/lib/config.rs b/aquatic_ws/src/lib/config.rs index a73e870..32ceb26 100644 --- a/aquatic_ws/src/lib/config.rs +++ b/aquatic_ws/src/lib/config.rs @@ -20,10 +20,12 @@ pub struct Config { pub log_level: LogLevel, pub network: NetworkConfig, pub protocol: ProtocolConfig, + pub handlers: HandlerConfig, pub cleaning: CleaningConfig, pub privileges: PrivilegeConfig, pub access_list: AccessListConfig, pub cpu_pinning: CpuPinningConfig, + pub statistics: StatisticsConfig, } impl aquatic_cli_helpers::Config for Config { @@ -38,10 +40,17 @@ pub struct NetworkConfig { /// Bind to this address pub address: SocketAddr, pub ipv6_only: bool, + pub tls_certificate_path: PathBuf, pub tls_private_key_path: PathBuf, pub websocket_max_message_size: usize, pub websocket_max_frame_size: usize, + + pub use_tls: bool, + pub tls_pkcs12_path: String, + pub tls_pkcs12_password: String, + pub poll_event_capacity: usize, + pub poll_timeout_microseconds: u64, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -55,6 +64,15 @@ pub struct ProtocolConfig { pub peer_announce_interval: usize, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct HandlerConfig { + /// Maximum number of requests to receive from channel before locking + /// mutex and starting work + pub max_requests_per_iter: usize, + pub channel_recv_timeout_microseconds: u64, +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(default)] pub struct CleaningConfig { @@ -62,6 +80,16 @@ pub struct CleaningConfig { pub torrent_cleaning_interval: u64, /// Remove peers that haven't announced for this long (seconds) pub max_peer_age: u64, + + /// Remove connections that are older than this (seconds) + pub max_connection_age: u64, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct StatisticsConfig { + /// Print statistics this often (seconds). Don't print when set to zero. + pub interval: u64, } impl Default for Config { @@ -72,10 +100,12 @@ impl Default for Config { log_level: LogLevel::default(), network: NetworkConfig::default(), protocol: ProtocolConfig::default(), + handlers: Default::default(), cleaning: CleaningConfig::default(), privileges: PrivilegeConfig::default(), access_list: AccessListConfig::default(), cpu_pinning: Default::default(), + statistics: Default::default(), } } } @@ -89,6 +119,12 @@ impl Default for NetworkConfig { tls_private_key_path: "".into(), websocket_max_message_size: 64 * 1024, websocket_max_frame_size: 16 * 1024, + + use_tls: false, + tls_pkcs12_path: "".into(), + tls_pkcs12_password: "".into(), + poll_event_capacity: 4096, + poll_timeout_microseconds: 200_000, } } } @@ -103,11 +139,28 @@ impl Default for ProtocolConfig { } } +impl Default for HandlerConfig { + fn default() -> Self { + Self { + max_requests_per_iter: 10000, + channel_recv_timeout_microseconds: 200, + } + } +} + impl Default for CleaningConfig { fn default() -> Self { Self { torrent_cleaning_interval: 30, max_peer_age: 1800, + + max_connection_age: 1800, } } } + +impl Default for StatisticsConfig { + fn default() -> Self { + Self { interval: 0 } + } +} diff --git a/aquatic_ws/src/lib/glommio/common.rs b/aquatic_ws/src/lib/glommio/common.rs new file mode 100644 index 0000000..3506b09 --- /dev/null +++ b/aquatic_ws/src/lib/glommio/common.rs @@ -0,0 +1,8 @@ +use std::sync::Arc; + +use aquatic_common::access_list::AccessListArcSwap; + +#[derive(Default, Clone)] +pub struct State { + pub access_list: Arc, +} diff --git a/aquatic_ws/src/lib/glommio/handlers.rs b/aquatic_ws/src/lib/glommio/handlers.rs new file mode 100644 index 0000000..c3332e0 --- /dev/null +++ b/aquatic_ws/src/lib/glommio/handlers.rs @@ -0,0 +1,114 @@ +use std::cell::RefCell; +use std::rc::Rc; +use std::time::Duration; + +use futures_lite::StreamExt; +use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; +use glommio::enclose; +use glommio::prelude::*; +use glommio::timer::TimerActionRepeat; +use rand::{rngs::SmallRng, SeedableRng}; + +use aquatic_ws_protocol::*; + +use crate::common::handlers::*; +use crate::common::*; +use crate::config::Config; + +use super::common::State; + +pub async fn run_request_worker( + config: Config, + state: State, + in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, + out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, +) { + let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); + let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); + + let out_message_senders = Rc::new(out_message_senders); + + let torrents = Rc::new(RefCell::new(TorrentMaps::default())); + let access_list = state.access_list; + + // Periodically clean torrents + TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || { + enclose!((config, torrents, access_list) move || async move { + torrents.borrow_mut().clean(&config, &access_list); + + Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval)) + })() + })); + + let mut handles = Vec::new(); + + for (_, receiver) in in_message_receivers.streams() { + let handle = spawn_local(handle_request_stream( + config.clone(), + torrents.clone(), + out_message_senders.clone(), + receiver, + )) + .detach(); + + handles.push(handle); + } + + for handle in handles { + handle.await; + } +} + +async fn handle_request_stream( + config: Config, + torrents: Rc>, + out_message_senders: Rc>, + mut stream: S, +) where + S: futures_lite::Stream + ::std::marker::Unpin, +{ + let mut rng = SmallRng::from_entropy(); + + let max_peer_age = config.cleaning.max_peer_age; + let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age))); + + TimerActionRepeat::repeat(enclose!((peer_valid_until) move || { + enclose!((peer_valid_until) move || async move { + *peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age); + + Some(Duration::from_secs(1)) + })() + })); + + let mut out_messages = Vec::new(); + + while let Some((meta, in_message)) = stream.next().await { + match in_message { + InMessage::AnnounceRequest(request) => handle_announce_request( + &config, + &mut rng, + &mut torrents.borrow_mut(), + &mut out_messages, + peer_valid_until.borrow().to_owned(), + meta, + request, + ), + InMessage::ScrapeRequest(request) => handle_scrape_request( + &config, + &mut torrents.borrow_mut(), + &mut out_messages, + meta, + request, + ), + }; + + for (meta, out_message) in out_messages.drain(..) { + out_message_senders + .send_to(meta.out_message_consumer_id.0, (meta, out_message)) + .await + .expect("failed sending out_message to socket worker"); + } + + yield_if_needed().await; + } +} diff --git a/aquatic_ws/src/lib/glommio/mod.rs b/aquatic_ws/src/lib/glommio/mod.rs new file mode 100644 index 0000000..8dac78f --- /dev/null +++ b/aquatic_ws/src/lib/glommio/mod.rs @@ -0,0 +1,132 @@ +pub mod common; +pub mod handlers; +pub mod network; + +use std::{ + fs::File, + io::BufReader, + sync::{atomic::AtomicUsize, Arc}, +}; + +use crate::config::Config; +use aquatic_common::privileges::drop_privileges_after_socket_binding; + +use self::common::State; + +use super::common::TlsConfig; +use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; + +const SHARED_CHANNEL_SIZE: usize = 1024; + +pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { + if config.cpu_pinning.active { + core_affinity::set_for_current(core_affinity::CoreId { + id: config.cpu_pinning.offset, + }); + } + + let num_peers = config.socket_workers + config.request_workers; + + let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); + let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); + + let num_bound_sockets = Arc::new(AtomicUsize::new(0)); + + let tls_config = Arc::new(create_tls_config(&config).unwrap()); + + let mut executors = Vec::new(); + + for i in 0..(config.socket_workers) { + let config = config.clone(); + let state = state.clone(); + let tls_config = tls_config.clone(); + let request_mesh_builder = request_mesh_builder.clone(); + let response_mesh_builder = response_mesh_builder.clone(); + let num_bound_sockets = num_bound_sockets.clone(); + + let mut builder = LocalExecutorBuilder::default(); + + if config.cpu_pinning.active { + builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + i); + } + + let executor = builder.spawn(|| async move { + network::run_socket_worker( + config, + state, + tls_config, + request_mesh_builder, + response_mesh_builder, + num_bound_sockets, + ) + .await + }); + + executors.push(executor); + } + + for i in 0..(config.request_workers) { + let config = config.clone(); + let state = state.clone(); + let request_mesh_builder = request_mesh_builder.clone(); + let response_mesh_builder = response_mesh_builder.clone(); + + let mut builder = LocalExecutorBuilder::default(); + + if config.cpu_pinning.active { + builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + config.socket_workers + i); + } + + let executor = builder.spawn(|| async move { + handlers::run_request_worker(config, state, request_mesh_builder, response_mesh_builder) + .await + }); + + executors.push(executor); + } + + drop_privileges_after_socket_binding( + &config.privileges, + num_bound_sockets, + config.socket_workers, + ) + .unwrap(); + + for executor in executors { + executor + .expect("failed to spawn local executor") + .join() + .unwrap(); + } + + Ok(()) +} + +fn create_tls_config(config: &Config) -> anyhow::Result { + let certs = { + let f = File::open(&config.network.tls_certificate_path)?; + let mut f = BufReader::new(f); + + rustls_pemfile::certs(&mut f)? + .into_iter() + .map(|bytes| futures_rustls::rustls::Certificate(bytes)) + .collect() + }; + + let private_key = { + let f = File::open(&config.network.tls_private_key_path)?; + let mut f = BufReader::new(f); + + rustls_pemfile::pkcs8_private_keys(&mut f)? + .first() + .map(|bytes| futures_rustls::rustls::PrivateKey(bytes.clone())) + .ok_or(anyhow::anyhow!("No private keys in file"))? + }; + + let tls_config = futures_rustls::rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, private_key)?; + + Ok(tls_config) +} diff --git a/aquatic_ws/src/lib/network.rs b/aquatic_ws/src/lib/glommio/network.rs similarity index 99% rename from aquatic_ws/src/lib/network.rs rename to aquatic_ws/src/lib/glommio/network.rs index 55ad023..0f84b11 100644 --- a/aquatic_ws/src/lib/network.rs +++ b/aquatic_ws/src/lib/glommio/network.rs @@ -27,7 +27,9 @@ use slab::Slab; use crate::config::Config; -use super::common::*; +use crate::common::*; + +use super::common::State; struct PendingScrapeResponse { pending_worker_out_messages: usize, diff --git a/aquatic_ws/src/lib/lib.rs b/aquatic_ws/src/lib/lib.rs index b367bab..9b96d00 100644 --- a/aquatic_ws/src/lib/lib.rs +++ b/aquatic_ws/src/lib/lib.rs @@ -1,27 +1,15 @@ -use std::{ - fs::File, - io::BufReader, - sync::{atomic::AtomicUsize, Arc}, -}; - -use aquatic_common::{ - access_list::update_access_list, privileges::drop_privileges_after_socket_binding, -}; -use common::{State, TlsConfig}; -use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; +use aquatic_common::access_list::update_access_list; use signal_hook::{consts::SIGUSR1, iterator::Signals}; use crate::config::Config; -mod common; +pub mod common; pub mod config; -mod handlers; -mod network; +pub mod glommio; +pub mod mio; pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; -const SHARED_CHANNEL_SIZE: usize = 1024; - pub fn run(config: Config) -> ::anyhow::Result<()> { if config.cpu_pinning.active { core_affinity::set_for_current(core_affinity::CoreId { @@ -29,7 +17,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { }); } - let state = State::default(); + let state = glommio::common::State::default(); update_access_list(&config.access_list, &state.access_list)?; @@ -39,7 +27,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let config = config.clone(); let state = state.clone(); - ::std::thread::spawn(move || run_inner(config, state)); + ::std::thread::spawn(move || glommio::run_inner(config, state)); } for signal in &mut signals { @@ -53,116 +41,3 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { Ok(()) } - -pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { - if config.cpu_pinning.active { - core_affinity::set_for_current(core_affinity::CoreId { - id: config.cpu_pinning.offset, - }); - } - - let num_peers = config.socket_workers + config.request_workers; - - let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); - let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); - - let num_bound_sockets = Arc::new(AtomicUsize::new(0)); - - let tls_config = Arc::new(create_tls_config(&config).unwrap()); - - let mut executors = Vec::new(); - - for i in 0..(config.socket_workers) { - let config = config.clone(); - let state = state.clone(); - let tls_config = tls_config.clone(); - let request_mesh_builder = request_mesh_builder.clone(); - let response_mesh_builder = response_mesh_builder.clone(); - let num_bound_sockets = num_bound_sockets.clone(); - - let mut builder = LocalExecutorBuilder::default(); - - if config.cpu_pinning.active { - builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + i); - } - - let executor = builder.spawn(|| async move { - network::run_socket_worker( - config, - state, - tls_config, - request_mesh_builder, - response_mesh_builder, - num_bound_sockets, - ) - .await - }); - - executors.push(executor); - } - - for i in 0..(config.request_workers) { - let config = config.clone(); - let state = state.clone(); - let request_mesh_builder = request_mesh_builder.clone(); - let response_mesh_builder = response_mesh_builder.clone(); - - let mut builder = LocalExecutorBuilder::default(); - - if config.cpu_pinning.active { - builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + config.socket_workers + i); - } - - let executor = builder.spawn(|| async move { - handlers::run_request_worker(config, state, request_mesh_builder, response_mesh_builder) - .await - }); - - executors.push(executor); - } - - drop_privileges_after_socket_binding( - &config.privileges, - num_bound_sockets, - config.socket_workers, - ) - .unwrap(); - - for executor in executors { - executor - .expect("failed to spawn local executor") - .join() - .unwrap(); - } - - Ok(()) -} - -fn create_tls_config(config: &Config) -> anyhow::Result { - let certs = { - let f = File::open(&config.network.tls_certificate_path)?; - let mut f = BufReader::new(f); - - rustls_pemfile::certs(&mut f)? - .into_iter() - .map(|bytes| futures_rustls::rustls::Certificate(bytes)) - .collect() - }; - - let private_key = { - let f = File::open(&config.network.tls_private_key_path)?; - let mut f = BufReader::new(f); - - rustls_pemfile::pkcs8_private_keys(&mut f)? - .first() - .map(|bytes| futures_rustls::rustls::PrivateKey(bytes.clone())) - .ok_or(anyhow::anyhow!("No private keys in file"))? - }; - - let tls_config = futures_rustls::rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs, private_key)?; - - Ok(tls_config) -} diff --git a/aquatic_ws/src/lib/mio/common.rs b/aquatic_ws/src/lib/mio/common.rs new file mode 100644 index 0000000..5337ca7 --- /dev/null +++ b/aquatic_ws/src/lib/mio/common.rs @@ -0,0 +1,51 @@ +use std::sync::Arc; + +use aquatic_common::access_list::AccessListArcSwap; +use aquatic_ws_protocol::*; +use crossbeam_channel::{Receiver, Sender}; +use log::error; +use mio::Token; +use parking_lot::Mutex; + +use crate::common::*; + +pub const LISTENER_TOKEN: Token = Token(0); +pub const CHANNEL_TOKEN: Token = Token(1); + +#[derive(Clone)] +pub struct State { + pub access_list: Arc, + pub torrent_maps: Arc>, +} + +impl Default for State { + fn default() -> Self { + Self { + access_list: Arc::new(Default::default()), + torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), + } + } +} + +pub type InMessageSender = Sender<(ConnectionMeta, InMessage)>; +pub type InMessageReceiver = Receiver<(ConnectionMeta, InMessage)>; +pub type OutMessageReceiver = Receiver<(ConnectionMeta, OutMessage)>; + +#[derive(Clone)] +pub struct OutMessageSender(Vec>); + +impl OutMessageSender { + pub fn new(senders: Vec>) -> Self { + Self(senders) + } + + #[inline] + pub fn send(&self, meta: ConnectionMeta, message: OutMessage) { + if let Err(err) = self.0[meta.out_message_consumer_id.0].send((meta, message)) { + error!("OutMessageSender: couldn't send message: {:?}", err); + } + } +} + +pub type SocketWorkerStatus = Option>; +pub type SocketWorkerStatuses = Arc>>; diff --git a/aquatic_ws/src/lib/mio/handlers.rs b/aquatic_ws/src/lib/mio/handlers.rs new file mode 100644 index 0000000..8d01397 --- /dev/null +++ b/aquatic_ws/src/lib/mio/handlers.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; +use std::time::Duration; + +use mio::Waker; +use parking_lot::MutexGuard; +use rand::{rngs::SmallRng, SeedableRng}; + +use aquatic_ws_protocol::*; + +use crate::common::handlers::{handle_announce_request, handle_scrape_request}; +use crate::common::*; +use crate::config::Config; + +use super::common::*; + +pub fn run_request_worker( + config: Config, + state: State, + in_message_receiver: InMessageReceiver, + out_message_sender: OutMessageSender, + wakers: Vec>, +) { + let mut wake_socket_workers: Vec = (0..config.socket_workers).map(|_| false).collect(); + + let mut announce_requests = Vec::new(); + let mut scrape_requests = Vec::new(); + let mut out_messages = Vec::new(); + + let mut rng = SmallRng::from_entropy(); + + let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds); + + loop { + let mut opt_torrent_map_guard: Option> = None; + + for i in 0..config.handlers.max_requests_per_iter { + let opt_in_message = if i == 0 { + in_message_receiver.recv().ok() + } else { + in_message_receiver.recv_timeout(timeout).ok() + }; + + match opt_in_message { + Some((meta, InMessage::AnnounceRequest(r))) => { + announce_requests.push((meta, r)); + } + Some((meta, InMessage::ScrapeRequest(r))) => { + scrape_requests.push((meta, r)); + } + None => { + if let Some(torrent_guard) = state.torrent_maps.try_lock() { + opt_torrent_map_guard = Some(torrent_guard); + + break; + } + } + } + } + + let mut torrent_map_guard = + opt_torrent_map_guard.unwrap_or_else(|| state.torrent_maps.lock()); + + let valid_until = ValidUntil::new(config.cleaning.max_peer_age); + + for (meta, request) in announce_requests.drain(..) { + handle_announce_request( + &config, + &mut rng, + &mut torrent_map_guard, + &mut out_messages, + valid_until, + meta, + request, + ); + } + for (meta, request) in scrape_requests.drain(..) { + handle_scrape_request( + &config, + &mut torrent_map_guard, + &mut out_messages, + meta, + request, + ); + } + + for (meta, out_message) in out_messages.drain(..) { + wake_socket_workers[meta.out_message_consumer_id.0] = true; + out_message_sender.send(meta, out_message); + } + + for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() { + if *wake { + if let Err(err) = wakers[worker_index].wake() { + ::log::error!("request handler couldn't wake poll: {:?}", err); + } + + *wake = false; + } + } + } +} diff --git a/aquatic_ws/src/lib/mio/mod.rs b/aquatic_ws/src/lib/mio/mod.rs new file mode 100644 index 0000000..b90ef8b --- /dev/null +++ b/aquatic_ws/src/lib/mio/mod.rs @@ -0,0 +1,206 @@ +use std::fs::File; +use std::io::Read; +use std::sync::Arc; +use std::thread::Builder; +use std::time::Duration; + +use anyhow::Context; +use histogram::Histogram; +use mio::{Poll, Waker}; +use native_tls::{Identity, TlsAcceptor}; +use parking_lot::Mutex; +use privdrop::PrivDrop; + +pub mod common; +pub mod handlers; +pub mod network; + +use crate::config::Config; +use common::*; + +pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; + +pub fn run(config: Config, state: State) -> anyhow::Result<()> { + start_workers(config.clone(), state.clone())?; + + // TODO: privdrop here instead + + loop { + ::std::thread::sleep(Duration::from_secs( + config.cleaning.torrent_cleaning_interval, + )); + + state.torrent_maps.lock().clean(&config, &state.access_list); + } +} + +pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { + let opt_tls_acceptor = create_tls_acceptor(&config)?; + + let (in_message_sender, in_message_receiver) = ::crossbeam_channel::unbounded(); + + let mut out_message_senders = Vec::new(); + let mut wakers = Vec::new(); + + let socket_worker_statuses: SocketWorkerStatuses = { + let mut statuses = Vec::new(); + + for _ in 0..config.socket_workers { + statuses.push(None); + } + + Arc::new(Mutex::new(statuses)) + }; + + for i in 0..config.socket_workers { + let config = config.clone(); + let state = state.clone(); + let socket_worker_statuses = socket_worker_statuses.clone(); + let in_message_sender = in_message_sender.clone(); + let opt_tls_acceptor = opt_tls_acceptor.clone(); + let poll = Poll::new()?; + let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?); + + let (out_message_sender, out_message_receiver) = ::crossbeam_channel::unbounded(); + + out_message_senders.push(out_message_sender); + wakers.push(waker); + + Builder::new() + .name(format!("socket-{:02}", i + 1)) + .spawn(move || { + network::run_socket_worker( + config, + state, + i, + socket_worker_statuses, + poll, + in_message_sender, + out_message_receiver, + opt_tls_acceptor, + ); + })?; + } + + // Wait for socket worker statuses. On error from any, quit program. + // On success from all, drop privileges if corresponding setting is set + // and continue program. + loop { + ::std::thread::sleep(::std::time::Duration::from_millis(10)); + + if let Some(statuses) = socket_worker_statuses.try_lock() { + for opt_status in statuses.iter() { + if let Some(Err(err)) = opt_status { + return Err(::anyhow::anyhow!(err.to_owned())); + } + } + + if statuses.iter().all(Option::is_some) { + if config.privileges.drop_privileges { + PrivDrop::default() + .chroot(config.privileges.chroot_path.clone()) + .user(config.privileges.user.clone()) + .apply() + .context("Couldn't drop root privileges")?; + } + + break; + } + } + } + + let out_message_sender = OutMessageSender::new(out_message_senders); + + for i in 0..config.request_workers { + let config = config.clone(); + let state = state.clone(); + let in_message_receiver = in_message_receiver.clone(); + let out_message_sender = out_message_sender.clone(); + let wakers = wakers.clone(); + + Builder::new() + .name(format!("request-{:02}", i + 1)) + .spawn(move || { + handlers::run_request_worker( + config, + state, + in_message_receiver, + out_message_sender, + wakers, + ); + })?; + } + + if config.statistics.interval != 0 { + let state = state.clone(); + let config = config.clone(); + + Builder::new() + .name("statistics".to_string()) + .spawn(move || loop { + ::std::thread::sleep(Duration::from_secs(config.statistics.interval)); + + print_statistics(&state); + }) + .expect("spawn statistics thread"); + } + + Ok(()) +} + +pub fn create_tls_acceptor(config: &Config) -> anyhow::Result> { + if config.network.use_tls { + let mut identity_bytes = Vec::new(); + let mut file = File::open(&config.network.tls_pkcs12_path) + .context("Couldn't open pkcs12 identity file")?; + + file.read_to_end(&mut identity_bytes) + .context("Couldn't read pkcs12 identity file")?; + + let identity = Identity::from_pkcs12(&identity_bytes, &config.network.tls_pkcs12_password) + .context("Couldn't parse pkcs12 identity file")?; + + let acceptor = TlsAcceptor::new(identity) + .context("Couldn't create TlsAcceptor from pkcs12 identity")?; + + Ok(Some(acceptor)) + } else { + Ok(None) + } +} + +fn print_statistics(state: &State) { + let mut peers_per_torrent = Histogram::new(); + + { + let torrents = &mut state.torrent_maps.lock(); + + for torrent in torrents.ipv4.values() { + let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64; + + if let Err(err) = peers_per_torrent.increment(num_peers) { + eprintln!("error incrementing peers_per_torrent histogram: {}", err) + } + } + for torrent in torrents.ipv6.values() { + let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64; + + if let Err(err) = peers_per_torrent.increment(num_peers) { + eprintln!("error incrementing peers_per_torrent histogram: {}", err) + } + } + } + + if peers_per_torrent.entries() != 0 { + println!( + "peers per torrent: min: {}, p50: {}, p75: {}, p90: {}, p99: {}, p999: {}, max: {}", + peers_per_torrent.minimum().unwrap(), + peers_per_torrent.percentile(50.0).unwrap(), + peers_per_torrent.percentile(75.0).unwrap(), + peers_per_torrent.percentile(90.0).unwrap(), + peers_per_torrent.percentile(99.0).unwrap(), + peers_per_torrent.percentile(99.9).unwrap(), + peers_per_torrent.maximum().unwrap(), + ); + } +} diff --git a/aquatic_ws/src/lib/mio/network/connection.rs b/aquatic_ws/src/lib/mio/network/connection.rs new file mode 100644 index 0000000..4769dc6 --- /dev/null +++ b/aquatic_ws/src/lib/mio/network/connection.rs @@ -0,0 +1,298 @@ +use std::io::{Read, Write}; +use std::net::SocketAddr; + +use either::Either; +use hashbrown::HashMap; +use log::info; +use mio::net::TcpStream; +use mio::{Poll, Token}; +use native_tls::{MidHandshakeTlsStream, TlsAcceptor, TlsStream}; +use tungstenite::handshake::{server::NoCallback, HandshakeError, MidHandshake}; +use tungstenite::protocol::WebSocketConfig; +use tungstenite::ServerHandshake; +use tungstenite::WebSocket; + +use crate::common::*; + +pub enum Stream { + TcpStream(TcpStream), + TlsStream(TlsStream), +} + +impl Stream { + #[inline] + pub fn get_peer_addr(&self) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.peer_addr(), + Self::TlsStream(stream) => stream.get_ref().peer_addr(), + } + } + + #[inline] + pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { + match self { + Self::TcpStream(stream) => poll.registry().deregister(stream), + Self::TlsStream(stream) => poll.registry().deregister(stream.get_mut()), + } + } +} + +impl Read for Stream { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> Result { + match self { + Self::TcpStream(stream) => stream.read(buf), + Self::TlsStream(stream) => stream.read(buf), + } + } + + /// Not used but provided for completeness + #[inline] + fn read_vectored( + &mut self, + bufs: &mut [::std::io::IoSliceMut<'_>], + ) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.read_vectored(bufs), + Self::TlsStream(stream) => stream.read_vectored(bufs), + } + } +} + +impl Write for Stream { + #[inline] + fn write(&mut self, buf: &[u8]) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.write(buf), + Self::TlsStream(stream) => stream.write(buf), + } + } + + /// Not used but provided for completeness + #[inline] + fn write_vectored(&mut self, bufs: &[::std::io::IoSlice<'_>]) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.write_vectored(bufs), + Self::TlsStream(stream) => stream.write_vectored(bufs), + } + } + + #[inline] + fn flush(&mut self) -> ::std::io::Result<()> { + match self { + Self::TcpStream(stream) => stream.flush(), + Self::TlsStream(stream) => stream.flush(), + } + } +} + +enum HandshakeMachine { + TcpStream(TcpStream), + TlsStream(TlsStream), + TlsMidHandshake(MidHandshakeTlsStream), + WsMidHandshake(MidHandshake>), +} + +impl HandshakeMachine { + #[inline] + fn new(tcp_stream: TcpStream) -> Self { + Self::TcpStream(tcp_stream) + } + + #[inline] + fn advance( + self, + ws_config: WebSocketConfig, + opt_tls_acceptor: &Option, // If set, run TLS + ) -> (Option>, bool) { + // bool = stop looping + match self { + HandshakeMachine::TcpStream(stream) => { + if let Some(tls_acceptor) = opt_tls_acceptor { + Self::handle_tls_handshake_result(tls_acceptor.accept(stream)) + } else { + let handshake_result = ::tungstenite::accept_with_config( + Stream::TcpStream(stream), + Some(ws_config), + ); + + Self::handle_ws_handshake_result(handshake_result) + } + } + HandshakeMachine::TlsStream(stream) => { + let handshake_result = ::tungstenite::accept(Stream::TlsStream(stream)); + + Self::handle_ws_handshake_result(handshake_result) + } + HandshakeMachine::TlsMidHandshake(handshake) => { + Self::handle_tls_handshake_result(handshake.handshake()) + } + HandshakeMachine::WsMidHandshake(handshake) => { + Self::handle_ws_handshake_result(handshake.handshake()) + } + } + } + + #[inline] + fn handle_tls_handshake_result( + result: Result, ::native_tls::HandshakeError>, + ) -> (Option>, bool) { + match result { + Ok(stream) => { + ::log::trace!( + "established tls handshake with peer with addr: {:?}", + stream.get_ref().peer_addr() + ); + + (Some(Either::Right(Self::TlsStream(stream))), false) + } + Err(native_tls::HandshakeError::WouldBlock(handshake)) => { + (Some(Either::Right(Self::TlsMidHandshake(handshake))), true) + } + Err(native_tls::HandshakeError::Failure(err)) => { + info!("tls handshake error: {}", err); + + (None, false) + } + } + } + + #[inline] + fn handle_ws_handshake_result( + result: Result, HandshakeError>>, + ) -> (Option>, bool) { + match result { + Ok(mut ws) => match ws.get_mut().get_peer_addr() { + Ok(peer_addr) => { + ::log::trace!( + "established ws handshake with peer with addr: {:?}", + peer_addr + ); + + let established_ws = EstablishedWs { ws, peer_addr }; + + (Some(Either::Left(established_ws)), false) + } + Err(err) => { + ::log::info!( + "get_peer_addr failed during handshake, removing connection: {:?}", + err + ); + + (None, false) + } + }, + Err(HandshakeError::Interrupted(handshake)) => ( + Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))), + true, + ), + Err(HandshakeError::Failure(err)) => { + info!("ws handshake error: {}", err); + + (None, false) + } + } + } +} + +pub struct EstablishedWs { + pub ws: WebSocket, + pub peer_addr: SocketAddr, +} + +pub struct Connection { + ws_config: WebSocketConfig, + pub valid_until: ValidUntil, + inner: Either, +} + +/// Create from TcpStream. Run `advance_handshakes` until `get_established_ws` +/// returns Some(EstablishedWs). +/// +/// advance_handshakes takes ownership of self because the TLS and WebSocket +/// handshake methods do. get_established_ws doesn't, since work can be done +/// on a mutable reference to a tungstenite websocket, and this way, the whole +/// Connection doesn't have to be removed from and reinserted into the +/// TorrentMap. This is also the reason for wrapping Container.inner in an +/// Either instead of combining all states into one structure just having a +/// single method for advancing handshakes and maybe returning a websocket. +impl Connection { + #[inline] + pub fn new(ws_config: WebSocketConfig, valid_until: ValidUntil, tcp_stream: TcpStream) -> Self { + Self { + ws_config, + valid_until, + inner: Either::Right(HandshakeMachine::new(tcp_stream)), + } + } + + #[inline] + pub fn get_established_ws(&mut self) -> Option<&mut EstablishedWs> { + match self.inner { + Either::Left(ref mut ews) => Some(ews), + Either::Right(_) => None, + } + } + + #[inline] + pub fn advance_handshakes( + self, + opt_tls_acceptor: &Option, + valid_until: ValidUntil, + ) -> (Option, bool) { + match self.inner { + Either::Left(_) => (Some(self), false), + Either::Right(machine) => { + let ws_config = self.ws_config; + + let (opt_inner, stop_loop) = machine.advance(ws_config, opt_tls_acceptor); + + let opt_new_self = opt_inner.map(|inner| Self { + ws_config, + valid_until, + inner, + }); + + (opt_new_self, stop_loop) + } + } + } + + #[inline] + pub fn close(&mut self) { + if let Either::Left(ref mut ews) = self.inner { + if ews.ws.can_read() { + if let Err(err) = ews.ws.close(None) { + ::log::info!("error closing ws: {}", err); + } + + // Required after ws.close() + if let Err(err) = ews.ws.write_pending() { + ::log::info!("error writing pending messages after closing ws: {}", err) + } + } + } + } + + pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { + use Either::{Left, Right}; + + match self.inner { + Left(EstablishedWs { ref mut ws, .. }) => ws.get_mut().deregister(poll), + Right(HandshakeMachine::TcpStream(ref mut stream)) => { + poll.registry().deregister(stream) + } + Right(HandshakeMachine::TlsMidHandshake(ref mut handshake)) => { + poll.registry().deregister(handshake.get_mut()) + } + Right(HandshakeMachine::TlsStream(ref mut stream)) => { + poll.registry().deregister(stream.get_mut()) + } + Right(HandshakeMachine::WsMidHandshake(ref mut handshake)) => { + handshake.get_mut().get_mut().deregister(poll) + } + } + } +} + +pub type ConnectionMap = HashMap; diff --git a/aquatic_ws/src/lib/mio/network/mod.rs b/aquatic_ws/src/lib/mio/network/mod.rs new file mode 100644 index 0000000..a109f79 --- /dev/null +++ b/aquatic_ws/src/lib/mio/network/mod.rs @@ -0,0 +1,332 @@ +use std::io::ErrorKind; +use std::time::Duration; +use std::vec::Drain; + +use aquatic_common::access_list::AccessListQuery; +use crossbeam_channel::Receiver; +use hashbrown::HashMap; +use log::{debug, error, info}; +use mio::net::TcpListener; +use mio::{Events, Interest, Poll, Token}; +use native_tls::TlsAcceptor; +use tungstenite::protocol::WebSocketConfig; + +use aquatic_common::convert_ipv4_mapped_ipv6; +use aquatic_ws_protocol::*; + +use crate::common::*; +use crate::config::Config; + +use super::common::*; + +pub mod connection; +pub mod utils; + +use connection::*; +use utils::*; + +pub fn run_socket_worker( + config: Config, + state: State, + socket_worker_index: usize, + socket_worker_statuses: SocketWorkerStatuses, + poll: Poll, + in_message_sender: InMessageSender, + out_message_receiver: OutMessageReceiver, + opt_tls_acceptor: Option, +) { + match create_listener(&config) { + Ok(listener) => { + socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(())); + + run_poll_loop( + config, + &state, + socket_worker_index, + poll, + in_message_sender, + out_message_receiver, + listener, + opt_tls_acceptor, + ); + } + Err(err) => { + socket_worker_statuses.lock()[socket_worker_index] = + Some(Err(format!("Couldn't open socket: {:#}", err))); + } + } +} + +pub fn run_poll_loop( + config: Config, + state: &State, + socket_worker_index: usize, + mut poll: Poll, + in_message_sender: InMessageSender, + out_message_receiver: OutMessageReceiver, + listener: ::std::net::TcpListener, + opt_tls_acceptor: Option, +) { + let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds); + let ws_config = WebSocketConfig { + max_message_size: Some(config.network.websocket_max_message_size), + max_frame_size: Some(config.network.websocket_max_frame_size), + max_send_queue: None, + ..Default::default() + }; + + let mut listener = TcpListener::from_std(listener); + let mut events = Events::with_capacity(config.network.poll_event_capacity); + + poll.registry() + .register(&mut listener, LISTENER_TOKEN, Interest::READABLE) + .unwrap(); + + let mut connections: ConnectionMap = HashMap::new(); + let mut local_responses = Vec::new(); + + let mut poll_token_counter = Token(0usize); + let mut iter_counter = 0usize; + + loop { + poll.poll(&mut events, Some(poll_timeout)) + .expect("failed polling"); + + let valid_until = ValidUntil::new(config.cleaning.max_connection_age); + + for event in events.iter() { + let token = event.token(); + + if token == LISTENER_TOKEN { + accept_new_streams( + ws_config, + &mut listener, + &mut poll, + &mut connections, + valid_until, + &mut poll_token_counter, + ); + } else if token != CHANNEL_TOKEN { + run_handshakes_and_read_messages( + &config, + state, + socket_worker_index, + &mut local_responses, + &in_message_sender, + &opt_tls_acceptor, + &mut poll, + &mut connections, + token, + valid_until, + ); + } + + send_out_messages( + &mut poll, + local_responses.drain(..), + &out_message_receiver, + &mut connections, + ); + } + + // Remove inactive connections, but not every iteration + if iter_counter % 128 == 0 { + remove_inactive_connections(&mut connections); + } + + iter_counter = iter_counter.wrapping_add(1); + } +} + +fn accept_new_streams( + ws_config: WebSocketConfig, + listener: &mut TcpListener, + poll: &mut Poll, + connections: &mut ConnectionMap, + valid_until: ValidUntil, + poll_token_counter: &mut Token, +) { + loop { + match listener.accept() { + Ok((mut stream, _)) => { + poll_token_counter.0 = poll_token_counter.0.wrapping_add(1); + + if poll_token_counter.0 < 2 { + poll_token_counter.0 = 2; + } + + let token = *poll_token_counter; + + remove_connection_if_exists(poll, connections, token); + + poll.registry() + .register(&mut stream, token, Interest::READABLE) + .unwrap(); + + let connection = Connection::new(ws_config, valid_until, stream); + + connections.insert(token, connection); + } + Err(err) => { + if err.kind() == ErrorKind::WouldBlock { + break; + } + + info!("error while accepting streams: {}", err); + } + } + } +} + +/// On the stream given by poll_token, get TLS (if requested) and tungstenite +/// up and running, then read messages and pass on through channel. +pub fn run_handshakes_and_read_messages( + config: &Config, + state: &State, + socket_worker_index: usize, + local_responses: &mut Vec<(ConnectionMeta, OutMessage)>, + in_message_sender: &InMessageSender, + opt_tls_acceptor: &Option, // If set, run TLS + poll: &mut Poll, + connections: &mut ConnectionMap, + poll_token: Token, + valid_until: ValidUntil, +) { + let access_list_mode = config.access_list.mode; + + loop { + if let Some(established_ws) = connections + .get_mut(&poll_token) + .map(|c| { + // Ugly but works + c.valid_until = valid_until; + + c + }) + .and_then(Connection::get_established_ws) + { + use ::tungstenite::Error::Io; + + match established_ws.ws.read_message() { + Ok(ws_message) => { + let naive_peer_addr = established_ws.peer_addr; + let converted_peer_ip = convert_ipv4_mapped_ipv6(naive_peer_addr.ip()); + + let meta = ConnectionMeta { + out_message_consumer_id: ConsumerId(socket_worker_index), + connection_id: ConnectionId(poll_token.0), + naive_peer_addr, + converted_peer_ip, + pending_scrape_id: None, // FIXME + }; + + debug!("read message"); + + match InMessage::from_ws_message(ws_message) { + Ok(InMessage::AnnounceRequest(ref request)) + if !state + .access_list + .allows(access_list_mode, &request.info_hash.0) => + { + let out_message = OutMessage::ErrorResponse(ErrorResponse { + failure_reason: "Info hash not allowed".into(), + action: Some(ErrorResponseAction::Announce), + info_hash: Some(request.info_hash), + }); + + local_responses.push((meta, out_message)); + } + Ok(in_message) => { + if let Err(err) = in_message_sender.send((meta, in_message)) { + error!("InMessageSender: couldn't send message: {:?}", err); + } + } + Err(_) => { + // FIXME: maybe this condition just occurs when enough data hasn't been recevied? + /* + info!("error parsing message: {:?}", err); + let out_message = OutMessage::ErrorResponse(ErrorResponse { + failure_reason: "Error parsing message".into(), + action: None, + info_hash: None, + }); + local_responses.push((meta, out_message)); + */ + } + } + } + Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => { + break; + } + Err(tungstenite::Error::ConnectionClosed) => { + remove_connection_if_exists(poll, connections, poll_token); + + break; + } + Err(err) => { + info!("error reading messages: {}", err); + + remove_connection_if_exists(poll, connections, poll_token); + + break; + } + } + } else if let Some(connection) = connections.remove(&poll_token) { + let (opt_new_connection, stop_loop) = + connection.advance_handshakes(opt_tls_acceptor, valid_until); + + if let Some(connection) = opt_new_connection { + connections.insert(poll_token, connection); + } + + if stop_loop { + break; + } + } else { + break; + } + } +} + +/// Read messages from channel, send to peers +pub fn send_out_messages( + poll: &mut Poll, + local_responses: Drain<(ConnectionMeta, OutMessage)>, + out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>, + connections: &mut ConnectionMap, +) { + let len = out_message_receiver.len(); + + for (meta, out_message) in local_responses.chain(out_message_receiver.try_iter().take(len)) { + let opt_established_ws = connections + .get_mut(&Token(meta.connection_id.0)) + .and_then(Connection::get_established_ws); + + if let Some(established_ws) = opt_established_ws { + if established_ws.peer_addr != meta.naive_peer_addr { + info!("socket worker error: peer socket addrs didn't match"); + + continue; + } + + use ::tungstenite::Error::Io; + + let ws_message = out_message.to_ws_message(); + + match established_ws.ws.write_message(ws_message) { + Ok(()) => { + debug!("sent message"); + } + Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => {} + Err(tungstenite::Error::ConnectionClosed) => { + remove_connection_if_exists(poll, connections, Token(meta.connection_id.0)); + } + Err(err) => { + info!("error writing ws message: {}", err); + + remove_connection_if_exists(poll, connections, Token(meta.connection_id.0)); + } + } + } + } +} diff --git a/aquatic_ws/src/lib/mio/network/utils.rs b/aquatic_ws/src/lib/mio/network/utils.rs new file mode 100644 index 0000000..2568c97 --- /dev/null +++ b/aquatic_ws/src/lib/mio/network/utils.rs @@ -0,0 +1,66 @@ +use std::time::Instant; + +use anyhow::Context; +use mio::{Poll, Token}; +use socket2::{Domain, Protocol, Socket, Type}; + +use crate::config::Config; + +use super::connection::*; + +pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> { + let builder = if config.network.address.is_ipv4() { + Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)) + } else { + Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)) + } + .context("Couldn't create socket2::Socket")?; + + if config.network.ipv6_only { + builder + .set_only_v6(true) + .context("Couldn't put socket in ipv6 only mode")? + } + + builder + .set_nonblocking(true) + .context("Couldn't put socket in non-blocking mode")?; + builder + .set_reuse_port(true) + .context("Couldn't put socket in reuse_port mode")?; + builder + .bind(&config.network.address.into()) + .with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?; + builder + .listen(128) + .context("Couldn't listen for connections on socket")?; + + Ok(builder.into()) +} + +pub fn remove_connection_if_exists(poll: &mut Poll, connections: &mut ConnectionMap, token: Token) { + if let Some(mut connection) = connections.remove(&token) { + connection.close(); + + if let Err(err) = connection.deregister(poll) { + ::log::error!("couldn't deregister stream: {}", err); + } + } +} + +// Close and remove inactive connections +pub fn remove_inactive_connections(connections: &mut ConnectionMap) { + let now = Instant::now(); + + connections.retain(|_, connection| { + if connection.valid_until.0 < now { + connection.close(); + + false + } else { + true + } + }); + + connections.shrink_to_fit(); +} diff --git a/aquatic_ws/src/lib/tasks.rs b/aquatic_ws/src/lib/tasks.rs index 59ad0f4..e4139fa 100644 --- a/aquatic_ws/src/lib/tasks.rs +++ b/aquatic_ws/src/lib/tasks.rs @@ -4,13 +4,3 @@ use histogram::Histogram; use crate::common::*; use crate::config::Config; -pub fn update_access_list(config: &Config, state: &State) { - match config.access_list.mode { - AccessListMode::White | AccessListMode::Black => { - if let Err(err) = state.access_list.update_from_path(&config.access_list.path) { - ::log::error!("Couldn't update access list: {:?}", err); - } - } - AccessListMode::Off => {} - } -}