From 24bfaf67c0357f09224565f62fa0f09303b51800 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Mon, 1 Nov 2021 19:08:00 +0100 Subject: [PATCH] aquatic_ws: rewrite to use glommio --- Cargo.lock | 270 +++++++-------- aquatic_ws/Cargo.toml | 13 +- aquatic_ws/src/lib/common.rs | 106 +++--- aquatic_ws/src/lib/config.rs | 38 +-- aquatic_ws/src/lib/handler.rs | 282 --------------- aquatic_ws/src/lib/handlers.rs | 308 +++++++++++++++++ aquatic_ws/src/lib/lib.rs | 256 ++++++-------- aquatic_ws/src/lib/network.rs | 416 +++++++++++++++++++++++ aquatic_ws/src/lib/network/connection.rs | 298 ---------------- aquatic_ws/src/lib/network/mod.rs | 331 ------------------ aquatic_ws/src/lib/network/utils.rs | 66 ---- 11 files changed, 1026 insertions(+), 1358 deletions(-) delete mode 100644 aquatic_ws/src/lib/handler.rs create mode 100644 aquatic_ws/src/lib/handlers.rs create mode 100644 aquatic_ws/src/lib/network.rs delete mode 100644 aquatic_ws/src/lib/network/connection.rs delete mode 100644 aquatic_ws/src/lib/network/mod.rs delete mode 100644 aquatic_ws/src/lib/network/utils.rs diff --git a/Cargo.lock b/Cargo.lock index d806f52..c96bfdf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -239,22 +239,25 @@ dependencies = [ "aquatic_cli_helpers", "aquatic_common", "aquatic_ws_protocol", - "crossbeam-channel", + "async-tungstenite", + "core_affinity", "either", + "futures", + "futures-lite", + "futures-rustls", + "glommio", "hashbrown 0.11.2", "histogram", "indexmap", "log", "mimalloc", - "mio", - "native-tls", - "parking_lot", "privdrop", "quickcheck", "quickcheck_macros", "rand", + "rustls-pemfile", "serde", - "socket2 0.4.2", + "slab", "tungstenite", ] @@ -308,6 +311,19 @@ dependencies = [ "nodrop", ] +[[package]] +name = "async-tungstenite" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "742cc7dcb20b2f84a42f4691aa999070ec7e78f8e7e7438bf14be7017b44907e" +dependencies = [ + "futures-io", + "futures-util", + "log", + "pin-project-lite", + "tungstenite", +] + [[package]] name = "atty" version = "0.2.14" @@ -487,22 +503,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "core-foundation" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6888e10551bb93e424d8df1d07f1a8b4fceb0001a3a4b048bfc47554946f47b3" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" - [[package]] name = "core_affinity" version = "0.5.10" @@ -733,21 +733,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.0.1" @@ -758,12 +743,48 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12aa0eb539080d55c3f2d45a67c3b58b6b0773c1a3ca2dfec66d58c97fd66ca" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "futures-core" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" +[[package]] +name = "futures-executor" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.17" @@ -785,6 +806,19 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-macro" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e4a4b95cea4b4ccbcf1c5675ca7c4ee4e9e75eb79944d07defde18068f79bb" +dependencies = [ + "autocfg", + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-rustls" version = "0.22.0" @@ -796,6 +830,39 @@ dependencies = [ "webpki", ] +[[package]] +name = "futures-sink" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11" + +[[package]] +name = "futures-task" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99" + +[[package]] +name = "futures-util" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481" +dependencies = [ + "autocfg", + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "proc-macro-hack", + "proc-macro-nested", + "slab", +] + [[package]] name = "generic-array" version = "0.14.4" @@ -1150,24 +1217,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "native-tls" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" -dependencies = [ - "lazy_static", - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nix" version = "0.23.0" @@ -1269,39 +1318,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" -[[package]] -name = "openssl" -version = "0.10.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9facdb76fec0b73c406f125d44d86fdad818d66fef0531eec9233ca425ff4a" -dependencies = [ - "bitflags", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-sys", -] - -[[package]] -name = "openssl-probe" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a" - -[[package]] -name = "openssl-sys" -version = "0.9.67" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69df2d8dfc6ce3aaf44b40dec6f487d5a886516cf6879c49e98e0710f310a058" -dependencies = [ - "autocfg", - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "owned-alloc" version = "0.2.0" @@ -1352,10 +1368,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" [[package]] -name = "pkg-config" -version = "0.3.20" +name = "pin-utils" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c9b1041b4387893b91ee6746cddfc28516aff326a3519fb2adf820932c5e6cb" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "plotters" @@ -1401,6 +1417,18 @@ dependencies = [ "nix", ] +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + +[[package]] +name = "proc-macro-nested" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" + [[package]] name = "proc-macro2" version = "1.0.30" @@ -1548,15 +1576,6 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi 0.3.9", -] - [[package]] name = "ring" version = "0.16.20" @@ -1632,16 +1651,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schannel" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75" -dependencies = [ - "lazy_static", - "winapi 0.3.9", -] - [[package]] name = "scoped-tls" version = "1.0.0" @@ -1664,29 +1673,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "security-framework" -version = "2.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525bc1abfda2e1998d152c45cf13e696f76d0a4972310b22fac1658b05df7c87" -dependencies = [ - "bitflags", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9dd14d83160b528b7bfd66439110573efcfbe281b17fc2ca9f39f550d619c7e" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.4" @@ -1873,20 +1859,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "tempfile" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" -dependencies = [ - "cfg-if", - "libc", - "rand", - "redox_syscall", - "remove_dir_all", - "winapi 0.3.9", -] - [[package]] name = "termcolor" version = "1.1.2" @@ -2105,12 +2077,6 @@ dependencies = [ "ryu", ] -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.3" diff --git a/aquatic_ws/Cargo.toml b/aquatic_ws/Cargo.toml index 1087208..953bbd6 100644 --- a/aquatic_ws/Cargo.toml +++ b/aquatic_ws/Cargo.toml @@ -17,23 +17,26 @@ path = "src/bin/main.rs" [dependencies] anyhow = "1" +async-tungstenite = "0.15" aquatic_cli_helpers = "0.1.0" aquatic_common = "0.1.0" aquatic_ws_protocol = "0.1.0" -crossbeam-channel = "0.5" +core_affinity = "0.5" either = "1" +futures-lite = "1" +futures = "0.3" +futures-rustls = "0.22" +glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5" } hashbrown = { version = "0.11.2", features = ["serde"] } histogram = "0.6" indexmap = "1" log = "0.4" mimalloc = { version = "0.1", default-features = false } -mio = { version = "0.7", features = ["tcp", "os-poll", "os-util"] } -native-tls = "0.2" -parking_lot = "0.11" privdrop = "0.5" rand = { version = "0.8", features = ["small_rng"] } +rustls-pemfile = "0.2" serde = { version = "1", features = ["derive"] } -socket2 = { version = "0.4.1", features = ["all"] } +slab = "0.4" tungstenite = "0.15" [dev-dependencies] diff --git a/aquatic_ws/src/lib/common.rs b/aquatic_ws/src/lib/common.rs index 3d6b1df..76f08fc 100644 --- a/aquatic_ws/src/lib/common.rs +++ b/aquatic_ws/src/lib/common.rs @@ -1,14 +1,15 @@ +use std::borrow::Borrow; +use std::cell::RefCell; use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; +use std::rc::Rc; use std::time::Instant; -use aquatic_common::access_list::{AccessList, AccessListArcSwap}; -use crossbeam_channel::{Receiver, Sender}; +use aquatic_common::access_list::AccessList; +use futures_lite::AsyncBufReadExt; +use glommio::io::{BufferedFile, StreamReaderBuilder}; +use glommio::yield_if_needed; use hashbrown::HashMap; use indexmap::IndexMap; -use log::error; -use mio::Token; -use parking_lot::Mutex; pub use aquatic_common::ValidUntil; @@ -16,19 +17,24 @@ use aquatic_ws_protocol::*; use crate::config::Config; -pub const LISTENER_TOKEN: Token = Token(0); -pub const CHANNEL_TOKEN: Token = Token(1); +pub type TlsConfig = futures_rustls::rustls::ServerConfig; + +#[derive(Copy, Clone, Debug)] +pub struct ConsumerId(pub usize); + +#[derive(Clone, Copy, Debug)] +pub struct ConnectionId(pub usize); #[derive(Clone, Copy, Debug)] pub struct ConnectionMeta { /// Index of socket worker responsible for this connection. Required for /// sending back response through correct channel to correct worker. - pub worker_index: usize, + pub out_message_consumer_id: ConsumerId, + pub connection_id: ConnectionId, /// Peer address as received from socket, meaning it wasn't converted to /// an IPv4 address if it was a IPv4-mapped IPv6 address pub naive_peer_addr: SocketAddr, pub converted_peer_ip: IpAddr, - pub poll_token: Token, } #[derive(PartialEq, Eq, Clone, Copy, Debug)] @@ -89,16 +95,12 @@ pub struct TorrentMaps { } impl TorrentMaps { - pub fn clean(&mut self, config: &Config, access_list: &Arc) { + pub fn clean(&mut self, config: &Config, access_list: &AccessList) { Self::clean_torrent_map(config, access_list, &mut self.ipv4); Self::clean_torrent_map(config, access_list, &mut self.ipv6); } - fn clean_torrent_map( - config: &Config, - access_list: &Arc, - torrent_map: &mut TorrentMap, - ) { + fn clean_torrent_map(config: &Config, access_list: &AccessList, torrent_map: &mut TorrentMap) { let now = Instant::now(); torrent_map.retain(|info_hash, torrent_data| { @@ -134,40 +136,44 @@ impl TorrentMaps { } } -#[derive(Clone)] -pub struct State { - pub access_list: Arc, - pub torrent_maps: Arc>, -} +pub async fn update_access_list>( + config: C, + access_list: Rc>, +) { + if config.borrow().access_list.mode.is_on() { + match BufferedFile::open(&config.borrow().access_list.path).await { + Ok(file) => { + let mut reader = StreamReaderBuilder::new(file).build(); + let mut new_access_list = AccessList::default(); -impl Default for State { - fn default() -> Self { - Self { - access_list: Arc::new(Default::default()), - torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), - } + loop { + let mut buf = String::with_capacity(42); + + match reader.read_line(&mut buf).await { + Ok(_) => { + if let Err(err) = new_access_list.insert_from_line(&buf) { + ::log::error!( + "Couln't parse access list line '{}': {:?}", + buf, + err + ); + } + } + Err(err) => { + ::log::error!("Couln't read access list line {:?}", err); + + break; + } + } + + yield_if_needed().await; + } + + *access_list.borrow_mut() = new_access_list; + } + Err(err) => { + ::log::error!("Couldn't open access list file: {:?}", err) + } + }; } } - -pub type InMessageSender = Sender<(ConnectionMeta, InMessage)>; -pub type InMessageReceiver = Receiver<(ConnectionMeta, InMessage)>; -pub type OutMessageReceiver = Receiver<(ConnectionMeta, OutMessage)>; - -#[derive(Clone)] -pub struct OutMessageSender(Vec>); - -impl OutMessageSender { - pub fn new(senders: Vec>) -> Self { - Self(senders) - } - - #[inline] - pub fn send(&self, meta: ConnectionMeta, message: OutMessage) { - if let Err(err) = self.0[meta.worker_index].send((meta, message)) { - error!("OutMessageSender: couldn't send message: {:?}", err); - } - } -} - -pub type SocketWorkerStatus = Option>; -pub type SocketWorkerStatuses = Arc>>; diff --git a/aquatic_ws/src/lib/config.rs b/aquatic_ws/src/lib/config.rs index c0cd032..df71319 100644 --- a/aquatic_ws/src/lib/config.rs +++ b/aquatic_ws/src/lib/config.rs @@ -1,5 +1,7 @@ use std::net::SocketAddr; +use std::path::PathBuf; +use aquatic_common::cpu_pinning::CpuPinningConfig; use aquatic_common::{access_list::AccessListConfig, privileges::PrivilegeConfig}; use serde::{Deserialize, Serialize}; @@ -18,11 +20,11 @@ pub struct Config { pub log_level: LogLevel, pub network: NetworkConfig, pub protocol: ProtocolConfig, - pub handlers: HandlerConfig, pub cleaning: CleaningConfig, pub statistics: StatisticsConfig, pub privileges: PrivilegeConfig, pub access_list: AccessListConfig, + pub cpu_pinning: CpuPinningConfig, } impl aquatic_cli_helpers::Config for Config { @@ -37,24 +39,12 @@ pub struct NetworkConfig { /// Bind to this address pub address: SocketAddr, pub ipv6_only: bool, - pub use_tls: bool, - pub tls_pkcs12_path: String, - pub tls_pkcs12_password: String, - pub poll_event_capacity: usize, - pub poll_timeout_microseconds: u64, + pub tls_certificate_path: PathBuf, + pub tls_private_key_path: PathBuf, pub websocket_max_message_size: usize, pub websocket_max_frame_size: usize, } -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(default)] -pub struct HandlerConfig { - /// Maximum number of requests to receive from channel before locking - /// mutex and starting work - pub max_requests_per_iter: usize, - pub channel_recv_timeout_microseconds: u64, -} - #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(default)] pub struct ProtocolConfig { @@ -92,11 +82,11 @@ impl Default for Config { log_level: LogLevel::default(), network: NetworkConfig::default(), protocol: ProtocolConfig::default(), - handlers: HandlerConfig::default(), cleaning: CleaningConfig::default(), statistics: StatisticsConfig::default(), privileges: PrivilegeConfig::default(), access_list: AccessListConfig::default(), + cpu_pinning: Default::default(), } } } @@ -106,11 +96,8 @@ impl Default for NetworkConfig { Self { address: SocketAddr::from(([0, 0, 0, 0], 3000)), ipv6_only: false, - use_tls: false, - tls_pkcs12_path: "".into(), - tls_pkcs12_password: "".into(), - poll_event_capacity: 4096, - poll_timeout_microseconds: 200_000, + tls_certificate_path: "".into(), + tls_private_key_path: "".into(), websocket_max_message_size: 64 * 1024, websocket_max_frame_size: 16 * 1024, } @@ -127,15 +114,6 @@ impl Default for ProtocolConfig { } } -impl Default for HandlerConfig { - fn default() -> Self { - Self { - max_requests_per_iter: 10000, - channel_recv_timeout_microseconds: 200, - } - } -} - impl Default for CleaningConfig { fn default() -> Self { Self { diff --git a/aquatic_ws/src/lib/handler.rs b/aquatic_ws/src/lib/handler.rs deleted file mode 100644 index 1f61a17..0000000 --- a/aquatic_ws/src/lib/handler.rs +++ /dev/null @@ -1,282 +0,0 @@ -use std::sync::Arc; -use std::time::Duration; -use std::vec::Drain; - -use hashbrown::HashMap; -use mio::Waker; -use parking_lot::MutexGuard; -use rand::{rngs::SmallRng, Rng, SeedableRng}; - -use aquatic_common::extract_response_peers; -use aquatic_ws_protocol::*; - -use crate::common::*; -use crate::config::Config; - -pub fn run_request_worker( - config: Config, - state: State, - in_message_receiver: InMessageReceiver, - out_message_sender: OutMessageSender, - wakers: Vec>, -) { - let mut wake_socket_workers: Vec = (0..config.socket_workers).map(|_| false).collect(); - - let mut announce_requests = Vec::new(); - let mut scrape_requests = Vec::new(); - - let mut rng = SmallRng::from_entropy(); - - let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds); - - loop { - let mut opt_torrent_map_guard: Option> = None; - - for i in 0..config.handlers.max_requests_per_iter { - let opt_in_message = if i == 0 { - in_message_receiver.recv().ok() - } else { - in_message_receiver.recv_timeout(timeout).ok() - }; - - match opt_in_message { - Some((meta, InMessage::AnnounceRequest(r))) => { - announce_requests.push((meta, r)); - } - Some((meta, InMessage::ScrapeRequest(r))) => { - scrape_requests.push((meta, r)); - } - None => { - if let Some(torrent_guard) = state.torrent_maps.try_lock() { - opt_torrent_map_guard = Some(torrent_guard); - - break; - } - } - } - } - - let mut torrent_map_guard = - opt_torrent_map_guard.unwrap_or_else(|| state.torrent_maps.lock()); - - handle_announce_requests( - &config, - &mut rng, - &mut torrent_map_guard, - &out_message_sender, - &mut wake_socket_workers, - announce_requests.drain(..), - ); - - handle_scrape_requests( - &config, - &mut torrent_map_guard, - &out_message_sender, - &mut wake_socket_workers, - scrape_requests.drain(..), - ); - - for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() { - if *wake { - if let Err(err) = wakers[worker_index].wake() { - ::log::error!("request handler couldn't wake poll: {:?}", err); - } - - *wake = false; - } - } - } -} - -pub fn handle_announce_requests( - config: &Config, - rng: &mut impl Rng, - torrent_maps: &mut TorrentMaps, - out_message_sender: &OutMessageSender, - wake_socket_workers: &mut Vec, - requests: Drain<(ConnectionMeta, AnnounceRequest)>, -) { - let valid_until = ValidUntil::new(config.cleaning.max_peer_age); - - for (request_sender_meta, request) in requests { - let torrent_data: &mut TorrentData = if request_sender_meta.converted_peer_ip.is_ipv4() { - torrent_maps.ipv4.entry(request.info_hash).or_default() - } else { - torrent_maps.ipv6.entry(request.info_hash).or_default() - }; - - // If there is already a peer with this peer_id, check that socket - // addr is same as that of request sender. Otherwise, ignore request. - // Since peers have access to each others peer_id's, they could send - // requests using them, causing all sorts of issues. Checking naive - // (non-converted) socket addresses is enough, since state is split - // on converted peer ip. - if let Some(previous_peer) = torrent_data.peers.get(&request.peer_id) { - if request_sender_meta.naive_peer_addr != previous_peer.connection_meta.naive_peer_addr - { - continue; - } - } - - ::log::trace!("received request from {:?}", request_sender_meta); - - // Insert/update/remove peer who sent this request - { - let peer_status = PeerStatus::from_event_and_bytes_left( - request.event.unwrap_or_default(), - request.bytes_left, - ); - - let peer = Peer { - connection_meta: request_sender_meta, - status: peer_status, - valid_until, - }; - - let opt_removed_peer = match peer_status { - PeerStatus::Leeching => { - torrent_data.num_leechers += 1; - - torrent_data.peers.insert(request.peer_id, peer) - } - PeerStatus::Seeding => { - torrent_data.num_seeders += 1; - - torrent_data.peers.insert(request.peer_id, peer) - } - PeerStatus::Stopped => torrent_data.peers.remove(&request.peer_id), - }; - - match opt_removed_peer.map(|peer| peer.status) { - Some(PeerStatus::Leeching) => { - torrent_data.num_leechers -= 1; - } - Some(PeerStatus::Seeding) => { - torrent_data.num_seeders -= 1; - } - _ => {} - } - } - - // If peer sent offers, send them on to random peers - if let Some(offers) = request.offers { - // FIXME: config: also maybe check this when parsing request - let max_num_peers_to_take = offers.len().min(config.protocol.max_offers); - - #[inline] - fn f(peer: &Peer) -> Peer { - *peer - } - - let offer_receivers: Vec = extract_response_peers( - rng, - &torrent_data.peers, - max_num_peers_to_take, - request.peer_id, - f, - ); - - for (offer, offer_receiver) in offers.into_iter().zip(offer_receivers) { - let middleman_offer = MiddlemanOfferToPeer { - action: AnnounceAction, - info_hash: request.info_hash, - peer_id: request.peer_id, - offer: offer.offer, - offer_id: offer.offer_id, - }; - - out_message_sender.send( - offer_receiver.connection_meta, - OutMessage::Offer(middleman_offer), - ); - ::log::trace!( - "sent middleman offer to {:?}", - offer_receiver.connection_meta - ); - wake_socket_workers[offer_receiver.connection_meta.worker_index] = true; - } - } - - // If peer sent answer, send it on to relevant peer - if let (Some(answer), Some(answer_receiver_id), Some(offer_id)) = - (request.answer, request.to_peer_id, request.offer_id) - { - if let Some(answer_receiver) = torrent_data.peers.get(&answer_receiver_id) { - let middleman_answer = MiddlemanAnswerToPeer { - action: AnnounceAction, - peer_id: request.peer_id, - info_hash: request.info_hash, - answer, - offer_id, - }; - - out_message_sender.send( - answer_receiver.connection_meta, - OutMessage::Answer(middleman_answer), - ); - ::log::trace!( - "sent middleman answer to {:?}", - answer_receiver.connection_meta - ); - wake_socket_workers[answer_receiver.connection_meta.worker_index] = true; - } - } - - let response = OutMessage::AnnounceResponse(AnnounceResponse { - action: AnnounceAction, - info_hash: request.info_hash, - complete: torrent_data.num_seeders, - incomplete: torrent_data.num_leechers, - announce_interval: config.protocol.peer_announce_interval, - }); - - out_message_sender.send(request_sender_meta, response); - wake_socket_workers[request_sender_meta.worker_index] = true; - } -} - -pub fn handle_scrape_requests( - config: &Config, - torrent_maps: &mut TorrentMaps, - out_message_sender: &OutMessageSender, - wake_socket_workers: &mut Vec, - requests: Drain<(ConnectionMeta, ScrapeRequest)>, -) { - for (meta, request) in requests { - let info_hashes = if let Some(info_hashes) = request.info_hashes { - info_hashes.as_vec() - } else { - continue; - }; - - let num_to_take = info_hashes.len().min(config.protocol.max_scrape_torrents); - - let mut response = ScrapeResponse { - action: ScrapeAction, - files: HashMap::with_capacity(num_to_take), - }; - - let torrent_map: &mut TorrentMap = if meta.converted_peer_ip.is_ipv4() { - &mut torrent_maps.ipv4 - } else { - &mut torrent_maps.ipv6 - }; - - // If request.info_hashes is empty, don't return scrape for all - // torrents, even though reference server does it. It is too expensive. - for info_hash in info_hashes.into_iter().take(num_to_take) { - if let Some(torrent_data) = torrent_map.get(&info_hash) { - let stats = ScrapeStatistics { - complete: torrent_data.num_seeders, - downloaded: 0, // No implementation planned - incomplete: torrent_data.num_leechers, - }; - - response.files.insert(info_hash, stats); - } - } - - out_message_sender.send(meta, OutMessage::ScrapeResponse(response)); - wake_socket_workers[meta.worker_index] = true; - } -} diff --git a/aquatic_ws/src/lib/handlers.rs b/aquatic_ws/src/lib/handlers.rs new file mode 100644 index 0000000..fccd51b --- /dev/null +++ b/aquatic_ws/src/lib/handlers.rs @@ -0,0 +1,308 @@ +use std::cell::RefCell; +use std::rc::Rc; +use std::time::Duration; + +use aquatic_common::access_list::AccessList; +use aquatic_common::extract_response_peers; +use futures_lite::StreamExt; +use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; +use glommio::enclose; +use glommio::prelude::*; +use glommio::timer::TimerActionRepeat; +use hashbrown::HashMap; +use rand::{rngs::SmallRng, Rng, SeedableRng}; + +use aquatic_ws_protocol::*; + +use crate::common::*; +use crate::config::Config; + +pub async fn run_request_worker( + config: Config, + in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, + out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, + access_list: AccessList, +) { + let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); + let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); + + let out_message_senders = Rc::new(out_message_senders); + + let torrents = Rc::new(RefCell::new(TorrentMaps::default())); + let access_list = Rc::new(RefCell::new(access_list)); + + // Periodically clean torrents and update access list + TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || { + enclose!((config, torrents, access_list) move || async move { + update_access_list(&config, access_list.clone()).await; + + torrents.borrow_mut().clean(&config, &*access_list.borrow()); + + Some(Duration::from_secs(config.cleaning.interval)) + })() + })); + + let mut handles = Vec::new(); + + for (_, receiver) in in_message_receivers.streams() { + let handle = spawn_local(handle_request_stream( + config.clone(), + torrents.clone(), + out_message_senders.clone(), + receiver, + )) + .detach(); + + handles.push(handle); + } + + for handle in handles { + handle.await; + } +} + +async fn handle_request_stream( + config: Config, + torrents: Rc>, + out_message_senders: Rc>, + mut stream: S, +) where + S: futures_lite::Stream + ::std::marker::Unpin, +{ + let mut rng = SmallRng::from_entropy(); + + let max_peer_age = config.cleaning.max_peer_age; + let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age))); + + TimerActionRepeat::repeat(enclose!((peer_valid_until) move || { + enclose!((peer_valid_until) move || async move { + *peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age); + + Some(Duration::from_secs(1)) + })() + })); + + while let Some((meta, in_message)) = stream.next().await { + match in_message { + InMessage::AnnounceRequest(request) => { + handle_announce_request( + &config, + &mut rng, + &mut torrents.borrow_mut(), + &out_message_senders, + peer_valid_until.borrow().to_owned(), + meta, + request, + ) + .await; + } + InMessage::ScrapeRequest(request) => { + handle_scrape_request( + &config, + &mut torrents.borrow_mut(), + &out_message_senders, + meta, + request, + ) + .await; + } + }; + + yield_if_needed().await; + } +} + +pub async fn handle_announce_request( + config: &Config, + rng: &mut impl Rng, + torrent_maps: &mut TorrentMaps, + out_message_senders: &Rc>, + valid_until: ValidUntil, + request_sender_meta: ConnectionMeta, + request: AnnounceRequest, +) { + let torrent_data: &mut TorrentData = if request_sender_meta.converted_peer_ip.is_ipv4() { + torrent_maps.ipv4.entry(request.info_hash).or_default() + } else { + torrent_maps.ipv6.entry(request.info_hash).or_default() + }; + + // If there is already a peer with this peer_id, check that socket + // addr is same as that of request sender. Otherwise, ignore request. + // Since peers have access to each others peer_id's, they could send + // requests using them, causing all sorts of issues. Checking naive + // (non-converted) socket addresses is enough, since state is split + // on converted peer ip. + if let Some(previous_peer) = torrent_data.peers.get(&request.peer_id) { + if request_sender_meta.naive_peer_addr != previous_peer.connection_meta.naive_peer_addr { + return; + } + } + + ::log::trace!("received request from {:?}", request_sender_meta); + + // Insert/update/remove peer who sent this request + { + let peer_status = PeerStatus::from_event_and_bytes_left( + request.event.unwrap_or_default(), + request.bytes_left, + ); + + let peer = Peer { + connection_meta: request_sender_meta, + status: peer_status, + valid_until, + }; + + let opt_removed_peer = match peer_status { + PeerStatus::Leeching => { + torrent_data.num_leechers += 1; + + torrent_data.peers.insert(request.peer_id, peer) + } + PeerStatus::Seeding => { + torrent_data.num_seeders += 1; + + torrent_data.peers.insert(request.peer_id, peer) + } + PeerStatus::Stopped => torrent_data.peers.remove(&request.peer_id), + }; + + match opt_removed_peer.map(|peer| peer.status) { + Some(PeerStatus::Leeching) => { + torrent_data.num_leechers -= 1; + } + Some(PeerStatus::Seeding) => { + torrent_data.num_seeders -= 1; + } + _ => {} + } + } + + // If peer sent offers, send them on to random peers + if let Some(offers) = request.offers { + // FIXME: config: also maybe check this when parsing request + let max_num_peers_to_take = offers.len().min(config.protocol.max_offers); + + #[inline] + fn f(peer: &Peer) -> Peer { + *peer + } + + let offer_receivers: Vec = extract_response_peers( + rng, + &torrent_data.peers, + max_num_peers_to_take, + request.peer_id, + f, + ); + + for (offer, offer_receiver) in offers.into_iter().zip(offer_receivers) { + let middleman_offer = MiddlemanOfferToPeer { + action: AnnounceAction, + info_hash: request.info_hash, + peer_id: request.peer_id, + offer: offer.offer, + offer_id: offer.offer_id, + }; + + out_message_senders.try_send_to( + offer_receiver.connection_meta.out_message_consumer_id.0, + ( + offer_receiver.connection_meta, + OutMessage::Offer(middleman_offer), + ), + ); + ::log::trace!( + "sent middleman offer to {:?}", + offer_receiver.connection_meta + ); + } + } + + // If peer sent answer, send it on to relevant peer + if let (Some(answer), Some(answer_receiver_id), Some(offer_id)) = + (request.answer, request.to_peer_id, request.offer_id) + { + if let Some(answer_receiver) = torrent_data.peers.get(&answer_receiver_id) { + let middleman_answer = MiddlemanAnswerToPeer { + action: AnnounceAction, + peer_id: request.peer_id, + info_hash: request.info_hash, + answer, + offer_id, + }; + + out_message_senders.try_send_to( + answer_receiver.connection_meta.out_message_consumer_id.0, + ( + answer_receiver.connection_meta, + OutMessage::Answer(middleman_answer), + ), + ); + ::log::trace!( + "sent middleman answer to {:?}", + answer_receiver.connection_meta + ); + } + } + + let out_message = OutMessage::AnnounceResponse(AnnounceResponse { + action: AnnounceAction, + info_hash: request.info_hash, + complete: torrent_data.num_seeders, + incomplete: torrent_data.num_leechers, + announce_interval: config.protocol.peer_announce_interval, + }); + + out_message_senders.send_to( + request_sender_meta.out_message_consumer_id.0, + (request_sender_meta, out_message), + ); +} + +pub async fn handle_scrape_request( + config: &Config, + torrent_maps: &mut TorrentMaps, + out_message_senders: &Rc>, + meta: ConnectionMeta, + request: ScrapeRequest, +) { + let info_hashes = if let Some(info_hashes) = request.info_hashes { + info_hashes.as_vec() + } else { + return; + }; + + let num_to_take = info_hashes.len().min(config.protocol.max_scrape_torrents); + + let mut out_message = ScrapeResponse { + action: ScrapeAction, + files: HashMap::with_capacity(num_to_take), + }; + + let torrent_map: &mut TorrentMap = if meta.converted_peer_ip.is_ipv4() { + &mut torrent_maps.ipv4 + } else { + &mut torrent_maps.ipv6 + }; + + // If request.info_hashes is empty, don't return scrape for all + // torrents, even though reference server does it. It is too expensive. + for info_hash in info_hashes.into_iter().take(num_to_take) { + if let Some(torrent_data) = torrent_map.get(&info_hash) { + let stats = ScrapeStatistics { + complete: torrent_data.num_seeders, + downloaded: 0, // No implementation planned + incomplete: torrent_data.num_leechers, + }; + + out_message.files.insert(info_hash, stats); + } + } + + out_message_senders.try_send_to( + meta.out_message_consumer_id.0, + (meta, OutMessage::ScrapeResponse(out_message)), + ); +} diff --git a/aquatic_ws/src/lib/lib.rs b/aquatic_ws/src/lib/lib.rs index 6141466..c1de1d2 100644 --- a/aquatic_ws/src/lib/lib.rs +++ b/aquatic_ws/src/lib/lib.rs @@ -1,176 +1,144 @@ -use std::fs::File; -use std::io::Read; -use std::sync::Arc; -use std::thread::Builder; -use std::time::Duration; +use std::{ + fs::File, + io::BufReader, + sync::{atomic::AtomicUsize, Arc}, +}; -use anyhow::Context; -use mio::{Poll, Waker}; -use native_tls::{Identity, TlsAcceptor}; -use parking_lot::Mutex; -use privdrop::PrivDrop; +use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding}; +use common::TlsConfig; +use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; -pub mod common; +use crate::config::Config; + +mod common; pub mod config; -pub mod handler; -pub mod network; -pub mod tasks; - -use common::*; -use config::Config; +mod handlers; +mod network; pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; +const SHARED_CHANNEL_SIZE: usize = 1024; + pub fn run(config: Config) -> anyhow::Result<()> { - let state = State::default(); - - tasks::update_access_list(&config, &state); - - start_workers(config.clone(), state.clone())?; - - loop { - ::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); - - tasks::update_access_list(&config, &state); - - state - .torrent_maps - .lock() - .clean(&config, &state.access_list.load_full()); + if config.cpu_pinning.active { + core_affinity::set_for_current(core_affinity::CoreId { + id: config.cpu_pinning.offset, + }); } -} -pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { - let opt_tls_acceptor = create_tls_acceptor(&config)?; - - let (in_message_sender, in_message_receiver) = ::crossbeam_channel::unbounded(); - - let mut out_message_senders = Vec::new(); - let mut wakers = Vec::new(); - - let socket_worker_statuses: SocketWorkerStatuses = { - let mut statuses = Vec::new(); - - for _ in 0..config.socket_workers { - statuses.push(None); - } - - Arc::new(Mutex::new(statuses)) + let access_list = if config.access_list.mode.is_on() { + AccessList::create_from_path(&config.access_list.path).expect("Load access list") + } else { + AccessList::default() }; - for i in 0..config.socket_workers { + let num_peers = config.socket_workers + config.request_workers; + + let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); + let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); + + let num_bound_sockets = Arc::new(AtomicUsize::new(0)); + + let tls_config = Arc::new(create_tls_config(&config).unwrap()); + + let mut executors = Vec::new(); + + for i in 0..(config.socket_workers) { let config = config.clone(); - let state = state.clone(); - let socket_worker_statuses = socket_worker_statuses.clone(); - let in_message_sender = in_message_sender.clone(); - let opt_tls_acceptor = opt_tls_acceptor.clone(); - let poll = Poll::new()?; - let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?); + let tls_config = tls_config.clone(); + let request_mesh_builder = request_mesh_builder.clone(); + let response_mesh_builder = response_mesh_builder.clone(); + let num_bound_sockets = num_bound_sockets.clone(); + let access_list = access_list.clone(); - let (out_message_sender, out_message_receiver) = ::crossbeam_channel::unbounded(); + let mut builder = LocalExecutorBuilder::default(); - out_message_senders.push(out_message_sender); - wakers.push(waker); - - Builder::new() - .name(format!("socket-{:02}", i + 1)) - .spawn(move || { - network::run_socket_worker( - config, - state, - i, - socket_worker_statuses, - poll, - in_message_sender, - out_message_receiver, - opt_tls_acceptor, - ); - })?; - } - - // Wait for socket worker statuses. On error from any, quit program. - // On success from all, drop privileges if corresponding setting is set - // and continue program. - loop { - ::std::thread::sleep(::std::time::Duration::from_millis(10)); - - if let Some(statuses) = socket_worker_statuses.try_lock() { - for opt_status in statuses.iter() { - if let Some(Err(err)) = opt_status { - return Err(::anyhow::anyhow!(err.to_owned())); - } - } - - if statuses.iter().all(Option::is_some) { - if config.privileges.drop_privileges { - PrivDrop::default() - .chroot(config.privileges.chroot_path.clone()) - .user(config.privileges.user.clone()) - .apply() - .context("Couldn't drop root privileges")?; - } - - break; - } + if config.cpu_pinning.active { + builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + i); } + + let executor = builder.spawn(|| async move { + network::run_socket_worker( + config, + tls_config, + request_mesh_builder, + response_mesh_builder, + num_bound_sockets, + access_list, + ) + .await + }); + + executors.push(executor); } - let out_message_sender = OutMessageSender::new(out_message_senders); - - for i in 0..config.request_workers { + for i in 0..(config.request_workers) { let config = config.clone(); - let state = state.clone(); - let in_message_receiver = in_message_receiver.clone(); - let out_message_sender = out_message_sender.clone(); - let wakers = wakers.clone(); + let request_mesh_builder = request_mesh_builder.clone(); + let response_mesh_builder = response_mesh_builder.clone(); + let access_list = access_list.clone(); - Builder::new() - .name(format!("request-{:02}", i + 1)) - .spawn(move || { - handler::run_request_worker( - config, - state, - in_message_receiver, - out_message_sender, - wakers, - ); - })?; + let mut builder = LocalExecutorBuilder::default(); + + if config.cpu_pinning.active { + builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + config.socket_workers + i); + } + + let executor = builder.spawn(|| async move { + handlers::run_request_worker( + config, + request_mesh_builder, + response_mesh_builder, + access_list, + ) + .await + }); + + executors.push(executor); } - if config.statistics.interval != 0 { - let state = state.clone(); - let config = config.clone(); + drop_privileges_after_socket_binding( + &config.privileges, + num_bound_sockets, + config.socket_workers, + ) + .unwrap(); - Builder::new() - .name("statistics".to_string()) - .spawn(move || loop { - ::std::thread::sleep(Duration::from_secs(config.statistics.interval)); - - tasks::print_statistics(&state); - }) - .expect("spawn statistics thread"); + for executor in executors { + executor + .expect("failed to spawn local executor") + .join() + .unwrap(); } Ok(()) } -pub fn create_tls_acceptor(config: &Config) -> anyhow::Result> { - if config.network.use_tls { - let mut identity_bytes = Vec::new(); - let mut file = File::open(&config.network.tls_pkcs12_path) - .context("Couldn't open pkcs12 identity file")?; +fn create_tls_config(config: &Config) -> anyhow::Result { + let certs = { + let f = File::open(&config.network.tls_certificate_path)?; + let mut f = BufReader::new(f); - file.read_to_end(&mut identity_bytes) - .context("Couldn't read pkcs12 identity file")?; + rustls_pemfile::certs(&mut f)? + .into_iter() + .map(|bytes| futures_rustls::rustls::Certificate(bytes)) + .collect() + }; - let identity = Identity::from_pkcs12(&identity_bytes, &config.network.tls_pkcs12_password) - .context("Couldn't parse pkcs12 identity file")?; + let private_key = { + let f = File::open(&config.network.tls_private_key_path)?; + let mut f = BufReader::new(f); - let acceptor = TlsAcceptor::new(identity) - .context("Couldn't create TlsAcceptor from pkcs12 identity")?; + rustls_pemfile::pkcs8_private_keys(&mut f)? + .first() + .map(|bytes| futures_rustls::rustls::PrivateKey(bytes.clone())) + .ok_or(anyhow::anyhow!("No private keys in file"))? + }; - Ok(Some(acceptor)) - } else { - Ok(None) - } + let tls_config = futures_rustls::rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, private_key)?; + + Ok(tls_config) } diff --git a/aquatic_ws/src/lib/network.rs b/aquatic_ws/src/lib/network.rs new file mode 100644 index 0000000..6d64740 --- /dev/null +++ b/aquatic_ws/src/lib/network.rs @@ -0,0 +1,416 @@ +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::net::SocketAddr; +use std::rc::Rc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use aquatic_common::access_list::AccessList; +use aquatic_common::convert_ipv4_mapped_ipv6; +use aquatic_ws_protocol::*; +use async_tungstenite::WebSocketStream; +use either::Either; +use futures::stream::{SplitSink, SplitStream}; +use futures_lite::StreamExt; +use futures_rustls::server::TlsStream; +use futures_rustls::TlsAcceptor; +use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; +use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; +use glommio::channels::shared_channel::ConnectedReceiver; +use glommio::net::{TcpListener, TcpStream}; +use glommio::timer::TimerActionRepeat; +use glommio::{enclose, prelude::*}; +use hashbrown::HashMap; +use slab::Slab; + +use crate::config::Config; + +use super::common::*; + +struct PendingScrapeResponse { + pending_worker_out_messages: usize, + stats: HashMap, +} + +struct ConnectionReference { + out_message_sender: LocalSender<(ConnectionMeta, OutMessage)>, +} + +struct Connection { + config: Rc, + access_list: Rc>, + in_message_senders: Rc>, + out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, + out_message_consumer_id: ConsumerId, + ws_out: SplitSink>, tungstenite::Message>, + ws_in: SplitStream>>, + peer_addr: SocketAddr, + connection_id: ConnectionId, +} + +pub async fn run_socket_worker( + config: Config, + tls_config: Arc, + in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, + out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, + num_bound_sockets: Arc, + access_list: AccessList, +) { + let config = Rc::new(config); + let access_list = Rc::new(RefCell::new(access_list)); + + let listener = TcpListener::bind(config.network.address).expect("bind socket"); + num_bound_sockets.fetch_add(1, Ordering::SeqCst); + + let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap(); + let in_message_senders = Rc::new(in_message_senders); + + let (_, mut out_message_receivers) = + out_message_mesh_builder.join(Role::Consumer).await.unwrap(); + let out_message_consumer_id = ConsumerId(out_message_receivers.consumer_id().unwrap()); + + let connection_slab = Rc::new(RefCell::new(Slab::new())); + let connections_to_remove = Rc::new(RefCell::new(Vec::new())); + + // Periodically update access list + TimerActionRepeat::repeat(enclose!((config, access_list) move || { + enclose!((config, access_list) move || async move { + update_access_list(config.clone(), access_list.clone()).await; + + Some(Duration::from_secs(config.cleaning.interval)) + })() + })); + + // Periodically remove closed connections + TimerActionRepeat::repeat( + enclose!((config, connection_slab, connections_to_remove) move || { + remove_closed_connections( + config.clone(), + connection_slab.clone(), + connections_to_remove.clone(), + ) + }), + ); + + for (_, out_message_receiver) in out_message_receivers.streams() { + spawn_local(receive_out_messages( + out_message_receiver, + connection_slab.clone(), + )) + .detach(); + } + + let mut incoming = listener.incoming(); + + while let Some(stream) = incoming.next().await { + match stream { + Ok(stream) => { + let (out_message_sender, out_message_receiver) = + new_bounded(config.request_workers); + let key = connection_slab + .borrow_mut() + .insert(ConnectionReference { out_message_sender }); + + spawn_local(enclose!((config, access_list, in_message_senders, tls_config, connections_to_remove) async move { + if let Err(err) = Connection::run( + config, + access_list, + in_message_senders, + out_message_receiver, + out_message_consumer_id, + ConnectionId(key), + tls_config, + stream + ).await { + ::log::debug!("Connection::run() error: {:?}", err); + } + + connections_to_remove.borrow_mut().push(key); + })) + .detach(); + } + Err(err) => { + ::log::error!("accept connection: {:?}", err); + } + } + } +} + +async fn remove_closed_connections( + config: Rc, + connection_slab: Rc>>, + connections_to_remove: Rc>>, +) -> Option { + let connections_to_remove = connections_to_remove.replace(Vec::new()); + + for connection_id in connections_to_remove { + if let Some(_) = connection_slab.borrow_mut().try_remove(connection_id) { + ::log::debug!("removed connection with id {}", connection_id); + } else { + ::log::error!( + "couldn't remove connection with id {}, it is not in connection slab", + connection_id + ); + } + } + + Some(Duration::from_secs(config.cleaning.interval)) +} + +async fn receive_out_messages( + mut out_message_receiver: ConnectedReceiver<(ConnectionMeta, OutMessage)>, + connection_references: Rc>>, +) { + while let Some(channel_out_message) = out_message_receiver.next().await { + if let Some(reference) = connection_references + .borrow() + .get(channel_out_message.0.connection_id.0) + { + if let Err(err) = reference.out_message_sender.try_send(channel_out_message) { + ::log::error!("Couldn't send out_message to local receiver: {:?}", err); + } + } + } +} + +impl Connection { + async fn run( + config: Rc, + access_list: Rc>, + in_message_senders: Rc>, + out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, + out_message_consumer_id: ConsumerId, + connection_id: ConnectionId, + tls_config: Arc, + stream: TcpStream, + ) -> anyhow::Result<()> { + let peer_addr = stream + .peer_addr() + .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?; + + let tls_acceptor: TlsAcceptor = tls_config.into(); + let stream = tls_acceptor.accept(stream).await?; + + let ws_config = tungstenite::protocol::WebSocketConfig { + max_frame_size: Some(config.network.websocket_max_frame_size), + max_message_size: Some(config.network.websocket_max_message_size), + ..Default::default() + }; + let stream = async_tungstenite::accept_async_with_config(stream, Some(ws_config)).await?; + let (ws_out, ws_in) = futures::StreamExt::split(stream); + + let mut conn = Connection { + config: config.clone(), + access_list: access_list.clone(), + in_message_senders: in_message_senders.clone(), + out_message_receiver, + out_message_consumer_id, + ws_out, + ws_in, + peer_addr, + connection_id, + }; + + conn.run_message_loop().await?; + + Ok(()) + } + + async fn run_message_loop(&mut self) -> anyhow::Result<()> { + loop { + let out_message = match self.read_in_message().await? { + Either::Left(out_message) => OutMessage::ErrorResponse(out_message), + Either::Right(in_message) => self.handle_in_message(in_message).await?, + }; + + self.write_out_message(&out_message).await?; + + if matches!(out_message, OutMessage::ErrorResponse(_)) { + // TODO: shut down? + + break; + } + } + + Ok(()) + } + + async fn read_in_message(&mut self) -> anyhow::Result> { + loop { + ::log::debug!("read"); + + let message = self.ws_in.next().await.unwrap()?; + + match InMessage::from_ws_message(message) { + Ok(in_message) => { + ::log::debug!("received in_message: {:?}", in_message); + + return Ok(Either::Right(in_message)); + } + Err(err) => { + ::log::debug!("Couldn't parse in_message: {:?}", err); + + let out_message = ErrorResponse { + action: None, + failure_reason: "Invalid request".into(), + info_hash: None, + }; + + return Ok(Either::Left(out_message)); + } + } + } + } + + async fn handle_in_message(&self, in_message: InMessage) -> anyhow::Result { + match in_message { + InMessage::AnnounceRequest(announce_request) => { + let info_hash = announce_request.info_hash; + + if self + .access_list + .borrow() + .allows(self.config.access_list.mode, &info_hash.0) + { + let meta = ConnectionMeta { + connection_id: self.connection_id, + out_message_consumer_id: self.out_message_consumer_id, + naive_peer_addr: self.peer_addr, + converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()), + }; + let in_message = InMessage::AnnounceRequest(announce_request); + + let consumer_index = + calculate_in_message_consumer_index(&self.config, info_hash); + + // Only fails when receiver is closed + self.in_message_senders + .send_to(consumer_index, (meta, in_message)) + .await + .unwrap(); + + self.wait_for_out_message(None).await + } else { + let out_message = OutMessage::ErrorResponse(ErrorResponse { + action: Some(ErrorResponseAction::Announce), + failure_reason: "Info hash not allowed".into(), + info_hash: Some(info_hash), + }); + + Ok(out_message) + } + } + InMessage::ScrapeRequest(ScrapeRequest { info_hashes, .. }) => { + let info_hashes = if let Some(info_hashes) = info_hashes { + info_hashes + } else { + let out_message = OutMessage::ErrorResponse(ErrorResponse { + action: Some(ErrorResponseAction::Scrape), + failure_reason: "Full scrapes are not allowed".into(), + info_hash: None, + }); + + return Ok(out_message); + }; + + let mut info_hashes_by_worker: BTreeMap> = BTreeMap::new(); + + for info_hash in info_hashes.as_vec() { + let info_hashes = info_hashes_by_worker + .entry(calculate_in_message_consumer_index(&self.config, info_hash)) + .or_default(); + + info_hashes.push(info_hash); + } + + let pending_worker_out_messages = info_hashes_by_worker.len(); + + let meta = ConnectionMeta { + connection_id: self.connection_id, + out_message_consumer_id: self.out_message_consumer_id, + naive_peer_addr: self.peer_addr, + converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()), + }; + + for (consumer_index, info_hashes) in info_hashes_by_worker { + let in_message = InMessage::ScrapeRequest(ScrapeRequest { + action: ScrapeAction, + info_hashes: Some(ScrapeRequestInfoHashes::Multiple(info_hashes)), + }); + + // Only fails when receiver is closed + self.in_message_senders + .send_to(consumer_index, (meta, in_message)) + .await + .unwrap(); + } + + let pending_scrape_out_message = PendingScrapeResponse { + pending_worker_out_messages, + stats: Default::default(), + }; + + self.wait_for_out_message(Some(pending_scrape_out_message)) + .await + } + } + } + + /// Wait for announce out_message or partial scrape out_messages to arrive, + /// return full out_message + async fn wait_for_out_message( + &self, + mut opt_pending_scrape_out_message: Option, + ) -> anyhow::Result { + loop { + let (meta, out_message) = self + .out_message_receiver + .recv() + .await + .expect("wait_for_out_message: can't receive out_message, sender is closed"); + + if meta.naive_peer_addr != self.peer_addr { + return Err(anyhow::anyhow!("peer addresses didn't match")); + } + + match out_message { + OutMessage::ScrapeResponse(out_message) => { + if let Some(mut pending) = opt_pending_scrape_out_message.take() { + pending.stats.extend(out_message.files); + pending.pending_worker_out_messages -= 1; + + if pending.pending_worker_out_messages == 0 { + let out_message = OutMessage::ScrapeResponse(ScrapeResponse { + action: ScrapeAction, + files: pending.stats, + }); + + break Ok(out_message); + } else { + opt_pending_scrape_out_message = Some(pending); + } + } else { + return Err(anyhow::anyhow!( + "received channel scrape out_message without pending scrape out_message" + )); + } + } + out_message => { + break Ok(out_message); + } + }; + } + } + + async fn write_out_message(&mut self, out_message: &OutMessage) -> anyhow::Result<()> { + futures::SinkExt::send(&mut self.ws_out, out_message.to_ws_message()).await?; + futures::SinkExt::flush(&mut self.ws_out).await?; + + Ok(()) + } +} + +fn calculate_in_message_consumer_index(config: &Config, info_hash: InfoHash) -> usize { + (info_hash.0[0] as usize) % config.request_workers +} diff --git a/aquatic_ws/src/lib/network/connection.rs b/aquatic_ws/src/lib/network/connection.rs deleted file mode 100644 index 4769dc6..0000000 --- a/aquatic_ws/src/lib/network/connection.rs +++ /dev/null @@ -1,298 +0,0 @@ -use std::io::{Read, Write}; -use std::net::SocketAddr; - -use either::Either; -use hashbrown::HashMap; -use log::info; -use mio::net::TcpStream; -use mio::{Poll, Token}; -use native_tls::{MidHandshakeTlsStream, TlsAcceptor, TlsStream}; -use tungstenite::handshake::{server::NoCallback, HandshakeError, MidHandshake}; -use tungstenite::protocol::WebSocketConfig; -use tungstenite::ServerHandshake; -use tungstenite::WebSocket; - -use crate::common::*; - -pub enum Stream { - TcpStream(TcpStream), - TlsStream(TlsStream), -} - -impl Stream { - #[inline] - pub fn get_peer_addr(&self) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.peer_addr(), - Self::TlsStream(stream) => stream.get_ref().peer_addr(), - } - } - - #[inline] - pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { - match self { - Self::TcpStream(stream) => poll.registry().deregister(stream), - Self::TlsStream(stream) => poll.registry().deregister(stream.get_mut()), - } - } -} - -impl Read for Stream { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> Result { - match self { - Self::TcpStream(stream) => stream.read(buf), - Self::TlsStream(stream) => stream.read(buf), - } - } - - /// Not used but provided for completeness - #[inline] - fn read_vectored( - &mut self, - bufs: &mut [::std::io::IoSliceMut<'_>], - ) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.read_vectored(bufs), - Self::TlsStream(stream) => stream.read_vectored(bufs), - } - } -} - -impl Write for Stream { - #[inline] - fn write(&mut self, buf: &[u8]) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.write(buf), - Self::TlsStream(stream) => stream.write(buf), - } - } - - /// Not used but provided for completeness - #[inline] - fn write_vectored(&mut self, bufs: &[::std::io::IoSlice<'_>]) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.write_vectored(bufs), - Self::TlsStream(stream) => stream.write_vectored(bufs), - } - } - - #[inline] - fn flush(&mut self) -> ::std::io::Result<()> { - match self { - Self::TcpStream(stream) => stream.flush(), - Self::TlsStream(stream) => stream.flush(), - } - } -} - -enum HandshakeMachine { - TcpStream(TcpStream), - TlsStream(TlsStream), - TlsMidHandshake(MidHandshakeTlsStream), - WsMidHandshake(MidHandshake>), -} - -impl HandshakeMachine { - #[inline] - fn new(tcp_stream: TcpStream) -> Self { - Self::TcpStream(tcp_stream) - } - - #[inline] - fn advance( - self, - ws_config: WebSocketConfig, - opt_tls_acceptor: &Option, // If set, run TLS - ) -> (Option>, bool) { - // bool = stop looping - match self { - HandshakeMachine::TcpStream(stream) => { - if let Some(tls_acceptor) = opt_tls_acceptor { - Self::handle_tls_handshake_result(tls_acceptor.accept(stream)) - } else { - let handshake_result = ::tungstenite::accept_with_config( - Stream::TcpStream(stream), - Some(ws_config), - ); - - Self::handle_ws_handshake_result(handshake_result) - } - } - HandshakeMachine::TlsStream(stream) => { - let handshake_result = ::tungstenite::accept(Stream::TlsStream(stream)); - - Self::handle_ws_handshake_result(handshake_result) - } - HandshakeMachine::TlsMidHandshake(handshake) => { - Self::handle_tls_handshake_result(handshake.handshake()) - } - HandshakeMachine::WsMidHandshake(handshake) => { - Self::handle_ws_handshake_result(handshake.handshake()) - } - } - } - - #[inline] - fn handle_tls_handshake_result( - result: Result, ::native_tls::HandshakeError>, - ) -> (Option>, bool) { - match result { - Ok(stream) => { - ::log::trace!( - "established tls handshake with peer with addr: {:?}", - stream.get_ref().peer_addr() - ); - - (Some(Either::Right(Self::TlsStream(stream))), false) - } - Err(native_tls::HandshakeError::WouldBlock(handshake)) => { - (Some(Either::Right(Self::TlsMidHandshake(handshake))), true) - } - Err(native_tls::HandshakeError::Failure(err)) => { - info!("tls handshake error: {}", err); - - (None, false) - } - } - } - - #[inline] - fn handle_ws_handshake_result( - result: Result, HandshakeError>>, - ) -> (Option>, bool) { - match result { - Ok(mut ws) => match ws.get_mut().get_peer_addr() { - Ok(peer_addr) => { - ::log::trace!( - "established ws handshake with peer with addr: {:?}", - peer_addr - ); - - let established_ws = EstablishedWs { ws, peer_addr }; - - (Some(Either::Left(established_ws)), false) - } - Err(err) => { - ::log::info!( - "get_peer_addr failed during handshake, removing connection: {:?}", - err - ); - - (None, false) - } - }, - Err(HandshakeError::Interrupted(handshake)) => ( - Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))), - true, - ), - Err(HandshakeError::Failure(err)) => { - info!("ws handshake error: {}", err); - - (None, false) - } - } - } -} - -pub struct EstablishedWs { - pub ws: WebSocket, - pub peer_addr: SocketAddr, -} - -pub struct Connection { - ws_config: WebSocketConfig, - pub valid_until: ValidUntil, - inner: Either, -} - -/// Create from TcpStream. Run `advance_handshakes` until `get_established_ws` -/// returns Some(EstablishedWs). -/// -/// advance_handshakes takes ownership of self because the TLS and WebSocket -/// handshake methods do. get_established_ws doesn't, since work can be done -/// on a mutable reference to a tungstenite websocket, and this way, the whole -/// Connection doesn't have to be removed from and reinserted into the -/// TorrentMap. This is also the reason for wrapping Container.inner in an -/// Either instead of combining all states into one structure just having a -/// single method for advancing handshakes and maybe returning a websocket. -impl Connection { - #[inline] - pub fn new(ws_config: WebSocketConfig, valid_until: ValidUntil, tcp_stream: TcpStream) -> Self { - Self { - ws_config, - valid_until, - inner: Either::Right(HandshakeMachine::new(tcp_stream)), - } - } - - #[inline] - pub fn get_established_ws(&mut self) -> Option<&mut EstablishedWs> { - match self.inner { - Either::Left(ref mut ews) => Some(ews), - Either::Right(_) => None, - } - } - - #[inline] - pub fn advance_handshakes( - self, - opt_tls_acceptor: &Option, - valid_until: ValidUntil, - ) -> (Option, bool) { - match self.inner { - Either::Left(_) => (Some(self), false), - Either::Right(machine) => { - let ws_config = self.ws_config; - - let (opt_inner, stop_loop) = machine.advance(ws_config, opt_tls_acceptor); - - let opt_new_self = opt_inner.map(|inner| Self { - ws_config, - valid_until, - inner, - }); - - (opt_new_self, stop_loop) - } - } - } - - #[inline] - pub fn close(&mut self) { - if let Either::Left(ref mut ews) = self.inner { - if ews.ws.can_read() { - if let Err(err) = ews.ws.close(None) { - ::log::info!("error closing ws: {}", err); - } - - // Required after ws.close() - if let Err(err) = ews.ws.write_pending() { - ::log::info!("error writing pending messages after closing ws: {}", err) - } - } - } - } - - pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { - use Either::{Left, Right}; - - match self.inner { - Left(EstablishedWs { ref mut ws, .. }) => ws.get_mut().deregister(poll), - Right(HandshakeMachine::TcpStream(ref mut stream)) => { - poll.registry().deregister(stream) - } - Right(HandshakeMachine::TlsMidHandshake(ref mut handshake)) => { - poll.registry().deregister(handshake.get_mut()) - } - Right(HandshakeMachine::TlsStream(ref mut stream)) => { - poll.registry().deregister(stream.get_mut()) - } - Right(HandshakeMachine::WsMidHandshake(ref mut handshake)) => { - handshake.get_mut().get_mut().deregister(poll) - } - } - } -} - -pub type ConnectionMap = HashMap; diff --git a/aquatic_ws/src/lib/network/mod.rs b/aquatic_ws/src/lib/network/mod.rs deleted file mode 100644 index bbb293d..0000000 --- a/aquatic_ws/src/lib/network/mod.rs +++ /dev/null @@ -1,331 +0,0 @@ -use std::io::ErrorKind; -use std::time::Duration; -use std::vec::Drain; - -use aquatic_common::access_list::AccessListQuery; -use crossbeam_channel::Receiver; -use hashbrown::HashMap; -use log::{debug, error, info}; -use mio::net::TcpListener; -use mio::{Events, Interest, Poll, Token}; -use native_tls::TlsAcceptor; -use tungstenite::protocol::WebSocketConfig; - -use aquatic_common::convert_ipv4_mapped_ipv6; -use aquatic_ws_protocol::*; - -use crate::common::*; -use crate::config::Config; - -pub mod connection; -pub mod utils; - -use connection::*; -use utils::*; - -pub fn run_socket_worker( - config: Config, - state: State, - socket_worker_index: usize, - socket_worker_statuses: SocketWorkerStatuses, - poll: Poll, - in_message_sender: InMessageSender, - out_message_receiver: OutMessageReceiver, - opt_tls_acceptor: Option, -) { - match create_listener(&config) { - Ok(listener) => { - socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(())); - - run_poll_loop( - config, - &state, - socket_worker_index, - poll, - in_message_sender, - out_message_receiver, - listener, - opt_tls_acceptor, - ); - } - Err(err) => { - socket_worker_statuses.lock()[socket_worker_index] = - Some(Err(format!("Couldn't open socket: {:#}", err))); - } - } -} - -pub fn run_poll_loop( - config: Config, - state: &State, - socket_worker_index: usize, - mut poll: Poll, - in_message_sender: InMessageSender, - out_message_receiver: OutMessageReceiver, - listener: ::std::net::TcpListener, - opt_tls_acceptor: Option, -) { - let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds); - let ws_config = WebSocketConfig { - max_message_size: Some(config.network.websocket_max_message_size), - max_frame_size: Some(config.network.websocket_max_frame_size), - max_send_queue: None, - ..Default::default() - }; - - let mut listener = TcpListener::from_std(listener); - let mut events = Events::with_capacity(config.network.poll_event_capacity); - - poll.registry() - .register(&mut listener, LISTENER_TOKEN, Interest::READABLE) - .unwrap(); - - let mut connections: ConnectionMap = HashMap::new(); - let mut local_responses = Vec::new(); - - let mut poll_token_counter = Token(0usize); - let mut iter_counter = 0usize; - - loop { - poll.poll(&mut events, Some(poll_timeout)) - .expect("failed polling"); - - let valid_until = ValidUntil::new(config.cleaning.max_connection_age); - - for event in events.iter() { - let token = event.token(); - - if token == LISTENER_TOKEN { - accept_new_streams( - ws_config, - &mut listener, - &mut poll, - &mut connections, - valid_until, - &mut poll_token_counter, - ); - } else if token != CHANNEL_TOKEN { - run_handshakes_and_read_messages( - &config, - state, - socket_worker_index, - &mut local_responses, - &in_message_sender, - &opt_tls_acceptor, - &mut poll, - &mut connections, - token, - valid_until, - ); - } - - send_out_messages( - &mut poll, - local_responses.drain(..), - &out_message_receiver, - &mut connections, - ); - } - - // Remove inactive connections, but not every iteration - if iter_counter % 128 == 0 { - remove_inactive_connections(&mut connections); - } - - iter_counter = iter_counter.wrapping_add(1); - } -} - -fn accept_new_streams( - ws_config: WebSocketConfig, - listener: &mut TcpListener, - poll: &mut Poll, - connections: &mut ConnectionMap, - valid_until: ValidUntil, - poll_token_counter: &mut Token, -) { - loop { - match listener.accept() { - Ok((mut stream, _)) => { - poll_token_counter.0 = poll_token_counter.0.wrapping_add(1); - - if poll_token_counter.0 < 2 { - poll_token_counter.0 = 2; - } - - let token = *poll_token_counter; - - remove_connection_if_exists(poll, connections, token); - - poll.registry() - .register(&mut stream, token, Interest::READABLE) - .unwrap(); - - let connection = Connection::new(ws_config, valid_until, stream); - - connections.insert(token, connection); - } - Err(err) => { - if err.kind() == ErrorKind::WouldBlock { - break; - } - - info!("error while accepting streams: {}", err); - } - } - } -} - -/// On the stream given by poll_token, get TLS (if requested) and tungstenite -/// up and running, then read messages and pass on through channel. -pub fn run_handshakes_and_read_messages( - config: &Config, - state: &State, - socket_worker_index: usize, - local_responses: &mut Vec<(ConnectionMeta, OutMessage)>, - in_message_sender: &InMessageSender, - opt_tls_acceptor: &Option, // If set, run TLS - poll: &mut Poll, - connections: &mut ConnectionMap, - poll_token: Token, - valid_until: ValidUntil, -) { - let access_list_mode = config.access_list.mode; - - loop { - if let Some(established_ws) = connections - .get_mut(&poll_token) - .map(|c| { - // Ugly but works - c.valid_until = valid_until; - - c - }) - .and_then(Connection::get_established_ws) - { - use ::tungstenite::Error::Io; - - match established_ws.ws.read_message() { - Ok(ws_message) => { - let naive_peer_addr = established_ws.peer_addr; - let converted_peer_ip = convert_ipv4_mapped_ipv6(naive_peer_addr.ip()); - - let meta = ConnectionMeta { - worker_index: socket_worker_index, - poll_token, - naive_peer_addr, - converted_peer_ip, - }; - - debug!("read message"); - - match InMessage::from_ws_message(ws_message) { - Ok(InMessage::AnnounceRequest(ref request)) - if !state - .access_list - .allows(access_list_mode, &request.info_hash.0) => - { - let out_message = OutMessage::ErrorResponse(ErrorResponse { - failure_reason: "Info hash not allowed".into(), - action: Some(ErrorResponseAction::Announce), - info_hash: Some(request.info_hash), - }); - - local_responses.push((meta, out_message)); - } - Ok(in_message) => { - if let Err(err) = in_message_sender.send((meta, in_message)) { - error!("InMessageSender: couldn't send message: {:?}", err); - } - } - Err(_) => { - // FIXME: maybe this condition just occurs when enough data hasn't been recevied? - /* - info!("error parsing message: {:?}", err); - - let out_message = OutMessage::ErrorResponse(ErrorResponse { - failure_reason: "Error parsing message".into(), - action: None, - info_hash: None, - }); - - local_responses.push((meta, out_message)); - */ - } - } - } - Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(tungstenite::Error::ConnectionClosed) => { - remove_connection_if_exists(poll, connections, poll_token); - - break; - } - Err(err) => { - info!("error reading messages: {}", err); - - remove_connection_if_exists(poll, connections, poll_token); - - break; - } - } - } else if let Some(connection) = connections.remove(&poll_token) { - let (opt_new_connection, stop_loop) = - connection.advance_handshakes(opt_tls_acceptor, valid_until); - - if let Some(connection) = opt_new_connection { - connections.insert(poll_token, connection); - } - - if stop_loop { - break; - } - } else { - break; - } - } -} - -/// Read messages from channel, send to peers -pub fn send_out_messages( - poll: &mut Poll, - local_responses: Drain<(ConnectionMeta, OutMessage)>, - out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>, - connections: &mut ConnectionMap, -) { - let len = out_message_receiver.len(); - - for (meta, out_message) in local_responses.chain(out_message_receiver.try_iter().take(len)) { - let opt_established_ws = connections - .get_mut(&meta.poll_token) - .and_then(Connection::get_established_ws); - - if let Some(established_ws) = opt_established_ws { - if established_ws.peer_addr != meta.naive_peer_addr { - info!("socket worker error: peer socket addrs didn't match"); - - continue; - } - - use ::tungstenite::Error::Io; - - let ws_message = out_message.to_ws_message(); - - match established_ws.ws.write_message(ws_message) { - Ok(()) => { - debug!("sent message"); - } - Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => {} - Err(tungstenite::Error::ConnectionClosed) => { - remove_connection_if_exists(poll, connections, meta.poll_token); - } - Err(err) => { - info!("error writing ws message: {}", err); - - remove_connection_if_exists(poll, connections, meta.poll_token); - } - } - } - } -} diff --git a/aquatic_ws/src/lib/network/utils.rs b/aquatic_ws/src/lib/network/utils.rs deleted file mode 100644 index 2568c97..0000000 --- a/aquatic_ws/src/lib/network/utils.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::time::Instant; - -use anyhow::Context; -use mio::{Poll, Token}; -use socket2::{Domain, Protocol, Socket, Type}; - -use crate::config::Config; - -use super::connection::*; - -pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> { - let builder = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)) - } else { - Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)) - } - .context("Couldn't create socket2::Socket")?; - - if config.network.ipv6_only { - builder - .set_only_v6(true) - .context("Couldn't put socket in ipv6 only mode")? - } - - builder - .set_nonblocking(true) - .context("Couldn't put socket in non-blocking mode")?; - builder - .set_reuse_port(true) - .context("Couldn't put socket in reuse_port mode")?; - builder - .bind(&config.network.address.into()) - .with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?; - builder - .listen(128) - .context("Couldn't listen for connections on socket")?; - - Ok(builder.into()) -} - -pub fn remove_connection_if_exists(poll: &mut Poll, connections: &mut ConnectionMap, token: Token) { - if let Some(mut connection) = connections.remove(&token) { - connection.close(); - - if let Err(err) = connection.deregister(poll) { - ::log::error!("couldn't deregister stream: {}", err); - } - } -} - -// Close and remove inactive connections -pub fn remove_inactive_connections(connections: &mut ConnectionMap) { - let now = Instant::now(); - - connections.retain(|_, connection| { - if connection.valid_until.0 < now { - connection.close(); - - false - } else { - true - } - }); - - connections.shrink_to_fit(); -}