From 667cf04085e79ccb637b8d56c95b80f770b26ef6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Fri, 18 Mar 2022 15:15:34 +0100 Subject: [PATCH] ws: remove mio implementation --- Cargo.lock | 86 --- aquatic_ws/Cargo.toml | 22 +- aquatic_ws/src/{common/mod.rs => common.rs} | 9 +- aquatic_ws/src/config.rs | 143 ++--- aquatic_ws/src/glommio/common.rs | 10 - aquatic_ws/src/glommio/mod.rs | 107 ---- aquatic_ws/src/glommio/request.rs | 128 ---- aquatic_ws/src/lib.rs | 128 +++- aquatic_ws/src/mio/common.rs | 51 -- aquatic_ws/src/mio/mod.rs | 218 ------- aquatic_ws/src/mio/request.rs | 103 ---- aquatic_ws/src/mio/socket/connection.rs | 577 ------------------ aquatic_ws/src/mio/socket/mod.rs | 403 ------------ aquatic_ws/src/workers/mod.rs | 2 + .../handlers.rs => workers/request.rs} | 125 +++- aquatic_ws/src/{glommio => workers}/socket.rs | 2 - 16 files changed, 283 insertions(+), 1831 deletions(-) rename aquatic_ws/src/{common/mod.rs => common.rs} (96%) delete mode 100644 aquatic_ws/src/glommio/common.rs delete mode 100644 aquatic_ws/src/glommio/mod.rs delete mode 100644 aquatic_ws/src/glommio/request.rs delete mode 100644 aquatic_ws/src/mio/common.rs delete mode 100644 aquatic_ws/src/mio/mod.rs delete mode 100644 aquatic_ws/src/mio/request.rs delete mode 100644 aquatic_ws/src/mio/socket/connection.rs delete mode 100644 aquatic_ws/src/mio/socket/mod.rs create mode 100644 aquatic_ws/src/workers/mod.rs rename aquatic_ws/src/{common/handlers.rs => workers/request.rs} (59%) rename aquatic_ws/src/{glommio => workers}/socket.rs (99%) diff --git a/Cargo.lock b/Cargo.lock index 641b919..46e4cf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -264,18 +264,14 @@ dependencies = [ "aquatic_ws_protocol", "async-tungstenite", "cfg-if", - "crossbeam-channel", "either", "futures", "futures-lite", "futures-rustls", "glommio", "hashbrown 0.12.0", - "histogram", "log", "mimalloc", - "mio", - "parking_lot", "privdrop", "quickcheck", "quickcheck_macros", @@ -285,7 +281,6 @@ dependencies = [ "serde", "signal-hook", "slab", - "socket2 0.4.4", "tungstenite", ] @@ -1066,12 +1061,6 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" -[[package]] -name = "histogram" -version = "0.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12cb882ccb290b8646e554b157ab0b71e64e8d5bef775cd66b6531e52d302669" - [[package]] name = "http" version = "0.2.6" @@ -1471,29 +1460,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72" -[[package]] -name = "parking_lot" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-sys", -] - [[package]] name = "percent-encoding" version = "2.1.0" @@ -1687,15 +1653,6 @@ dependencies = [ "num_cpus", ] -[[package]] -name = "redox_syscall" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff" -dependencies = [ - "bitflags 1.3.2", -] - [[package]] name = "regex" version = "1.5.4" @@ -2397,46 +2354,3 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-sys" -version = "0.32.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6" -dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_msvc" -version = "0.32.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5" - -[[package]] -name = "windows_i686_gnu" -version = "0.32.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615" - -[[package]] -name = "windows_i686_msvc" -version = "0.32.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.32.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.32.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316" diff --git a/aquatic_ws/Cargo.toml b/aquatic_ws/Cargo.toml index 255cfa0..ea08336 100644 --- a/aquatic_ws/Cargo.toml +++ b/aquatic_ws/Cargo.toml @@ -14,10 +14,7 @@ name = "aquatic_ws" name = "aquatic_ws" [features] -default = ["with-mio"] cpu-pinning = ["aquatic_common/cpu-pinning"] -with-glommio = ["cpu-pinning", "async-tungstenite", "futures-lite", "futures", "futures-rustls", "glommio"] -with-mio = ["crossbeam-channel", "histogram", "mio", "parking_lot", "socket2"] [dependencies] aquatic_cli_helpers = "0.1.0" @@ -39,20 +36,11 @@ serde = { version = "1", features = ["derive"] } signal-hook = { version = "0.3" } slab = "0.4" tungstenite = "0.17" - -# mio -crossbeam-channel = { version = "0.5", optional = true } -histogram = { version = "0.6", optional = true } -mio = { version = "0.8", features = ["net", "os-poll"], optional = true } -parking_lot = { version = "0.12", optional = true } -socket2 = { version = "0.4", features = ["all"], optional = true } - -# glommio -async-tungstenite = { version = "0.17", optional = true } -futures-lite = { version = "1", optional = true } -futures = { version = "0.3", optional = true } -futures-rustls = { version = "0.22", optional = true } -glommio = { version = "0.7", optional = true } +async-tungstenite = "0.17" +futures-lite = "1" +futures = "0.3" +futures-rustls = "0.22" +glommio = "0.7" [dev-dependencies] quickcheck = "1" diff --git a/aquatic_ws/src/common/mod.rs b/aquatic_ws/src/common.rs similarity index 96% rename from aquatic_ws/src/common/mod.rs rename to aquatic_ws/src/common.rs index 0f43d2b..f7ced95 100644 --- a/aquatic_ws/src/common/mod.rs +++ b/aquatic_ws/src/common.rs @@ -1,5 +1,3 @@ -pub mod handlers; - use std::fs::File; use std::io::BufReader; use std::sync::Arc; @@ -14,6 +12,13 @@ use aquatic_ws_protocol::*; use crate::config::Config; +pub type TlsConfig = futures_rustls::rustls::ServerConfig; + +#[derive(Default, Clone)] +pub struct State { + pub access_list: Arc, +} + #[derive(Copy, Clone, Debug)] pub struct PendingScrapeId(pub usize); diff --git a/aquatic_ws/src/config.rs b/aquatic_ws/src/config.rs index 76c7f7a..10ae6e1 100644 --- a/aquatic_ws/src/config.rs +++ b/aquatic_ws/src/config.rs @@ -23,15 +23,28 @@ pub struct Config { pub log_level: LogLevel, pub network: NetworkConfig, pub protocol: ProtocolConfig, - #[cfg(feature = "with-mio")] - pub handlers: HandlerConfig, pub cleaning: CleaningConfig, pub privileges: PrivilegeConfig, pub access_list: AccessListConfig, #[cfg(feature = "cpu-pinning")] pub cpu_pinning: CpuPinningConfig, - #[cfg(feature = "with-mio")] - pub statistics: StatisticsConfig, +} + +impl Default for Config { + fn default() -> Self { + Self { + socket_workers: 1, + request_workers: 1, + log_level: LogLevel::default(), + network: NetworkConfig::default(), + protocol: ProtocolConfig::default(), + cleaning: CleaningConfig::default(), + privileges: PrivilegeConfig::default(), + access_list: AccessListConfig::default(), + #[cfg(feature = "cpu-pinning")] + cpu_pinning: Default::default(), + } + } } impl aquatic_cli_helpers::Config for Config { @@ -55,81 +68,6 @@ pub struct NetworkConfig { pub websocket_max_message_size: usize, pub websocket_max_frame_size: usize, - - #[cfg(feature = "with-mio")] - pub poll_event_capacity: usize, - #[cfg(feature = "with-mio")] - pub poll_timeout_microseconds: u64, -} - -#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] -#[serde(default)] -pub struct ProtocolConfig { - /// Maximum number of torrents to accept in scrape request - pub max_scrape_torrents: usize, - /// Maximum number of offers to accept in announce request - pub max_offers: usize, - /// Ask peers to announce this often (seconds) - pub peer_announce_interval: usize, -} - -#[cfg(feature = "with-mio")] -#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] -#[serde(default)] -pub struct HandlerConfig { - /// Maximum number of requests to receive from channel before locking - /// mutex and starting work - pub max_requests_per_iter: usize, - pub channel_recv_timeout_microseconds: u64, -} - -#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] -#[serde(default)] -pub struct CleaningConfig { - /// Clean peers this often (seconds) - pub torrent_cleaning_interval: u64, - /// Remove peers that have not announced for this long (seconds) - pub max_peer_age: u64, - - // Clean connections this often (seconds) - #[cfg(feature = "with-glommio")] - pub connection_cleaning_interval: u64, - /// Close connections if no responses have been sent to them for this long (seconds) - #[cfg(feature = "with-glommio")] - pub max_connection_idle: u64, - - /// Remove connections that are older than this (seconds) - #[cfg(feature = "with-mio")] - pub max_connection_age: u64, -} - -#[cfg(feature = "with-mio")] -#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] -#[serde(default)] -pub struct StatisticsConfig { - /// Print statistics this often (seconds). Do not print when set to zero. - pub interval: u64, -} - -impl Default for Config { - fn default() -> Self { - Self { - socket_workers: 1, - request_workers: 1, - log_level: LogLevel::default(), - network: NetworkConfig::default(), - protocol: ProtocolConfig::default(), - #[cfg(feature = "with-mio")] - handlers: Default::default(), - cleaning: CleaningConfig::default(), - privileges: PrivilegeConfig::default(), - access_list: AccessListConfig::default(), - #[cfg(feature = "cpu-pinning")] - cpu_pinning: Default::default(), - #[cfg(feature = "with-mio")] - statistics: Default::default(), - } - } } impl Default for NetworkConfig { @@ -143,15 +81,21 @@ impl Default for NetworkConfig { websocket_max_message_size: 64 * 1024, websocket_max_frame_size: 16 * 1024, - - #[cfg(feature = "with-mio")] - poll_event_capacity: 4096, - #[cfg(feature = "with-mio")] - poll_timeout_microseconds: 200_000, } } } +#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] +#[serde(default)] +pub struct ProtocolConfig { + /// Maximum number of torrents to accept in scrape request + pub max_scrape_torrents: usize, + /// Maximum number of offers to accept in announce request + pub max_offers: usize, + /// Ask peers to announce this often (seconds) + pub peer_announce_interval: usize, +} + impl Default for ProtocolConfig { fn default() -> Self { Self { @@ -162,14 +106,17 @@ impl Default for ProtocolConfig { } } -#[cfg(feature = "with-mio")] -impl Default for HandlerConfig { - fn default() -> Self { - Self { - max_requests_per_iter: 256, - channel_recv_timeout_microseconds: 200, - } - } +#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] +#[serde(default)] +pub struct CleaningConfig { + /// Clean peers this often (seconds) + pub torrent_cleaning_interval: u64, + /// Remove peers that have not announced for this long (seconds) + pub max_peer_age: u64, + // Clean connections this often (seconds) + pub connection_cleaning_interval: u64, + /// Close connections if no responses have been sent to them for this long (seconds) + pub max_connection_idle: u64, } impl Default for CleaningConfig { @@ -177,24 +124,12 @@ impl Default for CleaningConfig { Self { torrent_cleaning_interval: 30, max_peer_age: 1800, - #[cfg(feature = "with-glommio")] max_connection_idle: 60 * 5, - - #[cfg(feature = "with-mio")] - max_connection_age: 1800, - #[cfg(feature = "with-glommio")] connection_cleaning_interval: 30, } } } -#[cfg(feature = "with-mio")] -impl Default for StatisticsConfig { - fn default() -> Self { - Self { interval: 0 } - } -} - #[cfg(test)] mod tests { use super::Config; diff --git a/aquatic_ws/src/glommio/common.rs b/aquatic_ws/src/glommio/common.rs deleted file mode 100644 index 539176e..0000000 --- a/aquatic_ws/src/glommio/common.rs +++ /dev/null @@ -1,10 +0,0 @@ -use std::sync::Arc; - -use aquatic_common::access_list::AccessListArcSwap; - -pub type TlsConfig = futures_rustls::rustls::ServerConfig; - -#[derive(Default, Clone)] -pub struct State { - pub access_list: Arc, -} diff --git a/aquatic_ws/src/glommio/mod.rs b/aquatic_ws/src/glommio/mod.rs deleted file mode 100644 index f1ceb69..0000000 --- a/aquatic_ws/src/glommio/mod.rs +++ /dev/null @@ -1,107 +0,0 @@ -pub mod common; -pub mod request; -pub mod socket; - -use std::sync::{atomic::AtomicUsize, Arc}; - -use crate::{common::create_tls_config, config::Config}; -#[cfg(feature = "cpu-pinning")] -use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; -use aquatic_common::privileges::drop_privileges_after_socket_binding; - -use self::common::*; - -use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; - -pub const SHARED_IN_CHANNEL_SIZE: usize = 1024; - -pub fn run(config: Config, state: State) -> anyhow::Result<()> { - let num_peers = config.socket_workers + config.request_workers; - - let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE); - let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16); - - let num_bound_sockets = Arc::new(AtomicUsize::new(0)); - - let tls_config = Arc::new(create_tls_config(&config).unwrap()); - - let mut executors = Vec::new(); - - for i in 0..(config.socket_workers) { - let config = config.clone(); - let state = state.clone(); - let tls_config = tls_config.clone(); - let request_mesh_builder = request_mesh_builder.clone(); - let response_mesh_builder = response_mesh_builder.clone(); - let num_bound_sockets = num_bound_sockets.clone(); - - let builder = LocalExecutorBuilder::default().name("socket"); - - let executor = builder.spawn(move || async move { - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - WorkerIndex::SocketWorker(i), - ); - - socket::run_socket_worker( - config, - state, - tls_config, - request_mesh_builder, - response_mesh_builder, - num_bound_sockets, - ) - .await - }); - - executors.push(executor); - } - - for i in 0..(config.request_workers) { - let config = config.clone(); - let state = state.clone(); - let request_mesh_builder = request_mesh_builder.clone(); - let response_mesh_builder = response_mesh_builder.clone(); - - let builder = LocalExecutorBuilder::default().name("request"); - - let executor = builder.spawn(move || async move { - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - WorkerIndex::RequestWorker(i), - ); - - request::run_request_worker(config, state, request_mesh_builder, response_mesh_builder) - .await - }); - - executors.push(executor); - } - - drop_privileges_after_socket_binding( - &config.privileges, - num_bound_sockets, - config.socket_workers, - ) - .unwrap(); - - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - WorkerIndex::Other, - ); - - for executor in executors { - executor - .expect("failed to spawn local executor") - .join() - .unwrap(); - } - - Ok(()) -} diff --git a/aquatic_ws/src/glommio/request.rs b/aquatic_ws/src/glommio/request.rs deleted file mode 100644 index bf9cdfb..0000000 --- a/aquatic_ws/src/glommio/request.rs +++ /dev/null @@ -1,128 +0,0 @@ -use std::cell::RefCell; -use std::rc::Rc; -use std::time::Duration; - -use futures::StreamExt; -use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; -use glommio::enclose; -use glommio::prelude::*; -use glommio::timer::TimerActionRepeat; -use rand::{rngs::SmallRng, SeedableRng}; - -use aquatic_ws_protocol::*; - -use crate::common::handlers::*; -use crate::common::*; -use crate::config::Config; - -use super::common::State; -use super::SHARED_IN_CHANNEL_SIZE; - -pub async fn run_request_worker( - config: Config, - state: State, - in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, - out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, -) { - let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); - let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); - - let out_message_senders = Rc::new(out_message_senders); - - let torrents = Rc::new(RefCell::new(TorrentMaps::default())); - let access_list = state.access_list; - - // Periodically clean torrents - TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || { - enclose!((config, torrents, access_list) move || async move { - torrents.borrow_mut().clean(&config, &access_list); - - Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval)) - })() - })); - - let mut handles = Vec::new(); - - for (_, receiver) in in_message_receivers.streams() { - let handle = spawn_local(handle_request_stream( - config.clone(), - torrents.clone(), - out_message_senders.clone(), - receiver, - )) - .detach(); - - handles.push(handle); - } - - for handle in handles { - handle.await; - } -} - -async fn handle_request_stream( - config: Config, - torrents: Rc>, - out_message_senders: Rc>, - stream: S, -) where - S: futures_lite::Stream + ::std::marker::Unpin, -{ - let rng = Rc::new(RefCell::new(SmallRng::from_entropy())); - - let max_peer_age = config.cleaning.max_peer_age; - let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age))); - - TimerActionRepeat::repeat(enclose!((peer_valid_until) move || { - enclose!((peer_valid_until) move || async move { - *peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age); - - Some(Duration::from_secs(1)) - })() - })); - - let config = &config; - let torrents = &torrents; - let peer_valid_until = &peer_valid_until; - let rng = &rng; - let out_message_senders = &out_message_senders; - - stream - .for_each_concurrent( - SHARED_IN_CHANNEL_SIZE, - move |(meta, in_message)| async move { - let mut out_messages = Vec::new(); - - match in_message { - InMessage::AnnounceRequest(request) => handle_announce_request( - &config, - &mut rng.borrow_mut(), - &mut torrents.borrow_mut(), - &mut out_messages, - peer_valid_until.borrow().to_owned(), - meta, - request, - ), - InMessage::ScrapeRequest(request) => handle_scrape_request( - &config, - &mut torrents.borrow_mut(), - &mut out_messages, - meta, - request, - ), - }; - - for (meta, out_message) in out_messages.drain(..) { - ::log::info!("request worker trying to send OutMessage to socket worker"); - - out_message_senders - .send_to(meta.out_message_consumer_id.0, (meta, out_message)) - .await - .expect("failed sending out_message to socket worker"); - - ::log::info!("request worker sent OutMessage to socket worker"); - } - }, - ) - .await; -} diff --git a/aquatic_ws/src/lib.rs b/aquatic_ws/src/lib.rs index 2a4f0d9..a9a496c 100644 --- a/aquatic_ws/src/lib.rs +++ b/aquatic_ws/src/lib.rs @@ -1,28 +1,26 @@ use aquatic_common::access_list::update_access_list; #[cfg(feature = "cpu-pinning")] use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; -use cfg_if::cfg_if; +use common::State; use signal_hook::{consts::SIGUSR1, iterator::Signals}; -use crate::config::Config; +use std::sync::{atomic::AtomicUsize, Arc}; + +use crate::{common::create_tls_config, config::Config}; +use aquatic_common::privileges::drop_privileges_after_socket_binding; + +use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; pub mod common; pub mod config; -#[cfg(feature = "with-glommio")] -pub mod glommio; -#[cfg(feature = "with-mio")] -pub mod mio; +pub mod workers; pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; +pub const SHARED_IN_CHANNEL_SIZE: usize = 1024; + pub fn run(config: Config) -> ::anyhow::Result<()> { - cfg_if!( - if #[cfg(feature = "with-glommio")] { - let state = glommio::common::State::default(); - } else { - let state = mio::common::State::default(); - } - ); + let state = State::default(); update_access_list(&config.access_list, &state.access_list)?; @@ -32,13 +30,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let config = config.clone(); let state = state.clone(); - cfg_if!( - if #[cfg(feature = "with-glommio")] { - ::std::thread::spawn(move || glommio::run(config, state)); - } else { - ::std::thread::spawn(move || mio::run(config, state)); - } - ); + ::std::thread::spawn(move || run_workers(config, state)); } #[cfg(feature = "cpu-pinning")] @@ -59,3 +51,99 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { Ok(()) } + +pub fn run_workers(config: Config, state: State) -> anyhow::Result<()> { + let num_peers = config.socket_workers + config.request_workers; + + let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE); + let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16); + + let num_bound_sockets = Arc::new(AtomicUsize::new(0)); + + let tls_config = Arc::new(create_tls_config(&config).unwrap()); + + let mut executors = Vec::new(); + + for i in 0..(config.socket_workers) { + let config = config.clone(); + let state = state.clone(); + let tls_config = tls_config.clone(); + let request_mesh_builder = request_mesh_builder.clone(); + let response_mesh_builder = response_mesh_builder.clone(); + let num_bound_sockets = num_bound_sockets.clone(); + + let builder = LocalExecutorBuilder::default().name("socket"); + + let executor = builder.spawn(move || async move { + #[cfg(feature = "cpu-pinning")] + pin_current_if_configured_to( + &config.cpu_pinning, + config.socket_workers, + WorkerIndex::SocketWorker(i), + ); + + workers::socket::run_socket_worker( + config, + state, + tls_config, + request_mesh_builder, + response_mesh_builder, + num_bound_sockets, + ) + .await + }); + + executors.push(executor); + } + + for i in 0..(config.request_workers) { + let config = config.clone(); + let state = state.clone(); + let request_mesh_builder = request_mesh_builder.clone(); + let response_mesh_builder = response_mesh_builder.clone(); + + let builder = LocalExecutorBuilder::default().name("request"); + + let executor = builder.spawn(move || async move { + #[cfg(feature = "cpu-pinning")] + pin_current_if_configured_to( + &config.cpu_pinning, + config.socket_workers, + WorkerIndex::RequestWorker(i), + ); + + workers::request::run_request_worker( + config, + state, + request_mesh_builder, + response_mesh_builder, + ) + .await + }); + + executors.push(executor); + } + + drop_privileges_after_socket_binding( + &config.privileges, + num_bound_sockets, + config.socket_workers, + ) + .unwrap(); + + #[cfg(feature = "cpu-pinning")] + pin_current_if_configured_to( + &config.cpu_pinning, + config.socket_workers, + WorkerIndex::Other, + ); + + for executor in executors { + executor + .expect("failed to spawn local executor") + .join() + .unwrap(); + } + + Ok(()) +} diff --git a/aquatic_ws/src/mio/common.rs b/aquatic_ws/src/mio/common.rs deleted file mode 100644 index 5337ca7..0000000 --- a/aquatic_ws/src/mio/common.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::sync::Arc; - -use aquatic_common::access_list::AccessListArcSwap; -use aquatic_ws_protocol::*; -use crossbeam_channel::{Receiver, Sender}; -use log::error; -use mio::Token; -use parking_lot::Mutex; - -use crate::common::*; - -pub const LISTENER_TOKEN: Token = Token(0); -pub const CHANNEL_TOKEN: Token = Token(1); - -#[derive(Clone)] -pub struct State { - pub access_list: Arc, - pub torrent_maps: Arc>, -} - -impl Default for State { - fn default() -> Self { - Self { - access_list: Arc::new(Default::default()), - torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), - } - } -} - -pub type InMessageSender = Sender<(ConnectionMeta, InMessage)>; -pub type InMessageReceiver = Receiver<(ConnectionMeta, InMessage)>; -pub type OutMessageReceiver = Receiver<(ConnectionMeta, OutMessage)>; - -#[derive(Clone)] -pub struct OutMessageSender(Vec>); - -impl OutMessageSender { - pub fn new(senders: Vec>) -> Self { - Self(senders) - } - - #[inline] - pub fn send(&self, meta: ConnectionMeta, message: OutMessage) { - if let Err(err) = self.0[meta.out_message_consumer_id.0].send((meta, message)) { - error!("OutMessageSender: couldn't send message: {:?}", err); - } - } -} - -pub type SocketWorkerStatus = Option>; -pub type SocketWorkerStatuses = Arc>>; diff --git a/aquatic_ws/src/mio/mod.rs b/aquatic_ws/src/mio/mod.rs deleted file mode 100644 index 54a2fa4..0000000 --- a/aquatic_ws/src/mio/mod.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::sync::Arc; -use std::thread::Builder; -use std::time::Duration; - -use anyhow::Context; -#[cfg(feature = "cpu-pinning")] -use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; -use histogram::Histogram; -use mio::{Poll, Waker}; -use parking_lot::Mutex; -use privdrop::PrivDrop; - -pub mod common; -pub mod request; -pub mod socket; - -use crate::{common::create_tls_config, config::Config}; -use common::*; - -pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; - -const SHARED_IN_CHANNEL_SIZE: usize = 1024; - -pub fn run(config: Config, state: State) -> anyhow::Result<()> { - start_workers(config.clone(), state.clone()).expect("couldn't start workers"); - - // TODO: privdrop here instead - - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - WorkerIndex::Other, - ); - - loop { - ::std::thread::sleep(Duration::from_secs( - config.cleaning.torrent_cleaning_interval, - )); - - state.torrent_maps.lock().clean(&config, &state.access_list); - } -} - -pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { - let tls_config = Arc::new(create_tls_config(&config)?); - - let (in_message_sender, in_message_receiver) = - ::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE); - - let mut out_message_senders = Vec::new(); - let mut wakers = Vec::new(); - - let socket_worker_statuses: SocketWorkerStatuses = { - let mut statuses = Vec::new(); - - for _ in 0..config.socket_workers { - statuses.push(None); - } - - Arc::new(Mutex::new(statuses)) - }; - - for i in 0..config.socket_workers { - let config = config.clone(); - let state = state.clone(); - let socket_worker_statuses = socket_worker_statuses.clone(); - let in_message_sender = in_message_sender.clone(); - let tls_config = tls_config.clone(); - let poll = Poll::new()?; - let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?); - - let (out_message_sender, out_message_receiver) = - ::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE * 16); - - out_message_senders.push(out_message_sender); - wakers.push(waker); - - Builder::new() - .name(format!("socket-{:02}", i + 1)) - .spawn(move || { - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - WorkerIndex::SocketWorker(i), - ); - - socket::run_socket_worker( - config, - state, - i, - socket_worker_statuses, - poll, - in_message_sender, - out_message_receiver, - tls_config, - ); - })?; - } - - // Wait for socket worker statuses. On error from any, quit program. - // On success from all, drop privileges if corresponding setting is set - // and continue program. - loop { - ::std::thread::sleep(::std::time::Duration::from_millis(10)); - - if let Some(statuses) = socket_worker_statuses.try_lock() { - for opt_status in statuses.iter() { - if let Some(Err(err)) = opt_status { - return Err(::anyhow::anyhow!(err.to_owned())); - } - } - - if statuses.iter().all(Option::is_some) { - if config.privileges.drop_privileges { - PrivDrop::default() - .chroot(config.privileges.chroot_path.clone()) - .user(config.privileges.user.clone()) - .apply() - .context("Couldn't drop root privileges")?; - } - - break; - } - } - } - - let out_message_sender = OutMessageSender::new(out_message_senders); - - for i in 0..config.request_workers { - let config = config.clone(); - let state = state.clone(); - let in_message_receiver = in_message_receiver.clone(); - let out_message_sender = out_message_sender.clone(); - let wakers = wakers.clone(); - - Builder::new() - .name(format!("request-{:02}", i + 1)) - .spawn(move || { - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - WorkerIndex::RequestWorker(i), - ); - - request::run_request_worker( - config, - state, - in_message_receiver, - out_message_sender, - wakers, - ); - })?; - } - - if config.statistics.interval != 0 { - let state = state.clone(); - let config = config.clone(); - - Builder::new() - .name("statistics".to_string()) - .spawn(move || { - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - WorkerIndex::Other, - ); - - loop { - ::std::thread::sleep(Duration::from_secs(config.statistics.interval)); - - print_statistics(&state); - } - }) - .expect("spawn statistics thread"); - } - - Ok(()) -} - -fn print_statistics(state: &State) { - let mut peers_per_torrent = Histogram::new(); - - { - let torrents = &mut state.torrent_maps.lock(); - - for torrent in torrents.ipv4.values() { - let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64; - - if let Err(err) = peers_per_torrent.increment(num_peers) { - eprintln!("error incrementing peers_per_torrent histogram: {}", err) - } - } - for torrent in torrents.ipv6.values() { - let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64; - - if let Err(err) = peers_per_torrent.increment(num_peers) { - eprintln!("error incrementing peers_per_torrent histogram: {}", err) - } - } - } - - if peers_per_torrent.entries() != 0 { - println!( - "peers per torrent: min: {}, p50: {}, p75: {}, p90: {}, p99: {}, p999: {}, max: {}", - peers_per_torrent.minimum().unwrap(), - peers_per_torrent.percentile(50.0).unwrap(), - peers_per_torrent.percentile(75.0).unwrap(), - peers_per_torrent.percentile(90.0).unwrap(), - peers_per_torrent.percentile(99.0).unwrap(), - peers_per_torrent.percentile(99.9).unwrap(), - peers_per_torrent.maximum().unwrap(), - ); - } -} diff --git a/aquatic_ws/src/mio/request.rs b/aquatic_ws/src/mio/request.rs deleted file mode 100644 index d3b905a..0000000 --- a/aquatic_ws/src/mio/request.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::sync::Arc; -use std::time::Duration; - -use mio::Waker; -use parking_lot::MutexGuard; -use rand::{rngs::SmallRng, SeedableRng}; - -use aquatic_ws_protocol::*; - -use crate::common::handlers::{handle_announce_request, handle_scrape_request}; -use crate::common::*; -use crate::config::Config; - -use super::common::*; - -pub fn run_request_worker( - config: Config, - state: State, - in_message_receiver: InMessageReceiver, - out_message_sender: OutMessageSender, - wakers: Vec>, -) { - let mut wake_socket_workers: Vec = (0..config.socket_workers).map(|_| false).collect(); - - let mut announce_requests = Vec::new(); - let mut scrape_requests = Vec::new(); - let mut out_messages = Vec::new(); - - let mut rng = SmallRng::from_entropy(); - - let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds); - - loop { - let mut opt_torrent_map_guard: Option> = None; - - for i in 0..config.handlers.max_requests_per_iter { - let opt_in_message = if i == 0 { - in_message_receiver.recv().ok() - } else { - in_message_receiver.recv_timeout(timeout).ok() - }; - - match opt_in_message { - Some((meta, InMessage::AnnounceRequest(r))) => { - announce_requests.push((meta, r)); - } - Some((meta, InMessage::ScrapeRequest(r))) => { - scrape_requests.push((meta, r)); - } - None => { - if let Some(torrent_guard) = state.torrent_maps.try_lock() { - opt_torrent_map_guard = Some(torrent_guard); - - break; - } - } - } - } - - let mut torrent_map_guard = - opt_torrent_map_guard.unwrap_or_else(|| state.torrent_maps.lock()); - - let valid_until = ValidUntil::new(config.cleaning.max_peer_age); - - for (meta, request) in announce_requests.drain(..) { - handle_announce_request( - &config, - &mut rng, - &mut torrent_map_guard, - &mut out_messages, - valid_until, - meta, - request, - ); - } - for (meta, request) in scrape_requests.drain(..) { - handle_scrape_request( - &config, - &mut torrent_map_guard, - &mut out_messages, - meta, - request, - ); - } - - ::std::mem::drop(torrent_map_guard); - - for (meta, out_message) in out_messages.drain(..) { - wake_socket_workers[meta.out_message_consumer_id.0] = true; - out_message_sender.send(meta, out_message); - } - - for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() { - if *wake { - if let Err(err) = wakers[worker_index].wake() { - ::log::error!("request handler couldn't wake poll: {:?}", err); - } - - *wake = false; - } - } - } -} diff --git a/aquatic_ws/src/mio/socket/connection.rs b/aquatic_ws/src/mio/socket/connection.rs deleted file mode 100644 index 50f5ca2..0000000 --- a/aquatic_ws/src/mio/socket/connection.rs +++ /dev/null @@ -1,577 +0,0 @@ -use std::{collections::VecDeque, io::ErrorKind, marker::PhantomData, net::Shutdown, sync::Arc}; - -use aquatic_common::ValidUntil; -use aquatic_ws_protocol::{InMessage, OutMessage}; -use mio::{net::TcpStream, Interest, Poll, Token}; -use rustls::{ServerConfig, ServerConnection}; -use tungstenite::{ - handshake::{server::NoCallback, MidHandshake}, - protocol::WebSocketConfig, - HandshakeError, ServerHandshake, -}; - -use crate::common::ConnectionMeta; - -const MAX_PENDING_MESSAGES: usize = 16; - -type TlsStream = rustls::StreamOwned; - -type WsHandshakeResult = - Result, HandshakeError>>; - -type ConnectionReadResult = ::std::io::Result>; - -pub trait RegistryStatus {} - -pub struct Registered; - -impl RegistryStatus for Registered {} - -pub struct NotRegistered; - -impl RegistryStatus for NotRegistered {} - -enum ConnectionReadStatus { - Message(T, InMessage), - Ok(T), - WouldBlock(T), -} - -enum ConnectionState { - TlsHandshaking(TlsHandshaking), - WsHandshaking(WsHandshaking), - WsConnection(WsConnection), -} - -pub struct Connection { - pub valid_until: ValidUntil, - meta: ConnectionMeta, - state: ConnectionState, - pub message_queue: VecDeque, - pub interest: Interest, - phantom_data: PhantomData, -} - -impl Connection { - pub fn get_meta(&self) -> ConnectionMeta { - self.meta - } -} - -impl Connection { - pub fn new( - tls_config: Arc, - ws_config: WebSocketConfig, - tcp_stream: TcpStream, - valid_until: ValidUntil, - meta: ConnectionMeta, - ) -> Self { - let state = - ConnectionState::TlsHandshaking(TlsHandshaking::new(tls_config, ws_config, tcp_stream)); - - Self { - valid_until, - meta, - state, - message_queue: Default::default(), - interest: Interest::READABLE, - phantom_data: PhantomData::default(), - } - } - - /// Read until stream blocks (or error occurs) - /// - /// Requires Connection not to be registered, since it might be dropped on errors - pub fn read( - mut self, - message_handler: &mut F, - ) -> ::std::io::Result> - where - F: FnMut(ConnectionMeta, InMessage), - { - loop { - let result = match self.state { - ConnectionState::TlsHandshaking(inner) => inner.read(), - ConnectionState::WsHandshaking(inner) => inner.read(), - ConnectionState::WsConnection(inner) => inner.read(), - }; - - match result { - Ok(ConnectionReadStatus::Message(state, message)) => { - self.state = state; - - message_handler(self.meta, message); - - // Stop looping even if WouldBlock wasn't necessarily reached. Otherwise, - // we might get stuck reading from this connection only. Since we register - // the connection again upon reinsertion into the ConnectionMap, we should - // be getting new events anyway. - return Ok(self); - } - Ok(ConnectionReadStatus::Ok(state)) => { - self.state = state; - - ::log::debug!("read connection"); - } - Ok(ConnectionReadStatus::WouldBlock(state)) => { - self.state = state; - - ::log::debug!("reading connection would block"); - - return Ok(self); - } - Err(err) => { - ::log::debug!("Connection::read error: {}", err); - - return Err(err); - } - } - } - } - - pub fn register(self, poll: &mut Poll, token: Token) -> Connection { - let state = match self.state { - ConnectionState::TlsHandshaking(inner) => { - ConnectionState::TlsHandshaking(inner.register(poll, token, self.interest)) - } - ConnectionState::WsHandshaking(inner) => { - ConnectionState::WsHandshaking(inner.register(poll, token, self.interest)) - } - ConnectionState::WsConnection(inner) => { - ConnectionState::WsConnection(inner.register(poll, token, self.interest)) - } - }; - - Connection { - valid_until: self.valid_until, - meta: self.meta, - state, - message_queue: self.message_queue, - interest: self.interest, - phantom_data: PhantomData::default(), - } - } - - pub fn close(self) { - ::log::debug!("will close connection to {}", self.meta.peer_addr.get()); - - match self.state { - ConnectionState::TlsHandshaking(inner) => inner.close(), - ConnectionState::WsHandshaking(inner) => inner.close(), - ConnectionState::WsConnection(inner) => inner.close(), - } - } -} - -impl Connection { - pub fn write_or_queue_message( - &mut self, - poll: &mut Poll, - message: OutMessage, - ) -> ::std::io::Result<()> { - let message_clone = message.clone(); - - match self.write_message(message) { - Ok(()) => Ok(()), - Err(err) if err.kind() == ErrorKind::WouldBlock => { - if self.message_queue.len() < MAX_PENDING_MESSAGES { - self.message_queue.push_back(message_clone); - - if !self.interest.is_writable() { - self.interest = Interest::WRITABLE; - - self.reregister(poll)?; - } - } else { - ::log::info!("Connection::message_queue is full, dropping message"); - } - - Ok(()) - } - Err(err) => Err(err), - } - } - - pub fn write(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { - if let ConnectionState::WsConnection(_) = self.state { - while let Some(message) = self.message_queue.pop_front() { - let message_clone = message.clone(); - - match self.write_message(message) { - Ok(()) => {} - Err(err) if err.kind() == ErrorKind::WouldBlock => { - // Can't make message queue longer than it was before pop_front - self.message_queue.push_front(message_clone); - - return Ok(()); - } - Err(err) => { - return Err(err); - } - } - } - - if self.message_queue.is_empty() { - self.interest = Interest::READABLE; - } - - self.reregister(poll)?; - - Ok(()) - } else { - Err(std::io::Error::new( - ErrorKind::NotConnected, - "WebSocket connection not established", - )) - } - } - - fn write_message(&mut self, message: OutMessage) -> ::std::io::Result<()> { - if let ConnectionState::WsConnection(WsConnection { - ref mut web_socket, .. - }) = self.state - { - match web_socket.write_message(message.to_ws_message()) { - Ok(_) => {} - Err(tungstenite::Error::SendQueueFull(_message)) => { - return Err(std::io::Error::new( - ErrorKind::WouldBlock, - "Send queue full", - )) - } - Err(tungstenite::Error::Io(err)) => return Err(err), - Err(err) => return Err(std::io::Error::new(ErrorKind::Other, err))?, - } - - match web_socket.write_pending() { - Ok(()) => Ok(()), - Err(tungstenite::Error::Io(err)) => Err(err), - Err(err) => Err(std::io::Error::new(ErrorKind::Other, err))?, - } - } else { - Err(std::io::Error::new( - ErrorKind::NotConnected, - "WebSocket connection not established", - )) - } - } - - pub fn reregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { - let token = Token(self.meta.connection_id.0); - - match self.state { - ConnectionState::TlsHandshaking(ref mut inner) => { - inner.reregister(poll, token, self.interest) - } - ConnectionState::WsHandshaking(ref mut inner) => { - inner.reregister(poll, token, self.interest) - } - ConnectionState::WsConnection(ref mut inner) => { - inner.reregister(poll, token, self.interest) - } - } - } - - pub fn deregister(self, poll: &mut Poll) -> Connection { - let state = match self.state { - ConnectionState::TlsHandshaking(inner) => { - ConnectionState::TlsHandshaking(inner.deregister(poll)) - } - ConnectionState::WsHandshaking(inner) => { - ConnectionState::WsHandshaking(inner.deregister(poll)) - } - ConnectionState::WsConnection(inner) => { - ConnectionState::WsConnection(inner.deregister(poll)) - } - }; - - Connection { - valid_until: self.valid_until, - meta: self.meta, - state, - message_queue: self.message_queue, - interest: self.interest, - phantom_data: PhantomData::default(), - } - } -} - -struct TlsHandshaking { - tls_conn: ServerConnection, - ws_config: WebSocketConfig, - tcp_stream: TcpStream, - phantom_data: PhantomData, -} - -impl TlsHandshaking { - fn new(tls_config: Arc, ws_config: WebSocketConfig, stream: TcpStream) -> Self { - Self { - tls_conn: ServerConnection::new(tls_config).unwrap(), - ws_config, - tcp_stream: stream, - phantom_data: PhantomData::default(), - } - } - - fn read(mut self) -> ConnectionReadResult> { - match self.tls_conn.read_tls(&mut self.tcp_stream) { - Ok(0) => { - return Err(::std::io::Error::new( - ErrorKind::ConnectionReset, - "Connection closed", - )) - } - Ok(_) => match self.tls_conn.process_new_packets() { - Ok(_) => { - while self.tls_conn.wants_write() { - self.tls_conn.write_tls(&mut self.tcp_stream)?; - } - - if self.tls_conn.is_handshaking() { - Ok(ConnectionReadStatus::WouldBlock( - ConnectionState::TlsHandshaking(self), - )) - } else { - let tls_stream = TlsStream::new(self.tls_conn, self.tcp_stream); - - WsHandshaking::handle_handshake_result(tungstenite::accept_with_config( - tls_stream, - Some(self.ws_config), - )) - } - } - Err(err) => { - let _ = self.tls_conn.write_tls(&mut self.tcp_stream); - - Err(::std::io::Error::new(ErrorKind::InvalidData, err)) - } - }, - Err(err) if err.kind() == ErrorKind::WouldBlock => { - return Ok(ConnectionReadStatus::WouldBlock( - ConnectionState::TlsHandshaking(self), - )) - } - Err(err) => return Err(err), - } - } - - fn register( - mut self, - poll: &mut Poll, - token: Token, - interest: Interest, - ) -> TlsHandshaking { - poll.registry() - .register(&mut self.tcp_stream, token, interest) - .unwrap(); - - TlsHandshaking { - tls_conn: self.tls_conn, - ws_config: self.ws_config, - tcp_stream: self.tcp_stream, - phantom_data: PhantomData::default(), - } - } - - fn close(self) { - ::log::debug!("closing connection (TlsHandshaking state)"); - - let _ = self.tcp_stream.shutdown(Shutdown::Both); - } -} - -impl TlsHandshaking { - fn deregister(mut self, poll: &mut Poll) -> TlsHandshaking { - poll.registry().deregister(&mut self.tcp_stream).unwrap(); - - TlsHandshaking { - tls_conn: self.tls_conn, - ws_config: self.ws_config, - tcp_stream: self.tcp_stream, - phantom_data: PhantomData::default(), - } - } - - fn reregister( - &mut self, - poll: &mut Poll, - token: Token, - interest: Interest, - ) -> std::io::Result<()> { - poll.registry() - .reregister(&mut self.tcp_stream, token, interest) - } -} - -struct WsHandshaking { - mid_handshake: MidHandshake>, - phantom_data: PhantomData, -} - -impl WsHandshaking { - fn read(self) -> ConnectionReadResult> { - Self::handle_handshake_result(self.mid_handshake.handshake()) - } - - fn handle_handshake_result( - handshake_result: WsHandshakeResult, - ) -> ConnectionReadResult> { - match handshake_result { - Ok(web_socket) => { - let conn = ConnectionState::WsConnection(WsConnection { - web_socket, - phantom_data: PhantomData::default(), - }); - - Ok(ConnectionReadStatus::Ok(conn)) - } - Err(HandshakeError::Interrupted(mid_handshake)) => { - let conn = ConnectionState::WsHandshaking(WsHandshaking { - mid_handshake, - phantom_data: PhantomData::default(), - }); - - Ok(ConnectionReadStatus::WouldBlock(conn)) - } - Err(HandshakeError::Failure(err)) => { - return Err(std::io::Error::new(ErrorKind::InvalidData, err)) - } - } - } - - fn register( - mut self, - poll: &mut Poll, - token: Token, - interest: Interest, - ) -> WsHandshaking { - let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; - - poll.registry() - .register(tcp_stream, token, interest) - .unwrap(); - - WsHandshaking { - mid_handshake: self.mid_handshake, - phantom_data: PhantomData::default(), - } - } - - fn close(mut self) { - ::log::debug!("closing connection (WsHandshaking state)"); - - let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; - - let _ = tcp_stream.shutdown(Shutdown::Both); - } -} - -impl WsHandshaking { - fn deregister(mut self, poll: &mut Poll) -> WsHandshaking { - let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; - - poll.registry().deregister(tcp_stream).unwrap(); - - WsHandshaking { - mid_handshake: self.mid_handshake, - phantom_data: PhantomData::default(), - } - } - - fn reregister( - &mut self, - poll: &mut Poll, - token: Token, - interest: Interest, - ) -> std::io::Result<()> { - let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; - - poll.registry().reregister(tcp_stream, token, interest) - } -} - -struct WsConnection { - web_socket: tungstenite::WebSocket, - phantom_data: PhantomData, -} - -impl WsConnection { - fn read(mut self) -> ConnectionReadResult> { - match self.web_socket.read_message() { - Ok( - message @ tungstenite::Message::Text(_) | message @ tungstenite::Message::Binary(_), - ) => match InMessage::from_ws_message(message) { - Ok(message) => { - ::log::debug!("received WebSocket message"); - - Ok(ConnectionReadStatus::Message( - ConnectionState::WsConnection(self), - message, - )) - } - Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)), - }, - Ok(message) => { - ::log::info!("received unexpected WebSocket message: {}", message); - - Err(std::io::Error::new( - ErrorKind::InvalidData, - "unexpected WebSocket message type", - )) - } - Err(tungstenite::Error::Io(err)) if err.kind() == ErrorKind::WouldBlock => { - let conn = ConnectionState::WsConnection(self); - - Ok(ConnectionReadStatus::WouldBlock(conn)) - } - Err(tungstenite::Error::Io(err)) => Err(err), - Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)), - } - } - - fn register( - mut self, - poll: &mut Poll, - token: Token, - interest: Interest, - ) -> WsConnection { - poll.registry() - .register(self.web_socket.get_mut().get_mut(), token, interest) - .unwrap(); - - WsConnection { - web_socket: self.web_socket, - phantom_data: PhantomData::default(), - } - } - - fn close(mut self) { - ::log::debug!("closing connection (WsConnection state)"); - - let _ = self.web_socket.close(None); - let _ = self.web_socket.write_pending(); - } -} - -impl WsConnection { - fn deregister(mut self, poll: &mut Poll) -> WsConnection { - poll.registry() - .deregister(self.web_socket.get_mut().get_mut()) - .unwrap(); - - WsConnection { - web_socket: self.web_socket, - phantom_data: PhantomData::default(), - } - } - - fn reregister( - &mut self, - poll: &mut Poll, - token: Token, - interest: Interest, - ) -> std::io::Result<()> { - poll.registry() - .reregister(self.web_socket.get_mut().get_mut(), token, interest) - } -} diff --git a/aquatic_ws/src/mio/socket/mod.rs b/aquatic_ws/src/mio/socket/mod.rs deleted file mode 100644 index 181c8f6..0000000 --- a/aquatic_ws/src/mio/socket/mod.rs +++ /dev/null @@ -1,403 +0,0 @@ -use std::io::ErrorKind; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use anyhow::Context; -use aquatic_common::access_list::AccessListQuery; -use aquatic_common::CanonicalSocketAddr; -use hashbrown::HashMap; -use mio::net::TcpListener; -use mio::{Events, Interest, Poll, Token}; -use socket2::{Domain, Protocol, Socket, Type}; -use tungstenite::protocol::WebSocketConfig; - -use aquatic_ws_protocol::*; - -use crate::common::*; -use crate::config::Config; - -pub mod connection; - -use super::common::*; - -use connection::{Connection, NotRegistered, Registered}; - -struct ConnectionMap { - token_counter: Token, - connections: HashMap>, -} - -impl Default for ConnectionMap { - fn default() -> Self { - Self { - token_counter: Token(2), - connections: Default::default(), - } - } -} - -impl ConnectionMap { - fn insert_and_register_new(&mut self, poll: &mut Poll, connection_creator: F) - where - F: FnOnce(Token) -> Connection, - { - self.token_counter.0 = self.token_counter.0.wrapping_add(1); - - // Don't assign LISTENER_TOKEN or CHANNEL_TOKEN - if self.token_counter.0 < 2 { - self.token_counter.0 = 2; - } - - let token = self.token_counter; - - // Remove, deregister and close any existing connection with this token. - // This shouldn't happen in practice. - if let Some(connection) = self.connections.remove(&token) { - ::log::warn!( - "removing existing connection {} because of token reuse", - token.0 - ); - - connection.deregister(poll).close(); - } - - let connection = connection_creator(token); - - self.insert_and_register(poll, token, connection); - } - - fn insert_and_register( - &mut self, - poll: &mut Poll, - key: Token, - conn: Connection, - ) { - self.connections.insert(key, conn.register(poll, key)); - } - - fn remove_and_deregister( - &mut self, - poll: &mut Poll, - key: &Token, - ) -> Option> { - if let Some(connection) = self.connections.remove(key) { - Some(connection.deregister(poll)) - } else { - None - } - } - - fn get_mut(&mut self, key: &Token) -> Option<&mut Connection> { - self.connections.get_mut(key) - } - - /// Close and remove inactive connections - fn clean(mut self, poll: &mut Poll) -> Self { - let now = Instant::now(); - - let mut retained_connections = HashMap::default(); - - for (token, connection) in self.connections.drain() { - if connection.valid_until.0 < now { - connection.deregister(poll).close(); - } else { - retained_connections.insert(token, connection); - } - } - - ConnectionMap { - connections: retained_connections, - ..self - } - } -} - -pub fn run_socket_worker( - config: Config, - state: State, - socket_worker_index: usize, - socket_worker_statuses: SocketWorkerStatuses, - poll: Poll, - in_message_sender: InMessageSender, - out_message_receiver: OutMessageReceiver, - tls_config: Arc, -) { - match create_listener(&config) { - Ok(listener) => { - socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(())); - - run_poll_loop( - config, - &state, - socket_worker_index, - poll, - in_message_sender, - out_message_receiver, - listener, - tls_config, - ); - } - Err(err) => { - socket_worker_statuses.lock()[socket_worker_index] = - Some(Err(format!("Couldn't open socket: {:#}", err))); - } - } -} - -fn run_poll_loop( - config: Config, - state: &State, - socket_worker_index: usize, - mut poll: Poll, - in_message_sender: InMessageSender, - out_message_receiver: OutMessageReceiver, - listener: ::std::net::TcpListener, - tls_config: Arc, -) { - let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds); - let ws_config = WebSocketConfig { - max_message_size: Some(config.network.websocket_max_message_size), - max_frame_size: Some(config.network.websocket_max_frame_size), - max_send_queue: Some(2), - ..Default::default() - }; - - let mut listener = TcpListener::from_std(listener); - let mut events = Events::with_capacity(config.network.poll_event_capacity); - - poll.registry() - .register(&mut listener, LISTENER_TOKEN, Interest::READABLE) - .unwrap(); - - let mut connections = ConnectionMap::default(); - let mut local_responses = Vec::new(); - - let mut iter_counter = 0usize; - - loop { - poll.poll(&mut events, Some(poll_timeout)) - .expect("failed polling"); - - let valid_until = ValidUntil::new(config.cleaning.max_connection_age); - - for event in events.iter() { - let token = event.token(); - - match token { - LISTENER_TOKEN => { - accept_new_streams( - &tls_config, - ws_config, - socket_worker_index, - &mut listener, - &mut poll, - &mut connections, - valid_until, - ); - } - CHANNEL_TOKEN => { - write_or_queue_messages( - &mut poll, - out_message_receiver - .try_iter() - .take(out_message_receiver.len()), - &mut connections, - ); - } - token => { - if event.is_writable() { - let mut remove_connection = false; - - if let Some(connection) = connections.get_mut(&token) { - if let Err(err) = connection.write(&mut poll) { - ::log::debug!("Connection::write error: {}", err); - - remove_connection = true; - } - } - - if remove_connection { - if let Some(connection) = - connections.remove_and_deregister(&mut poll, &token) - { - connection.close(); - } - } - } - if event.is_readable() { - handle_stream_read_event( - &config, - state, - &mut local_responses, - &in_message_sender, - &mut poll, - &mut connections, - token, - valid_until, - ); - } - } - } - - write_or_queue_messages(&mut poll, local_responses.drain(..), &mut connections); - } - - // Remove inactive connections, but not every iteration - if iter_counter % 128 == 0 { - connections = connections.clean(&mut poll); - } - - iter_counter = iter_counter.wrapping_add(1); - } -} - -fn accept_new_streams( - tls_config: &Arc, - ws_config: WebSocketConfig, - socket_worker_index: usize, - listener: &mut TcpListener, - poll: &mut Poll, - connections: &mut ConnectionMap, - valid_until: ValidUntil, -) { - loop { - match listener.accept() { - Ok((stream, _)) => { - let peer_addr = if let Ok(peer_addr) = stream.peer_addr() { - CanonicalSocketAddr::new(peer_addr) - } else { - continue; - }; - - connections.insert_and_register_new(poll, move |token| { - let meta = ConnectionMeta { - out_message_consumer_id: ConsumerId(socket_worker_index), - connection_id: ConnectionId(token.0), - peer_addr, - pending_scrape_id: None, // FIXME - }; - - Connection::new(tls_config.clone(), ws_config, stream, valid_until, meta) - }); - } - Err(err) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(err) => { - ::log::info!("error while accepting streams: {}", err); - } - } - } -} - -fn handle_stream_read_event( - config: &Config, - state: &State, - local_responses: &mut Vec<(ConnectionMeta, OutMessage)>, - in_message_sender: &InMessageSender, - poll: &mut Poll, - connections: &mut ConnectionMap, - token: Token, - valid_until: ValidUntil, -) { - let access_list_mode = config.access_list.mode; - - if let Some(mut connection) = connections.remove_and_deregister(poll, &token) { - let message_handler = &mut |meta, message| match message { - InMessage::AnnounceRequest(ref request) - if !state - .access_list - .allows(access_list_mode, &request.info_hash.0) => - { - let out_message = OutMessage::ErrorResponse(ErrorResponse { - failure_reason: "Info hash not allowed".into(), - action: Some(ErrorResponseAction::Announce), - info_hash: Some(request.info_hash), - }); - - local_responses.push((meta, out_message)); - } - in_message => { - if let Err(err) = in_message_sender.send((meta, in_message)) { - ::log::info!("InMessageSender: couldn't send message: {:?}", err); - } - } - }; - - connection.valid_until = valid_until; - - match connection.read(message_handler) { - Ok(connection) => { - connections.insert_and_register(poll, token, connection); - } - Err(_) => {} - } - } -} - -fn write_or_queue_messages(poll: &mut Poll, responses: I, connections: &mut ConnectionMap) -where - I: Iterator, -{ - for (meta, out_message) in responses { - let token = Token(meta.connection_id.0); - - let mut remove_connection = false; - - if let Some(connection) = connections.get_mut(&token) { - if connection.get_meta().peer_addr != meta.peer_addr { - ::log::warn!( - "socket worker error: connection socket addr {} didn't match channel {}. Token: {}.", - connection.get_meta().peer_addr.get(), - meta.peer_addr.get(), - token.0 - ); - - remove_connection = true; - } else { - match connection.write_or_queue_message(poll, out_message) { - Ok(()) => {} - Err(err) => { - ::log::debug!("Connection::write_or_queue_message error: {}", err); - - remove_connection = true; - } - } - } - } - - if remove_connection { - connections.remove_and_deregister(poll, &token); - } - } -} - -pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> { - let builder = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)) - } else { - Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)) - } - .context("Couldn't create socket2::Socket")?; - - if config.network.ipv6_only { - builder - .set_only_v6(true) - .context("Couldn't put socket in ipv6 only mode")? - } - - builder - .set_nonblocking(true) - .context("Couldn't put socket in non-blocking mode")?; - builder - .set_reuse_port(true) - .context("Couldn't put socket in reuse_port mode")?; - builder - .bind(&config.network.address.into()) - .with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?; - builder - .listen(128) - .context("Couldn't listen for connections on socket")?; - - Ok(builder.into()) -} diff --git a/aquatic_ws/src/workers/mod.rs b/aquatic_ws/src/workers/mod.rs new file mode 100644 index 0000000..63fc0ec --- /dev/null +++ b/aquatic_ws/src/workers/mod.rs @@ -0,0 +1,2 @@ +pub mod request; +pub mod socket; diff --git a/aquatic_ws/src/common/handlers.rs b/aquatic_ws/src/workers/request.rs similarity index 59% rename from aquatic_ws/src/common/handlers.rs rename to aquatic_ws/src/workers/request.rs index 1d6ae0c..a5dd22f 100644 --- a/aquatic_ws/src/common/handlers.rs +++ b/aquatic_ws/src/workers/request.rs @@ -1,11 +1,130 @@ -use aquatic_common::extract_response_peers; -use hashbrown::HashMap; -use rand::rngs::SmallRng; +use std::cell::RefCell; +use std::rc::Rc; +use std::time::Duration; +use futures::StreamExt; +use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; +use glommio::enclose; +use glommio::prelude::*; +use glommio::timer::TimerActionRepeat; +use hashbrown::HashMap; +use rand::{rngs::SmallRng, SeedableRng}; + +use aquatic_common::extract_response_peers; use aquatic_ws_protocol::*; use crate::common::*; use crate::config::Config; +use crate::SHARED_IN_CHANNEL_SIZE; + +pub async fn run_request_worker( + config: Config, + state: State, + in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, + out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, +) { + let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); + let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); + + let out_message_senders = Rc::new(out_message_senders); + + let torrents = Rc::new(RefCell::new(TorrentMaps::default())); + let access_list = state.access_list; + + // Periodically clean torrents + TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || { + enclose!((config, torrents, access_list) move || async move { + torrents.borrow_mut().clean(&config, &access_list); + + Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval)) + })() + })); + + let mut handles = Vec::new(); + + for (_, receiver) in in_message_receivers.streams() { + let handle = spawn_local(handle_request_stream( + config.clone(), + torrents.clone(), + out_message_senders.clone(), + receiver, + )) + .detach(); + + handles.push(handle); + } + + for handle in handles { + handle.await; + } +} + +async fn handle_request_stream( + config: Config, + torrents: Rc>, + out_message_senders: Rc>, + stream: S, +) where + S: futures_lite::Stream + ::std::marker::Unpin, +{ + let rng = Rc::new(RefCell::new(SmallRng::from_entropy())); + + let max_peer_age = config.cleaning.max_peer_age; + let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age))); + + TimerActionRepeat::repeat(enclose!((peer_valid_until) move || { + enclose!((peer_valid_until) move || async move { + *peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age); + + Some(Duration::from_secs(1)) + })() + })); + + let config = &config; + let torrents = &torrents; + let peer_valid_until = &peer_valid_until; + let rng = &rng; + let out_message_senders = &out_message_senders; + + stream + .for_each_concurrent( + SHARED_IN_CHANNEL_SIZE, + move |(meta, in_message)| async move { + let mut out_messages = Vec::new(); + + match in_message { + InMessage::AnnounceRequest(request) => handle_announce_request( + &config, + &mut rng.borrow_mut(), + &mut torrents.borrow_mut(), + &mut out_messages, + peer_valid_until.borrow().to_owned(), + meta, + request, + ), + InMessage::ScrapeRequest(request) => handle_scrape_request( + &config, + &mut torrents.borrow_mut(), + &mut out_messages, + meta, + request, + ), + }; + + for (meta, out_message) in out_messages.drain(..) { + ::log::info!("request worker trying to send OutMessage to socket worker"); + + out_message_senders + .send_to(meta.out_message_consumer_id.0, (meta, out_message)) + .await + .expect("failed sending out_message to socket worker"); + + ::log::info!("request worker sent OutMessage to socket worker"); + } + }, + ) + .await; +} pub fn handle_announce_request( config: &Config, diff --git a/aquatic_ws/src/glommio/socket.rs b/aquatic_ws/src/workers/socket.rs similarity index 99% rename from aquatic_ws/src/glommio/socket.rs rename to aquatic_ws/src/workers/socket.rs index b9dc6fe..1de5d46 100644 --- a/aquatic_ws/src/glommio/socket.rs +++ b/aquatic_ws/src/workers/socket.rs @@ -29,8 +29,6 @@ use crate::config::Config; use crate::common::*; -use super::common::*; - const LOCAL_CHANNEL_SIZE: usize = 16; struct PendingScrapeResponse {