diff --git a/Cargo.lock b/Cargo.lock index f6af11f..7380242 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -245,12 +245,12 @@ dependencies = [ "log", "mimalloc", "mio", - "native-tls", "parking_lot", "privdrop", "quickcheck", "quickcheck_macros", "rand", + "rustls", "rustls-pemfile", "serde", "signal-hook", @@ -272,6 +272,7 @@ dependencies = [ "futures-rustls", "glommio 0.6.0 (git+https://github.com/DataDog/glommio.git?rev=2efe2f2a08f54394a435b674e8e0125057cbff03)", "hashbrown 0.11.2", + "log", "mimalloc", "quickcheck", "quickcheck_macros", @@ -517,22 +518,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "core-foundation" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6888e10551bb93e424d8df1d07f1a8b4fceb0001a3a4b048bfc47554946f47b3" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" - [[package]] name = "cpufeatures" version = "0.2.1" @@ -772,21 +757,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.0.1" @@ -1323,24 +1293,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "native-tls" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" -dependencies = [ - "lazy_static", - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nix" version = "0.23.0" @@ -1464,39 +1416,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" -[[package]] -name = "openssl" -version = "0.10.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7ae222234c30df141154f159066c5093ff73b63204dcda7121eb082fc56a95" -dependencies = [ - "bitflags 1.3.2", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-sys", -] - -[[package]] -name = "openssl-probe" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a" - -[[package]] -name = "openssl-sys" -version = "0.9.71" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7df13d165e607909b363a4757a6f133f8a818a74e9d3a98d09c6128e15fa4c73" -dependencies = [ - "autocfg", - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "owned-alloc" version = "0.2.0" @@ -1761,15 +1680,6 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi 0.3.9", -] - [[package]] name = "ring" version = "0.16.20" @@ -1845,16 +1755,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schannel" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75" -dependencies = [ - "lazy_static", - "winapi 0.3.9", -] - [[package]] name = "scoped-tls" version = "1.0.0" @@ -1877,29 +1777,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "security-framework" -version = "2.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525bc1abfda2e1998d152c45cf13e696f76d0a4972310b22fac1658b05df7c87" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9dd14d83160b528b7bfd66439110573efcfbe281b17fc2ca9f39f550d619c7e" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.4" @@ -2105,20 +1982,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "tempfile" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" -dependencies = [ - "cfg-if", - "libc", - "rand", - "redox_syscall", - "remove_dir_all", - "winapi 0.3.9", -] - [[package]] name = "termcolor" version = "1.1.2" @@ -2337,12 +2200,6 @@ dependencies = [ "ryu", ] -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.3" diff --git a/README.md b/README.md index 174d9aa..c9e97ae 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ of sub-implementations for different protocols: |--------------|--------------------------------------------|------------------------------------------------------------| | aquatic_udp | [BitTorrent over UDP] | Unix-like | | aquatic_http | [BitTorrent over HTTP] with TLS ([rustls]) | Linux 5.8+ | -| aquatic_ws | [WebTorrent] | Unix-like with [mio] (default) / Linux 5.8+ with [glommio] | +| aquatic_ws | [WebTorrent] over TLS ([rustls]) | Unix-like with [mio] (default) / Linux 5.8+ with [glommio] | ## Usage @@ -166,28 +166,11 @@ tls_private_key_path = './key.pem' Aims for compatibility with [WebTorrent](https://github.com/webtorrent) clients, with some exceptions: + * Only runs over TLS * Doesn't track of the number of torrent downloads (0 is always sent). * Doesn't allow full scrapes, i.e. of all registered info hashes - -#### TLS: mio version - -To run over TLS, a pkcs12 file (`.pkx`) is needed. It can be generated from -Let's Encrypt certificates as follows, assuming you are in the directory where -they are stored: - -```sh -openssl pkcs12 -export -out identity.pfx -inkey privkey.pem -in cert.pem -certfile fullchain.pem -``` - -Enter a password when prompted. Then move `identity.pfx` somewhere suitable, -and enter the path into the tracker configuration field `tls_pkcs12_path`. Set -the password in the field `tls_pkcs12_password` and set `use_tls` to true. - -#### TLS: glommio version - -The glommio version only runs over TLS. For setup instructions, please see -`aquatic_http` TLS section above. +For TLS setup instructions, please see `aquatic_http` TLS section above. #### Benchmarks diff --git a/TODO.md b/TODO.md index e29a54f..2f49cf2 100644 --- a/TODO.md +++ b/TODO.md @@ -40,6 +40,15 @@ messages can be sent back (e.g., "full scrapes are not supported") * aquatic_ws + * mio + * shard torrent state. this could decrease dropped messages too, since + request handlers won't send large batches of them + * connection cleaning interval + * use access list cache + * use write event interest for handshakes too + * deregistering before closing is required by mio, but it hurts performance + * blocked on https://github.com/snapview/tungstenite-rs/issues/51 + * connection closing: send tls close message etc? * glommio * proper cpu set pinning * RES memory still high after traffic stops, even if torrent maps and connection slabs go down to 0 len and capacity diff --git a/aquatic_ws/Cargo.toml b/aquatic_ws/Cargo.toml index 4bab56b..d159bb6 100644 --- a/aquatic_ws/Cargo.toml +++ b/aquatic_ws/Cargo.toml @@ -16,8 +16,8 @@ name = "aquatic_ws" [features] default = ["with-mio"] cpu-pinning = ["aquatic_common/cpu-pinning"] -with-glommio = ["cpu-pinning", "async-tungstenite", "futures-lite", "futures", "futures-rustls", "glommio", "rustls-pemfile"] -with-mio = ["crossbeam-channel", "histogram", "mio", "native-tls", "parking_lot", "socket2"] +with-glommio = ["cpu-pinning", "async-tungstenite", "futures-lite", "futures", "futures-rustls", "glommio"] +with-mio = ["crossbeam-channel", "histogram", "mio", "parking_lot", "socket2"] [dependencies] anyhow = "1" @@ -31,6 +31,8 @@ log = "0.4" mimalloc = { version = "0.1", default-features = false } privdrop = "0.5" rand = { version = "0.8", features = ["small_rng"] } +rustls = "0.20" +rustls-pemfile = "0.2" serde = { version = "1", features = ["derive"] } signal-hook = { version = "0.3" } slab = "0.4" @@ -40,7 +42,6 @@ tungstenite = "0.16" crossbeam-channel = { version = "0.5", optional = true } histogram = { version = "0.6", optional = true } mio = { version = "0.8", features = ["net", "os-poll"], optional = true } -native-tls = { version = "0.2", optional = true } parking_lot = { version = "0.11", optional = true } socket2 = { version = "0.4", features = ["all"], optional = true } @@ -50,7 +51,6 @@ futures-lite = { version = "1", optional = true } futures = { version = "0.3", optional = true } futures-rustls = { version = "0.22", optional = true } glommio = { git = "https://github.com/DataDog/glommio.git", rev = "2efe2f2a08f54394a435b674e8e0125057cbff03", optional = true } -rustls-pemfile = { version = "0.2", optional = true } [dev-dependencies] quickcheck = "1" diff --git a/aquatic_ws/src/common/mod.rs b/aquatic_ws/src/common/mod.rs index 2b7d1a6..0d58fee 100644 --- a/aquatic_ws/src/common/mod.rs +++ b/aquatic_ws/src/common/mod.rs @@ -1,5 +1,7 @@ pub mod handlers; +use std::fs::File; +use std::io::BufReader; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Instant; @@ -142,3 +144,32 @@ impl TorrentMaps { torrent_map.shrink_to_fit(); } } + +pub fn create_tls_config(config: &Config) -> anyhow::Result { + let certs = { + let f = File::open(&config.network.tls_certificate_path)?; + let mut f = BufReader::new(f); + + rustls_pemfile::certs(&mut f)? + .into_iter() + .map(|bytes| rustls::Certificate(bytes)) + .collect() + }; + + let private_key = { + let f = File::open(&config.network.tls_private_key_path)?; + let mut f = BufReader::new(f); + + rustls_pemfile::pkcs8_private_keys(&mut f)? + .first() + .map(|bytes| rustls::PrivateKey(bytes.clone())) + .ok_or(anyhow::anyhow!("No private keys in file"))? + }; + + let tls_config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, private_key)?; + + Ok(tls_config) +} diff --git a/aquatic_ws/src/config.rs b/aquatic_ws/src/config.rs index 8037b66..8b4839e 100644 --- a/aquatic_ws/src/config.rs +++ b/aquatic_ws/src/config.rs @@ -1,5 +1,4 @@ use std::net::SocketAddr; -#[cfg(feature = "with-glommio")] use std::path::PathBuf; #[cfg(feature = "cpu-pinning")] @@ -48,17 +47,9 @@ pub struct NetworkConfig { pub websocket_max_message_size: usize, pub websocket_max_frame_size: usize, - #[cfg(feature = "with-glommio")] pub tls_certificate_path: PathBuf, - #[cfg(feature = "with-glommio")] pub tls_private_key_path: PathBuf, - #[cfg(feature = "with-mio")] - pub use_tls: bool, - #[cfg(feature = "with-mio")] - pub tls_pkcs12_path: String, - #[cfg(feature = "with-mio")] - pub tls_pkcs12_password: String, #[cfg(feature = "with-mio")] pub poll_event_capacity: usize, #[cfg(feature = "with-mio")] @@ -143,17 +134,9 @@ impl Default for NetworkConfig { websocket_max_message_size: 64 * 1024, websocket_max_frame_size: 16 * 1024, - #[cfg(feature = "with-glommio")] tls_certificate_path: "".into(), - #[cfg(feature = "with-glommio")] tls_private_key_path: "".into(), - #[cfg(feature = "with-mio")] - use_tls: false, - #[cfg(feature = "with-mio")] - tls_pkcs12_path: "".into(), - #[cfg(feature = "with-mio")] - tls_pkcs12_password: "".into(), #[cfg(feature = "with-mio")] poll_event_capacity: 4096, #[cfg(feature = "with-mio")] @@ -176,7 +159,7 @@ impl Default for ProtocolConfig { impl Default for HandlerConfig { fn default() -> Self { Self { - max_requests_per_iter: 10000, + max_requests_per_iter: 256, channel_recv_timeout_microseconds: 200, } } diff --git a/aquatic_ws/src/glommio/mod.rs b/aquatic_ws/src/glommio/mod.rs index 88d5659..f1ceb69 100644 --- a/aquatic_ws/src/glommio/mod.rs +++ b/aquatic_ws/src/glommio/mod.rs @@ -2,13 +2,9 @@ pub mod common; pub mod request; pub mod socket; -use std::{ - fs::File, - io::BufReader, - sync::{atomic::AtomicUsize, Arc}, -}; +use std::sync::{atomic::AtomicUsize, Arc}; -use crate::config::Config; +use crate::{common::create_tls_config, config::Config}; #[cfg(feature = "cpu-pinning")] use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; use aquatic_common::privileges::drop_privileges_after_socket_binding; @@ -109,32 +105,3 @@ pub fn run(config: Config, state: State) -> anyhow::Result<()> { Ok(()) } - -fn create_tls_config(config: &Config) -> anyhow::Result { - let certs = { - let f = File::open(&config.network.tls_certificate_path)?; - let mut f = BufReader::new(f); - - rustls_pemfile::certs(&mut f)? - .into_iter() - .map(|bytes| futures_rustls::rustls::Certificate(bytes)) - .collect() - }; - - let private_key = { - let f = File::open(&config.network.tls_private_key_path)?; - let mut f = BufReader::new(f); - - rustls_pemfile::pkcs8_private_keys(&mut f)? - .first() - .map(|bytes| futures_rustls::rustls::PrivateKey(bytes.clone())) - .ok_or(anyhow::anyhow!("No private keys in file"))? - }; - - let tls_config = futures_rustls::rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs, private_key)?; - - Ok(tls_config) -} diff --git a/aquatic_ws/src/mio/mod.rs b/aquatic_ws/src/mio/mod.rs index e7ea249..54a2fa4 100644 --- a/aquatic_ws/src/mio/mod.rs +++ b/aquatic_ws/src/mio/mod.rs @@ -1,5 +1,3 @@ -use std::fs::File; -use std::io::Read; use std::sync::Arc; use std::thread::Builder; use std::time::Duration; @@ -9,7 +7,6 @@ use anyhow::Context; use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; use histogram::Histogram; use mio::{Poll, Waker}; -use native_tls::{Identity, TlsAcceptor}; use parking_lot::Mutex; use privdrop::PrivDrop; @@ -17,11 +14,13 @@ pub mod common; pub mod request; pub mod socket; -use crate::config::Config; +use crate::{common::create_tls_config, config::Config}; use common::*; pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; +const SHARED_IN_CHANNEL_SIZE: usize = 1024; + pub fn run(config: Config, state: State) -> anyhow::Result<()> { start_workers(config.clone(), state.clone()).expect("couldn't start workers"); @@ -44,9 +43,10 @@ pub fn run(config: Config, state: State) -> anyhow::Result<()> { } pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { - let opt_tls_acceptor = create_tls_acceptor(&config)?; + let tls_config = Arc::new(create_tls_config(&config)?); - let (in_message_sender, in_message_receiver) = ::crossbeam_channel::unbounded(); + let (in_message_sender, in_message_receiver) = + ::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE); let mut out_message_senders = Vec::new(); let mut wakers = Vec::new(); @@ -66,11 +66,12 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { let state = state.clone(); let socket_worker_statuses = socket_worker_statuses.clone(); let in_message_sender = in_message_sender.clone(); - let opt_tls_acceptor = opt_tls_acceptor.clone(); + let tls_config = tls_config.clone(); let poll = Poll::new()?; let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?); - let (out_message_sender, out_message_receiver) = ::crossbeam_channel::unbounded(); + let (out_message_sender, out_message_receiver) = + ::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE * 16); out_message_senders.push(out_message_sender); wakers.push(waker); @@ -93,7 +94,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { poll, in_message_sender, out_message_receiver, - opt_tls_acceptor, + tls_config, ); })?; } @@ -180,27 +181,6 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { Ok(()) } -pub fn create_tls_acceptor(config: &Config) -> anyhow::Result> { - if config.network.use_tls { - let mut identity_bytes = Vec::new(); - let mut file = File::open(&config.network.tls_pkcs12_path) - .context("Couldn't open pkcs12 identity file")?; - - file.read_to_end(&mut identity_bytes) - .context("Couldn't read pkcs12 identity file")?; - - let identity = Identity::from_pkcs12(&identity_bytes, &config.network.tls_pkcs12_password) - .context("Couldn't parse pkcs12 identity file")?; - - let acceptor = TlsAcceptor::new(identity) - .context("Couldn't create TlsAcceptor from pkcs12 identity")?; - - Ok(Some(acceptor)) - } else { - Ok(None) - } -} - fn print_statistics(state: &State) { let mut peers_per_torrent = Histogram::new(); diff --git a/aquatic_ws/src/mio/socket/connection.rs b/aquatic_ws/src/mio/socket/connection.rs index 4769dc6..f5559d0 100644 --- a/aquatic_ws/src/mio/socket/connection.rs +++ b/aquatic_ws/src/mio/socket/connection.rs @@ -1,298 +1,577 @@ -use std::io::{Read, Write}; -use std::net::SocketAddr; +use std::{collections::VecDeque, io::ErrorKind, marker::PhantomData, net::Shutdown, sync::Arc}; -use either::Either; -use hashbrown::HashMap; -use log::info; -use mio::net::TcpStream; -use mio::{Poll, Token}; -use native_tls::{MidHandshakeTlsStream, TlsAcceptor, TlsStream}; -use tungstenite::handshake::{server::NoCallback, HandshakeError, MidHandshake}; -use tungstenite::protocol::WebSocketConfig; -use tungstenite::ServerHandshake; -use tungstenite::WebSocket; +use aquatic_common::ValidUntil; +use aquatic_ws_protocol::{InMessage, OutMessage}; +use mio::{net::TcpStream, Interest, Poll, Token}; +use rustls::{ServerConfig, ServerConnection}; +use tungstenite::{ + handshake::{server::NoCallback, MidHandshake}, + protocol::WebSocketConfig, + HandshakeError, ServerHandshake, +}; -use crate::common::*; +use crate::common::ConnectionMeta; -pub enum Stream { - TcpStream(TcpStream), - TlsStream(TlsStream), +const MAX_PENDING_MESSAGES: usize = 16; + +type TlsStream = rustls::StreamOwned; + +type WsHandshakeResult = + Result, HandshakeError>>; + +type ConnectionReadResult = ::std::io::Result>; + +pub trait RegistryStatus {} + +pub struct Registered; + +impl RegistryStatus for Registered {} + +pub struct NotRegistered; + +impl RegistryStatus for NotRegistered {} + +enum ConnectionReadStatus { + Message(T, InMessage), + Ok(T), + WouldBlock(T), } -impl Stream { - #[inline] - pub fn get_peer_addr(&self) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.peer_addr(), - Self::TlsStream(stream) => stream.get_ref().peer_addr(), - } - } +enum ConnectionState { + TlsHandshaking(TlsHandshaking), + WsHandshaking(WsHandshaking), + WsConnection(WsConnection), +} - #[inline] - pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { - match self { - Self::TcpStream(stream) => poll.registry().deregister(stream), - Self::TlsStream(stream) => poll.registry().deregister(stream.get_mut()), - } +pub struct Connection { + pub valid_until: ValidUntil, + meta: ConnectionMeta, + state: ConnectionState, + pub message_queue: VecDeque, + pub interest: Interest, + phantom_data: PhantomData, +} + +impl Connection { + pub fn get_meta(&self) -> ConnectionMeta { + self.meta } } -impl Read for Stream { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> Result { - match self { - Self::TcpStream(stream) => stream.read(buf), - Self::TlsStream(stream) => stream.read(buf), - } - } - - /// Not used but provided for completeness - #[inline] - fn read_vectored( - &mut self, - bufs: &mut [::std::io::IoSliceMut<'_>], - ) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.read_vectored(bufs), - Self::TlsStream(stream) => stream.read_vectored(bufs), - } - } -} - -impl Write for Stream { - #[inline] - fn write(&mut self, buf: &[u8]) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.write(buf), - Self::TlsStream(stream) => stream.write(buf), - } - } - - /// Not used but provided for completeness - #[inline] - fn write_vectored(&mut self, bufs: &[::std::io::IoSlice<'_>]) -> ::std::io::Result { - match self { - Self::TcpStream(stream) => stream.write_vectored(bufs), - Self::TlsStream(stream) => stream.write_vectored(bufs), - } - } - - #[inline] - fn flush(&mut self) -> ::std::io::Result<()> { - match self { - Self::TcpStream(stream) => stream.flush(), - Self::TlsStream(stream) => stream.flush(), - } - } -} - -enum HandshakeMachine { - TcpStream(TcpStream), - TlsStream(TlsStream), - TlsMidHandshake(MidHandshakeTlsStream), - WsMidHandshake(MidHandshake>), -} - -impl HandshakeMachine { - #[inline] - fn new(tcp_stream: TcpStream) -> Self { - Self::TcpStream(tcp_stream) - } - - #[inline] - fn advance( - self, +impl Connection { + pub fn new( + tls_config: Arc, ws_config: WebSocketConfig, - opt_tls_acceptor: &Option, // If set, run TLS - ) -> (Option>, bool) { - // bool = stop looping - match self { - HandshakeMachine::TcpStream(stream) => { - if let Some(tls_acceptor) = opt_tls_acceptor { - Self::handle_tls_handshake_result(tls_acceptor.accept(stream)) - } else { - let handshake_result = ::tungstenite::accept_with_config( - Stream::TcpStream(stream), - Some(ws_config), - ); + tcp_stream: TcpStream, + valid_until: ValidUntil, + meta: ConnectionMeta, + ) -> Self { + let state = + ConnectionState::TlsHandshaking(TlsHandshaking::new(tls_config, ws_config, tcp_stream)); - Self::handle_ws_handshake_result(handshake_result) + Self { + valid_until, + meta, + state, + message_queue: Default::default(), + interest: Interest::READABLE, + phantom_data: PhantomData::default(), + } + } + + /// Read until stream blocks (or error occurs) + /// + /// Requires Connection not to be registered, since it might be dropped on errors + pub fn read( + mut self, + message_handler: &mut F, + ) -> ::std::io::Result> + where + F: FnMut(ConnectionMeta, InMessage), + { + loop { + let result = match self.state { + ConnectionState::TlsHandshaking(inner) => inner.read(), + ConnectionState::WsHandshaking(inner) => inner.read(), + ConnectionState::WsConnection(inner) => inner.read(), + }; + + match result { + Ok(ConnectionReadStatus::Message(state, message)) => { + self.state = state; + + message_handler(self.meta, message); + + // Stop looping even if WouldBlock wasn't necessarily reached. Otherwise, + // we might get stuck reading from this connection only. Since we register + // the connection again upon reinsertion into the ConnectionMap, we should + // be getting new events anyway. + return Ok(self); } - } - HandshakeMachine::TlsStream(stream) => { - let handshake_result = ::tungstenite::accept(Stream::TlsStream(stream)); + Ok(ConnectionReadStatus::Ok(state)) => { + self.state = state; - Self::handle_ws_handshake_result(handshake_result) - } - HandshakeMachine::TlsMidHandshake(handshake) => { - Self::handle_tls_handshake_result(handshake.handshake()) - } - HandshakeMachine::WsMidHandshake(handshake) => { - Self::handle_ws_handshake_result(handshake.handshake()) - } - } - } + ::log::debug!("read connection"); + } + Ok(ConnectionReadStatus::WouldBlock(state)) => { + self.state = state; - #[inline] - fn handle_tls_handshake_result( - result: Result, ::native_tls::HandshakeError>, - ) -> (Option>, bool) { - match result { - Ok(stream) => { - ::log::trace!( - "established tls handshake with peer with addr: {:?}", - stream.get_ref().peer_addr() - ); + ::log::debug!("reading connection would block"); - (Some(Either::Right(Self::TlsStream(stream))), false) - } - Err(native_tls::HandshakeError::WouldBlock(handshake)) => { - (Some(Either::Right(Self::TlsMidHandshake(handshake))), true) - } - Err(native_tls::HandshakeError::Failure(err)) => { - info!("tls handshake error: {}", err); - - (None, false) - } - } - } - - #[inline] - fn handle_ws_handshake_result( - result: Result, HandshakeError>>, - ) -> (Option>, bool) { - match result { - Ok(mut ws) => match ws.get_mut().get_peer_addr() { - Ok(peer_addr) => { - ::log::trace!( - "established ws handshake with peer with addr: {:?}", - peer_addr - ); - - let established_ws = EstablishedWs { ws, peer_addr }; - - (Some(Either::Left(established_ws)), false) + return Ok(self); } Err(err) => { - ::log::info!( - "get_peer_addr failed during handshake, removing connection: {:?}", - err - ); + ::log::debug!("Connection::read error: {}", err); - (None, false) + return Err(err); + } + } + } + } + + pub fn register(self, poll: &mut Poll, token: Token) -> Connection { + let state = match self.state { + ConnectionState::TlsHandshaking(inner) => { + ConnectionState::TlsHandshaking(inner.register(poll, token, self.interest)) + } + ConnectionState::WsHandshaking(inner) => { + ConnectionState::WsHandshaking(inner.register(poll, token, self.interest)) + } + ConnectionState::WsConnection(inner) => { + ConnectionState::WsConnection(inner.register(poll, token, self.interest)) + } + }; + + Connection { + valid_until: self.valid_until, + meta: self.meta, + state, + message_queue: self.message_queue, + interest: self.interest, + phantom_data: PhantomData::default(), + } + } + + pub fn close(self) { + ::log::debug!("will close connection to {}", self.meta.naive_peer_addr); + + match self.state { + ConnectionState::TlsHandshaking(inner) => inner.close(), + ConnectionState::WsHandshaking(inner) => inner.close(), + ConnectionState::WsConnection(inner) => inner.close(), + } + } +} + +impl Connection { + pub fn write_or_queue_message( + &mut self, + poll: &mut Poll, + message: OutMessage, + ) -> ::std::io::Result<()> { + let message_clone = message.clone(); + + match self.write_message(message) { + Ok(()) => Ok(()), + Err(err) if err.kind() == ErrorKind::WouldBlock => { + if self.message_queue.len() < MAX_PENDING_MESSAGES { + self.message_queue.push_back(message_clone); + + if !self.interest.is_writable() { + self.interest = Interest::WRITABLE; + + self.reregister(poll)?; + } + } else { + ::log::info!("Connection::message_queue is full, dropping message"); + } + + Ok(()) + } + Err(err) => Err(err), + } + } + + pub fn write(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { + if let ConnectionState::WsConnection(_) = self.state { + while let Some(message) = self.message_queue.pop_front() { + let message_clone = message.clone(); + + match self.write_message(message) { + Ok(()) => {} + Err(err) if err.kind() == ErrorKind::WouldBlock => { + // Can't make message queue longer than it was before pop_front + self.message_queue.push_front(message_clone); + + return Ok(()); + } + Err(err) => { + return Err(err); + } + } + } + + if self.message_queue.is_empty() { + self.interest = Interest::READABLE; + } + + self.reregister(poll)?; + + Ok(()) + } else { + Err(std::io::Error::new( + ErrorKind::NotConnected, + "WebSocket connection not established", + )) + } + } + + fn write_message(&mut self, message: OutMessage) -> ::std::io::Result<()> { + if let ConnectionState::WsConnection(WsConnection { + ref mut web_socket, .. + }) = self.state + { + match web_socket.write_message(message.to_ws_message()) { + Ok(_) => {} + Err(tungstenite::Error::SendQueueFull(_message)) => { + return Err(std::io::Error::new( + ErrorKind::WouldBlock, + "Send queue full", + )) + } + Err(tungstenite::Error::Io(err)) => return Err(err), + Err(err) => return Err(std::io::Error::new(ErrorKind::Other, err))?, + } + + match web_socket.write_pending() { + Ok(()) => Ok(()), + Err(tungstenite::Error::Io(err)) => Err(err), + Err(err) => Err(std::io::Error::new(ErrorKind::Other, err))?, + } + } else { + Err(std::io::Error::new( + ErrorKind::NotConnected, + "WebSocket connection not established", + )) + } + } + + pub fn reregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { + let token = Token(self.meta.connection_id.0); + + match self.state { + ConnectionState::TlsHandshaking(ref mut inner) => { + inner.reregister(poll, token, self.interest) + } + ConnectionState::WsHandshaking(ref mut inner) => { + inner.reregister(poll, token, self.interest) + } + ConnectionState::WsConnection(ref mut inner) => { + inner.reregister(poll, token, self.interest) + } + } + } + + pub fn deregister(self, poll: &mut Poll) -> Connection { + let state = match self.state { + ConnectionState::TlsHandshaking(inner) => { + ConnectionState::TlsHandshaking(inner.deregister(poll)) + } + ConnectionState::WsHandshaking(inner) => { + ConnectionState::WsHandshaking(inner.deregister(poll)) + } + ConnectionState::WsConnection(inner) => { + ConnectionState::WsConnection(inner.deregister(poll)) + } + }; + + Connection { + valid_until: self.valid_until, + meta: self.meta, + state, + message_queue: self.message_queue, + interest: self.interest, + phantom_data: PhantomData::default(), + } + } +} + +struct TlsHandshaking { + tls_conn: ServerConnection, + ws_config: WebSocketConfig, + tcp_stream: TcpStream, + phantom_data: PhantomData, +} + +impl TlsHandshaking { + fn new(tls_config: Arc, ws_config: WebSocketConfig, stream: TcpStream) -> Self { + Self { + tls_conn: ServerConnection::new(tls_config).unwrap(), + ws_config, + tcp_stream: stream, + phantom_data: PhantomData::default(), + } + } + + fn read(mut self) -> ConnectionReadResult> { + match self.tls_conn.read_tls(&mut self.tcp_stream) { + Ok(0) => { + return Err(::std::io::Error::new( + ErrorKind::ConnectionReset, + "Connection closed", + )) + } + Ok(_) => match self.tls_conn.process_new_packets() { + Ok(_) => { + while self.tls_conn.wants_write() { + self.tls_conn.write_tls(&mut self.tcp_stream)?; + } + + if self.tls_conn.is_handshaking() { + Ok(ConnectionReadStatus::WouldBlock( + ConnectionState::TlsHandshaking(self), + )) + } else { + let tls_stream = TlsStream::new(self.tls_conn, self.tcp_stream); + + WsHandshaking::handle_handshake_result(tungstenite::accept_with_config( + tls_stream, + Some(self.ws_config), + )) + } + } + Err(err) => { + let _ = self.tls_conn.write_tls(&mut self.tcp_stream); + + Err(::std::io::Error::new(ErrorKind::InvalidData, err)) } }, - Err(HandshakeError::Interrupted(handshake)) => ( - Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))), - true, - ), - Err(HandshakeError::Failure(err)) => { - info!("ws handshake error: {}", err); - - (None, false) + Err(err) if err.kind() == ErrorKind::WouldBlock => { + return Ok(ConnectionReadStatus::WouldBlock( + ConnectionState::TlsHandshaking(self), + )) } - } - } -} - -pub struct EstablishedWs { - pub ws: WebSocket, - pub peer_addr: SocketAddr, -} - -pub struct Connection { - ws_config: WebSocketConfig, - pub valid_until: ValidUntil, - inner: Either, -} - -/// Create from TcpStream. Run `advance_handshakes` until `get_established_ws` -/// returns Some(EstablishedWs). -/// -/// advance_handshakes takes ownership of self because the TLS and WebSocket -/// handshake methods do. get_established_ws doesn't, since work can be done -/// on a mutable reference to a tungstenite websocket, and this way, the whole -/// Connection doesn't have to be removed from and reinserted into the -/// TorrentMap. This is also the reason for wrapping Container.inner in an -/// Either instead of combining all states into one structure just having a -/// single method for advancing handshakes and maybe returning a websocket. -impl Connection { - #[inline] - pub fn new(ws_config: WebSocketConfig, valid_until: ValidUntil, tcp_stream: TcpStream) -> Self { - Self { - ws_config, - valid_until, - inner: Either::Right(HandshakeMachine::new(tcp_stream)), + Err(err) => return Err(err), } } - #[inline] - pub fn get_established_ws(&mut self) -> Option<&mut EstablishedWs> { - match self.inner { - Either::Left(ref mut ews) => Some(ews), - Either::Right(_) => None, + fn register( + mut self, + poll: &mut Poll, + token: Token, + interest: Interest, + ) -> TlsHandshaking { + poll.registry() + .register(&mut self.tcp_stream, token, interest) + .unwrap(); + + TlsHandshaking { + tls_conn: self.tls_conn, + ws_config: self.ws_config, + tcp_stream: self.tcp_stream, + phantom_data: PhantomData::default(), } } - #[inline] - pub fn advance_handshakes( - self, - opt_tls_acceptor: &Option, - valid_until: ValidUntil, - ) -> (Option, bool) { - match self.inner { - Either::Left(_) => (Some(self), false), - Either::Right(machine) => { - let ws_config = self.ws_config; + fn close(self) { + ::log::debug!("closing connection (TlsHandshaking state)"); - let (opt_inner, stop_loop) = machine.advance(ws_config, opt_tls_acceptor); + let _ = self.tcp_stream.shutdown(Shutdown::Both); + } +} - let opt_new_self = opt_inner.map(|inner| Self { - ws_config, - valid_until, - inner, +impl TlsHandshaking { + fn deregister(mut self, poll: &mut Poll) -> TlsHandshaking { + poll.registry().deregister(&mut self.tcp_stream).unwrap(); + + TlsHandshaking { + tls_conn: self.tls_conn, + ws_config: self.ws_config, + tcp_stream: self.tcp_stream, + phantom_data: PhantomData::default(), + } + } + + fn reregister( + &mut self, + poll: &mut Poll, + token: Token, + interest: Interest, + ) -> std::io::Result<()> { + poll.registry() + .reregister(&mut self.tcp_stream, token, interest) + } +} + +struct WsHandshaking { + mid_handshake: MidHandshake>, + phantom_data: PhantomData, +} + +impl WsHandshaking { + fn read(self) -> ConnectionReadResult> { + Self::handle_handshake_result(self.mid_handshake.handshake()) + } + + fn handle_handshake_result( + handshake_result: WsHandshakeResult, + ) -> ConnectionReadResult> { + match handshake_result { + Ok(web_socket) => { + let conn = ConnectionState::WsConnection(WsConnection { + web_socket, + phantom_data: PhantomData::default(), }); - (opt_new_self, stop_loop) + Ok(ConnectionReadStatus::Ok(conn)) + } + Err(HandshakeError::Interrupted(mid_handshake)) => { + let conn = ConnectionState::WsHandshaking(WsHandshaking { + mid_handshake, + phantom_data: PhantomData::default(), + }); + + Ok(ConnectionReadStatus::WouldBlock(conn)) + } + Err(HandshakeError::Failure(err)) => { + return Err(std::io::Error::new(ErrorKind::InvalidData, err)) } } } - #[inline] - pub fn close(&mut self) { - if let Either::Left(ref mut ews) = self.inner { - if ews.ws.can_read() { - if let Err(err) = ews.ws.close(None) { - ::log::info!("error closing ws: {}", err); - } + fn register( + mut self, + poll: &mut Poll, + token: Token, + interest: Interest, + ) -> WsHandshaking { + let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; - // Required after ws.close() - if let Err(err) = ews.ws.write_pending() { - ::log::info!("error writing pending messages after closing ws: {}", err) - } - } + poll.registry() + .register(tcp_stream, token, interest) + .unwrap(); + + WsHandshaking { + mid_handshake: self.mid_handshake, + phantom_data: PhantomData::default(), } } - pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> { - use Either::{Left, Right}; + fn close(mut self) { + ::log::debug!("closing connection (WsHandshaking state)"); - match self.inner { - Left(EstablishedWs { ref mut ws, .. }) => ws.get_mut().deregister(poll), - Right(HandshakeMachine::TcpStream(ref mut stream)) => { - poll.registry().deregister(stream) - } - Right(HandshakeMachine::TlsMidHandshake(ref mut handshake)) => { - poll.registry().deregister(handshake.get_mut()) - } - Right(HandshakeMachine::TlsStream(ref mut stream)) => { - poll.registry().deregister(stream.get_mut()) - } - Right(HandshakeMachine::WsMidHandshake(ref mut handshake)) => { - handshake.get_mut().get_mut().deregister(poll) - } - } + let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; + + let _ = tcp_stream.shutdown(Shutdown::Both); } } -pub type ConnectionMap = HashMap; +impl WsHandshaking { + fn deregister(mut self, poll: &mut Poll) -> WsHandshaking { + let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; + + poll.registry().deregister(tcp_stream).unwrap(); + + WsHandshaking { + mid_handshake: self.mid_handshake, + phantom_data: PhantomData::default(), + } + } + + fn reregister( + &mut self, + poll: &mut Poll, + token: Token, + interest: Interest, + ) -> std::io::Result<()> { + let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock; + + poll.registry().reregister(tcp_stream, token, interest) + } +} + +struct WsConnection { + web_socket: tungstenite::WebSocket, + phantom_data: PhantomData, +} + +impl WsConnection { + fn read(mut self) -> ConnectionReadResult> { + match self.web_socket.read_message() { + Ok( + message @ tungstenite::Message::Text(_) | message @ tungstenite::Message::Binary(_), + ) => match InMessage::from_ws_message(message) { + Ok(message) => { + ::log::debug!("received WebSocket message"); + + Ok(ConnectionReadStatus::Message( + ConnectionState::WsConnection(self), + message, + )) + } + Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)), + }, + Ok(message) => { + ::log::info!("received unexpected WebSocket message: {}", message); + + Err(std::io::Error::new( + ErrorKind::InvalidData, + "unexpected WebSocket message type", + )) + } + Err(tungstenite::Error::Io(err)) if err.kind() == ErrorKind::WouldBlock => { + let conn = ConnectionState::WsConnection(self); + + Ok(ConnectionReadStatus::WouldBlock(conn)) + } + Err(tungstenite::Error::Io(err)) => Err(err), + Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)), + } + } + + fn register( + mut self, + poll: &mut Poll, + token: Token, + interest: Interest, + ) -> WsConnection { + poll.registry() + .register(self.web_socket.get_mut().get_mut(), token, interest) + .unwrap(); + + WsConnection { + web_socket: self.web_socket, + phantom_data: PhantomData::default(), + } + } + + fn close(mut self) { + ::log::debug!("closing connection (WsConnection state)"); + + let _ = self.web_socket.close(None); + let _ = self.web_socket.write_pending(); + } +} + +impl WsConnection { + fn deregister(mut self, poll: &mut Poll) -> WsConnection { + poll.registry() + .deregister(self.web_socket.get_mut().get_mut()) + .unwrap(); + + WsConnection { + web_socket: self.web_socket, + phantom_data: PhantomData::default(), + } + } + + fn reregister( + &mut self, + poll: &mut Poll, + token: Token, + interest: Interest, + ) -> std::io::Result<()> { + poll.registry() + .reregister(self.web_socket.get_mut().get_mut(), token, interest) + } +} diff --git a/aquatic_ws/src/mio/socket/mod.rs b/aquatic_ws/src/mio/socket/mod.rs index a109f79..237c226 100644 --- a/aquatic_ws/src/mio/socket/mod.rs +++ b/aquatic_ws/src/mio/socket/mod.rs @@ -1,14 +1,13 @@ use std::io::ErrorKind; -use std::time::Duration; -use std::vec::Drain; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use anyhow::Context; use aquatic_common::access_list::AccessListQuery; -use crossbeam_channel::Receiver; use hashbrown::HashMap; -use log::{debug, error, info}; use mio::net::TcpListener; use mio::{Events, Interest, Poll, Token}; -use native_tls::TlsAcceptor; +use socket2::{Domain, Protocol, Socket, Type}; use tungstenite::protocol::WebSocketConfig; use aquatic_common::convert_ipv4_mapped_ipv6; @@ -17,13 +16,101 @@ use aquatic_ws_protocol::*; use crate::common::*; use crate::config::Config; +pub mod connection; + use super::common::*; -pub mod connection; -pub mod utils; +use connection::{Connection, NotRegistered, Registered}; -use connection::*; -use utils::*; +struct ConnectionMap { + token_counter: Token, + connections: HashMap>, +} + +impl Default for ConnectionMap { + fn default() -> Self { + Self { + token_counter: Token(2), + connections: Default::default(), + } + } +} + +impl ConnectionMap { + fn insert_and_register_new(&mut self, poll: &mut Poll, connection_creator: F) + where + F: FnOnce(Token) -> Connection, + { + self.token_counter.0 = self.token_counter.0.wrapping_add(1); + + // Don't assign LISTENER_TOKEN or CHANNEL_TOKEN + if self.token_counter.0 < 2 { + self.token_counter.0 = 2; + } + + let token = self.token_counter; + + // Remove, deregister and close any existing connection with this token. + // This shouldn't happen in practice. + if let Some(connection) = self.connections.remove(&token) { + ::log::warn!( + "removing existing connection {} because of token reuse", + token.0 + ); + + connection.deregister(poll).close(); + } + + let connection = connection_creator(token); + + self.insert_and_register(poll, token, connection); + } + + fn insert_and_register( + &mut self, + poll: &mut Poll, + key: Token, + conn: Connection, + ) { + self.connections.insert(key, conn.register(poll, key)); + } + + fn remove_and_deregister( + &mut self, + poll: &mut Poll, + key: &Token, + ) -> Option> { + if let Some(connection) = self.connections.remove(key) { + Some(connection.deregister(poll)) + } else { + None + } + } + + fn get_mut(&mut self, key: &Token) -> Option<&mut Connection> { + self.connections.get_mut(key) + } + + /// Close and remove inactive connections + fn clean(mut self, poll: &mut Poll) -> Self { + let now = Instant::now(); + + let mut retained_connections = HashMap::default(); + + for (token, connection) in self.connections.drain() { + if connection.valid_until.0 < now { + connection.deregister(poll).close(); + } else { + retained_connections.insert(token, connection); + } + } + + ConnectionMap { + connections: retained_connections, + ..self + } + } +} pub fn run_socket_worker( config: Config, @@ -33,7 +120,7 @@ pub fn run_socket_worker( poll: Poll, in_message_sender: InMessageSender, out_message_receiver: OutMessageReceiver, - opt_tls_acceptor: Option, + tls_config: Arc, ) { match create_listener(&config) { Ok(listener) => { @@ -47,7 +134,7 @@ pub fn run_socket_worker( in_message_sender, out_message_receiver, listener, - opt_tls_acceptor, + tls_config, ); } Err(err) => { @@ -57,7 +144,7 @@ pub fn run_socket_worker( } } -pub fn run_poll_loop( +fn run_poll_loop( config: Config, state: &State, socket_worker_index: usize, @@ -65,13 +152,13 @@ pub fn run_poll_loop( in_message_sender: InMessageSender, out_message_receiver: OutMessageReceiver, listener: ::std::net::TcpListener, - opt_tls_acceptor: Option, + tls_config: Arc, ) { let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds); let ws_config = WebSocketConfig { max_message_size: Some(config.network.websocket_max_message_size), max_frame_size: Some(config.network.websocket_max_frame_size), - max_send_queue: None, + max_send_queue: Some(2), ..Default::default() }; @@ -82,10 +169,9 @@ pub fn run_poll_loop( .register(&mut listener, LISTENER_TOKEN, Interest::READABLE) .unwrap(); - let mut connections: ConnectionMap = HashMap::new(); + let mut connections = ConnectionMap::default(); let mut local_responses = Vec::new(); - let mut poll_token_counter = Token(0usize); let mut iter_counter = 0usize; loop { @@ -97,41 +183,68 @@ pub fn run_poll_loop( for event in events.iter() { let token = event.token(); - if token == LISTENER_TOKEN { - accept_new_streams( - ws_config, - &mut listener, - &mut poll, - &mut connections, - valid_until, - &mut poll_token_counter, - ); - } else if token != CHANNEL_TOKEN { - run_handshakes_and_read_messages( - &config, - state, - socket_worker_index, - &mut local_responses, - &in_message_sender, - &opt_tls_acceptor, - &mut poll, - &mut connections, - token, - valid_until, - ); + match token { + LISTENER_TOKEN => { + accept_new_streams( + &tls_config, + ws_config, + socket_worker_index, + &mut listener, + &mut poll, + &mut connections, + valid_until, + ); + } + CHANNEL_TOKEN => { + write_or_queue_messages( + &mut poll, + out_message_receiver + .try_iter() + .take(out_message_receiver.len()), + &mut connections, + ); + } + token => { + if event.is_writable() { + let mut remove_connection = false; + + if let Some(connection) = connections.get_mut(&token) { + if let Err(err) = connection.write(&mut poll) { + ::log::debug!("Connection::write error: {}", err); + + remove_connection = true; + } + } + + if remove_connection { + if let Some(connection) = + connections.remove_and_deregister(&mut poll, &token) + { + connection.close(); + } + } + } + if event.is_readable() { + handle_stream_read_event( + &config, + state, + &mut local_responses, + &in_message_sender, + &mut poll, + &mut connections, + token, + valid_until, + ); + } + } } - send_out_messages( - &mut poll, - local_responses.drain(..), - &out_message_receiver, - &mut connections, - ); + write_or_queue_messages(&mut poll, local_responses.drain(..), &mut connections); } // Remove inactive connections, but not every iteration if iter_counter % 128 == 0 { - remove_inactive_connections(&mut connections); + connections = connections.clean(&mut poll); } iter_counter = iter_counter.wrapping_add(1); @@ -139,194 +252,155 @@ pub fn run_poll_loop( } fn accept_new_streams( + tls_config: &Arc, ws_config: WebSocketConfig, + socket_worker_index: usize, listener: &mut TcpListener, poll: &mut Poll, connections: &mut ConnectionMap, valid_until: ValidUntil, - poll_token_counter: &mut Token, ) { loop { match listener.accept() { - Ok((mut stream, _)) => { - poll_token_counter.0 = poll_token_counter.0.wrapping_add(1); + Ok((stream, _)) => { + let naive_peer_addr = if let Ok(peer_addr) = stream.peer_addr() { + peer_addr + } else { + continue; + }; - if poll_token_counter.0 < 2 { - poll_token_counter.0 = 2; - } - - let token = *poll_token_counter; - - remove_connection_if_exists(poll, connections, token); - - poll.registry() - .register(&mut stream, token, Interest::READABLE) - .unwrap(); - - let connection = Connection::new(ws_config, valid_until, stream); - - connections.insert(token, connection); - } - Err(err) => { - if err.kind() == ErrorKind::WouldBlock { - break; - } - - info!("error while accepting streams: {}", err); - } - } - } -} - -/// On the stream given by poll_token, get TLS (if requested) and tungstenite -/// up and running, then read messages and pass on through channel. -pub fn run_handshakes_and_read_messages( - config: &Config, - state: &State, - socket_worker_index: usize, - local_responses: &mut Vec<(ConnectionMeta, OutMessage)>, - in_message_sender: &InMessageSender, - opt_tls_acceptor: &Option, // If set, run TLS - poll: &mut Poll, - connections: &mut ConnectionMap, - poll_token: Token, - valid_until: ValidUntil, -) { - let access_list_mode = config.access_list.mode; - - loop { - if let Some(established_ws) = connections - .get_mut(&poll_token) - .map(|c| { - // Ugly but works - c.valid_until = valid_until; - - c - }) - .and_then(Connection::get_established_ws) - { - use ::tungstenite::Error::Io; - - match established_ws.ws.read_message() { - Ok(ws_message) => { - let naive_peer_addr = established_ws.peer_addr; + connections.insert_and_register_new(poll, move |token| { let converted_peer_ip = convert_ipv4_mapped_ipv6(naive_peer_addr.ip()); let meta = ConnectionMeta { out_message_consumer_id: ConsumerId(socket_worker_index), - connection_id: ConnectionId(poll_token.0), + connection_id: ConnectionId(token.0), naive_peer_addr, converted_peer_ip, pending_scrape_id: None, // FIXME }; - debug!("read message"); - - match InMessage::from_ws_message(ws_message) { - Ok(InMessage::AnnounceRequest(ref request)) - if !state - .access_list - .allows(access_list_mode, &request.info_hash.0) => - { - let out_message = OutMessage::ErrorResponse(ErrorResponse { - failure_reason: "Info hash not allowed".into(), - action: Some(ErrorResponseAction::Announce), - info_hash: Some(request.info_hash), - }); - - local_responses.push((meta, out_message)); - } - Ok(in_message) => { - if let Err(err) = in_message_sender.send((meta, in_message)) { - error!("InMessageSender: couldn't send message: {:?}", err); - } - } - Err(_) => { - // FIXME: maybe this condition just occurs when enough data hasn't been recevied? - /* - info!("error parsing message: {:?}", err); - let out_message = OutMessage::ErrorResponse(ErrorResponse { - failure_reason: "Error parsing message".into(), - action: None, - info_hash: None, - }); - local_responses.push((meta, out_message)); - */ - } - } - } - Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(tungstenite::Error::ConnectionClosed) => { - remove_connection_if_exists(poll, connections, poll_token); - - break; - } - Err(err) => { - info!("error reading messages: {}", err); - - remove_connection_if_exists(poll, connections, poll_token); - - break; - } + Connection::new(tls_config.clone(), ws_config, stream, valid_until, meta) + }); } - } else if let Some(connection) = connections.remove(&poll_token) { - let (opt_new_connection, stop_loop) = - connection.advance_handshakes(opt_tls_acceptor, valid_until); - - if let Some(connection) = opt_new_connection { - connections.insert(poll_token, connection); - } - - if stop_loop { + Err(err) if err.kind() == ErrorKind::WouldBlock => { break; } - } else { - break; + Err(err) => { + ::log::info!("error while accepting streams: {}", err); + } } } } -/// Read messages from channel, send to peers -pub fn send_out_messages( +fn handle_stream_read_event( + config: &Config, + state: &State, + local_responses: &mut Vec<(ConnectionMeta, OutMessage)>, + in_message_sender: &InMessageSender, poll: &mut Poll, - local_responses: Drain<(ConnectionMeta, OutMessage)>, - out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>, connections: &mut ConnectionMap, + token: Token, + valid_until: ValidUntil, ) { - let len = out_message_receiver.len(); + let access_list_mode = config.access_list.mode; - for (meta, out_message) in local_responses.chain(out_message_receiver.try_iter().take(len)) { - let opt_established_ws = connections - .get_mut(&Token(meta.connection_id.0)) - .and_then(Connection::get_established_ws); + if let Some(mut connection) = connections.remove_and_deregister(poll, &token) { + let message_handler = &mut |meta, message| match message { + InMessage::AnnounceRequest(ref request) + if !state + .access_list + .allows(access_list_mode, &request.info_hash.0) => + { + let out_message = OutMessage::ErrorResponse(ErrorResponse { + failure_reason: "Info hash not allowed".into(), + action: Some(ErrorResponseAction::Announce), + info_hash: Some(request.info_hash), + }); - if let Some(established_ws) = opt_established_ws { - if established_ws.peer_addr != meta.naive_peer_addr { - info!("socket worker error: peer socket addrs didn't match"); - - continue; + local_responses.push((meta, out_message)); } - - use ::tungstenite::Error::Io; - - let ws_message = out_message.to_ws_message(); - - match established_ws.ws.write_message(ws_message) { - Ok(()) => { - debug!("sent message"); - } - Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => {} - Err(tungstenite::Error::ConnectionClosed) => { - remove_connection_if_exists(poll, connections, Token(meta.connection_id.0)); - } - Err(err) => { - info!("error writing ws message: {}", err); - - remove_connection_if_exists(poll, connections, Token(meta.connection_id.0)); + in_message => { + if let Err(err) = in_message_sender.send((meta, in_message)) { + ::log::info!("InMessageSender: couldn't send message: {:?}", err); } } + }; + + connection.valid_until = valid_until; + + match connection.read(message_handler) { + Ok(connection) => { + connections.insert_and_register(poll, token, connection); + } + Err(_) => {} } } } + +fn write_or_queue_messages(poll: &mut Poll, responses: I, connections: &mut ConnectionMap) +where + I: Iterator, +{ + for (meta, out_message) in responses { + let token = Token(meta.connection_id.0); + + let mut remove_connection = false; + + if let Some(connection) = connections.get_mut(&token) { + if connection.get_meta().naive_peer_addr != meta.naive_peer_addr { + ::log::warn!( + "socket worker error: connection socket addr {} didn't match channel {}. Token: {}.", + connection.get_meta().naive_peer_addr, + meta.naive_peer_addr, + token.0 + ); + + remove_connection = true; + } else { + match connection.write_or_queue_message(poll, out_message) { + Ok(()) => {} + Err(err) => { + ::log::debug!("Connection::write_or_queue_message error: {}", err); + + remove_connection = true; + } + } + } + } + + if remove_connection { + connections.remove_and_deregister(poll, &token); + } + } +} + +pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> { + let builder = if config.network.address.is_ipv4() { + Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)) + } else { + Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)) + } + .context("Couldn't create socket2::Socket")?; + + if config.network.ipv6_only { + builder + .set_only_v6(true) + .context("Couldn't put socket in ipv6 only mode")? + } + + builder + .set_nonblocking(true) + .context("Couldn't put socket in non-blocking mode")?; + builder + .set_reuse_port(true) + .context("Couldn't put socket in reuse_port mode")?; + builder + .bind(&config.network.address.into()) + .with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?; + builder + .listen(128) + .context("Couldn't listen for connections on socket")?; + + Ok(builder.into()) +} diff --git a/aquatic_ws/src/mio/socket/utils.rs b/aquatic_ws/src/mio/socket/utils.rs deleted file mode 100644 index 2568c97..0000000 --- a/aquatic_ws/src/mio/socket/utils.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::time::Instant; - -use anyhow::Context; -use mio::{Poll, Token}; -use socket2::{Domain, Protocol, Socket, Type}; - -use crate::config::Config; - -use super::connection::*; - -pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> { - let builder = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)) - } else { - Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)) - } - .context("Couldn't create socket2::Socket")?; - - if config.network.ipv6_only { - builder - .set_only_v6(true) - .context("Couldn't put socket in ipv6 only mode")? - } - - builder - .set_nonblocking(true) - .context("Couldn't put socket in non-blocking mode")?; - builder - .set_reuse_port(true) - .context("Couldn't put socket in reuse_port mode")?; - builder - .bind(&config.network.address.into()) - .with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?; - builder - .listen(128) - .context("Couldn't listen for connections on socket")?; - - Ok(builder.into()) -} - -pub fn remove_connection_if_exists(poll: &mut Poll, connections: &mut ConnectionMap, token: Token) { - if let Some(mut connection) = connections.remove(&token) { - connection.close(); - - if let Err(err) = connection.deregister(poll) { - ::log::error!("couldn't deregister stream: {}", err); - } - } -} - -// Close and remove inactive connections -pub fn remove_inactive_connections(connections: &mut ConnectionMap) { - let now = Instant::now(); - - connections.retain(|_, connection| { - if connection.valid_until.0 < now { - connection.close(); - - false - } else { - true - } - }); - - connections.shrink_to_fit(); -} diff --git a/aquatic_ws_load_test/Cargo.toml b/aquatic_ws_load_test/Cargo.toml index a922427..e8ba249 100644 --- a/aquatic_ws_load_test/Cargo.toml +++ b/aquatic_ws_load_test/Cargo.toml @@ -22,6 +22,7 @@ futures = "0.3" futures-rustls = "0.22" glommio = { git = "https://github.com/DataDog/glommio.git", rev = "2efe2f2a08f54394a435b674e8e0125057cbff03" } hashbrown = { version = "0.11", features = ["serde"] } +log = "0.4" mimalloc = { version = "0.1", default-features = false } rand = { version = "0.8", features = ["small_rng"] } rand_distr = "0.4" diff --git a/aquatic_ws_load_test/src/network.rs b/aquatic_ws_load_test/src/network.rs index c3f75c2..62c847d 100644 --- a/aquatic_ws_load_test/src/network.rs +++ b/aquatic_ws_load_test/src/network.rs @@ -105,7 +105,7 @@ impl Connection { *num_active_connections.borrow_mut() += 1; if let Err(err) = connection.run_connection_loop().await { - eprintln!("connection error: {:?}", err); + eprintln!("connection error: {}", err); } *num_active_connections.borrow_mut() -= 1; @@ -159,7 +159,26 @@ impl Connection { } async fn read_message(&mut self) -> anyhow::Result<()> { - match OutMessage::from_ws_message(self.stream.next().await.unwrap()?) { + let message = match self + .stream + .next() + .await + .ok_or_else(|| anyhow::anyhow!("stream finished"))?? + { + message @ tungstenite::Message::Text(_) | message @ tungstenite::Message::Binary(_) => { + message + } + message => { + eprintln!( + "Received WebSocket message of unexpected type: {:?}", + message + ); + + return Ok(()); + } + }; + + match OutMessage::from_ws_message(message) { Ok(OutMessage::Offer(offer)) => { self.load_test_state .statistics @@ -205,7 +224,7 @@ impl Connection { self.can_send = true; } Err(err) => { - eprintln!("error deserializing offer: {:?}", err); + eprintln!("error deserializing message: {:?}", err); } }