From 923b3637e872e9957855ff5340a64d85d98f11a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Fri, 17 Nov 2023 18:16:29 +0100 Subject: [PATCH] http: allow disabling TLS, allow reverse proxies, general fixes --- CHANGELOG.md | 9 + Cargo.lock | 5 +- README.md | 10 +- TODO.md | 3 + crates/http/Cargo.toml | 4 +- crates/http/README.md | 7 +- crates/http/src/common.rs | 6 +- crates/http/src/config.rs | 39 +- crates/http/src/lib.rs | 36 +- crates/http/src/workers/socket.rs | 576 ------------------- crates/http/src/workers/socket/connection.rs | 471 +++++++++++++++ crates/http/src/workers/socket/mod.rs | 212 +++++++ crates/http/src/workers/socket/request.rs | 147 +++++ crates/http_load_test/Cargo.toml | 1 + crates/http_load_test/src/config.rs | 2 + crates/http_load_test/src/main.rs | 12 +- crates/http_load_test/src/network.rs | 77 ++- crates/http_protocol/src/request.rs | 33 +- 18 files changed, 986 insertions(+), 664 deletions(-) delete mode 100644 crates/http/src/workers/socket.rs create mode 100644 crates/http/src/workers/socket/connection.rs create mode 100644 crates/http/src/workers/socket/mod.rs create mode 100644 crates/http/src/workers/socket/request.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index ddefb9d..747bd56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,15 @@ * Reload TLS certificate (and key) on SIGUSR1 +#### Changed + +* Allow running without TLS +* Allow running behind reverse proxy + +#### Fixed + +* Fix bug where clean up after closing connections wasn't always done + ### aquatic_ws #### Added diff --git a/Cargo.lock b/Cargo.lock index 2b2a24c..a20e9cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,7 @@ dependencies = [ "futures-lite", "futures-rustls", "glommio", + "httparse", "itoa", "libc", "log", @@ -140,8 +141,9 @@ dependencies = [ "rustls-pemfile", "serde", "signal-hook", - "slab", + "slotmap", "socket2 0.5.4", + "thiserror", ] [[package]] @@ -152,6 +154,7 @@ dependencies = [ "aquatic_common", "aquatic_http_protocol", "aquatic_toml_config", + "futures", "futures-lite", "futures-rustls", "glommio", diff --git a/README.md b/README.md index 8cdcbb6..e0c2ea0 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,11 @@ of sub-implementations for different protocols: [aquatic_http]: ./crates/http [aquatic_ws]: ./crates/ws -| Name | Protocol | OS requirements | -|----------------|---------------------------------|-----------------| -| [aquatic_udp] | BitTorrent over UDP | Unix-like | -| [aquatic_http] | BitTorrent over HTTP over TLS | Linux 5.8+ | -| [aquatic_ws] | WebTorrent, optionally over TLS | Linux 5.8+ | +| Name | Protocol | OS requirements | +|----------------|-------------------------------------------|-----------------| +| [aquatic_udp] | BitTorrent over UDP | Unix-like | +| [aquatic_http] | BitTorrent over HTTP, optionally over TLS | Linux 5.8+ | +| [aquatic_ws] | WebTorrent, optionally over TLS | Linux 5.8+ | Features at a glance: diff --git a/TODO.md b/TODO.md index da43bc6..abff8c2 100644 --- a/TODO.md +++ b/TODO.md @@ -5,6 +5,9 @@ * aquatic_ws * Validate SDP data +* http + * panic sentinel not working + ## Medium priority * stagger cleaning tasks? diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 8f52edd..2660d46 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -35,6 +35,7 @@ futures = "0.3" futures-lite = "1" futures-rustls = "0.24" glommio = "0.8" +httparse = "1" itoa = "1" libc = "0.2" log = "0.4" @@ -48,8 +49,9 @@ rand = { version = "0.8", features = ["small_rng"] } rustls-pemfile = "1" serde = { version = "1", features = ["derive"] } signal-hook = { version = "0.3" } -slab = "0.4" +slotmap = "1" socket2 = { version = "0.5", features = ["all"] } +thiserror = "1" [dev-dependencies] quickcheck = "1" diff --git a/crates/http/README.md b/crates/http/README.md index eecad8e..fbc0372 100644 --- a/crates/http/README.md +++ b/crates/http/README.md @@ -49,12 +49,9 @@ Generate the configuration file: Make necessary adjustments to the file. You will likely want to adjust `address` (listening address) under the `network` section. -`aquatic_http` __only__ runs over TLS, so configuring certificate and private -key files is required. +To run over TLS, configure certificate and private key files. -Running behind a reverse proxy is currently not supported due to the -[difficulties of determining the originating IP address](https://adam-p.ca/blog/2022/03/x-forwarded-for/) -without knowing the exact setup. +Running behind a reverse proxy is supported. ### Running diff --git a/crates/http/src/common.rs b/crates/http/src/common.rs index 7f2b356..9c326d6 100644 --- a/crates/http/src/common.rs +++ b/crates/http/src/common.rs @@ -10,12 +10,14 @@ use aquatic_http_protocol::{ response::{AnnounceResponse, ScrapeResponse}, }; use glommio::channels::shared_channel::SharedSender; +use slotmap::new_key_type; #[derive(Copy, Clone, Debug)] pub struct ConsumerId(pub usize); -#[derive(Clone, Copy, Debug)] -pub struct ConnectionId(pub usize); +new_key_type! { + pub struct ConnectionId; +} #[derive(Debug)] pub enum ChannelRequest { diff --git a/crates/http/src/config.rs b/crates/http/src/config.rs index 7696bf0..f7a4666 100644 --- a/crates/http/src/config.rs +++ b/crates/http/src/config.rs @@ -5,13 +5,18 @@ use aquatic_common::{ privileges::PrivilegeConfig, }; use aquatic_toml_config::TomlConfig; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use aquatic_common::cli::LogLevel; +#[derive(Clone, Copy, Debug, PartialEq, Serialize, TomlConfig, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ReverseProxyPeerIpHeaderFormat { + #[default] + LastAddress, +} + /// aquatic_http configuration -/// -/// Does not support running behind a reverse proxy. #[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] #[serde(default, deny_unknown_fields)] pub struct Config { @@ -76,29 +81,55 @@ pub struct NetworkConfig { pub only_ipv6: bool, /// Maximum number of pending TCP connections pub tcp_backlog: i32, - /// Path to TLS certificate (DER-encoded X.509) + /// Enable TLS /// /// The TLS files are read on start and when the program receives `SIGUSR1`. /// If initial parsing fails, the program exits. Later failures result in /// in emitting of an error-level log message, while successful updates /// result in emitting of an info-level log message. Updates only affect /// new connections. + pub enable_tls: bool, + /// Path to TLS certificate (DER-encoded X.509) pub tls_certificate_path: PathBuf, /// Path to TLS private key (DER-encoded ASN.1 in PKCS#8 or PKCS#1 format) pub tls_private_key_path: PathBuf, /// Keep connections alive after sending a response pub keep_alive: bool, + /// Does tracker run behind reverse proxy? + /// + /// MUST be set to false if not running behind reverse proxy. + /// + /// If set to true, make sure that reverse_proxy_ip_header_name and + /// reverse_proxy_ip_header_format are set to match your reverse proxy + /// setup. + /// + /// More info on what can go wrong when running behind reverse proxies: + /// https://adam-p.ca/blog/2022/03/x-forwarded-for/ + pub runs_behind_reverse_proxy: bool, + /// Name of header set by reverse proxy to indicate peer ip + pub reverse_proxy_ip_header_name: String, + /// How to extract peer IP from header field + /// + /// Options: + /// - last_address: use the last address in the last instance of the + /// header. Works with typical multi-IP setups (e.g., "X-Forwarded-For") + /// as well as for single-IP setups (e.g., nginx "X-Real-IP") + pub reverse_proxy_ip_header_format: ReverseProxyPeerIpHeaderFormat, } impl Default for NetworkConfig { fn default() -> Self { Self { address: SocketAddr::from(([0, 0, 0, 0], 3000)), + enable_tls: false, tls_certificate_path: "".into(), tls_private_key_path: "".into(), only_ipv6: false, tcp_backlog: 1024, keep_alive: true, + runs_behind_reverse_proxy: false, + reverse_proxy_ip_header_name: "X-Forwarded-For".into(), + reverse_proxy_ip_header_format: Default::default(), } } } diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index d4e1864..3bb0484 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -24,7 +24,7 @@ mod common; pub mod config; mod workers; -pub const APP_NAME: &str = "aquatic_http: BitTorrent tracker (HTTP over TLS)"; +pub const APP_NAME: &str = "aquatic_http: HTTP BitTorrent tracker"; pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); const SHARED_CHANNEL_SIZE: usize = 1024; @@ -58,10 +58,14 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel(); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); - let tls_config = Arc::new(ArcSwap::from_pointee(create_rustls_config( - &config.network.tls_certificate_path, - &config.network.tls_private_key_path, - )?)); + let opt_tls_config = if config.network.enable_tls { + Some(Arc::new(ArcSwap::from_pointee(create_rustls_config( + &config.network.tls_certificate_path, + &config.network.tls_private_key_path, + )?))) + } else { + None + }; let server_start_instant = ServerStartInstant::new(); @@ -71,7 +75,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let sentinel = sentinel.clone(); let config = config.clone(); let state = state.clone(); - let tls_config = tls_config.clone(); + let opt_tls_config = opt_tls_config.clone(); let request_mesh_builder = request_mesh_builder.clone(); let priv_dropper = priv_dropper.clone(); @@ -89,7 +93,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { sentinel, config, state, - tls_config, + opt_tls_config, request_mesh_builder, priv_dropper, server_start_instant, @@ -146,16 +150,18 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { SIGUSR1 => { let _ = update_access_list(&config.access_list, &state.access_list); - match create_rustls_config( - &config.network.tls_certificate_path, - &config.network.tls_private_key_path, - ) { - Ok(config) => { - tls_config.store(Arc::new(config)); + if let Some(tls_config) = opt_tls_config.as_ref() { + match create_rustls_config( + &config.network.tls_certificate_path, + &config.network.tls_private_key_path, + ) { + Ok(config) => { + tls_config.store(Arc::new(config)); - ::log::info!("successfully updated tls config"); + ::log::info!("successfully updated tls config"); + } + Err(err) => ::log::error!("could not update tls config: {:#}", err), } - Err(err) => ::log::error!("could not update tls config: {:#}", err), } } SIGTERM => { diff --git a/crates/http/src/workers/socket.rs b/crates/http/src/workers/socket.rs deleted file mode 100644 index a3751ec..0000000 --- a/crates/http/src/workers/socket.rs +++ /dev/null @@ -1,576 +0,0 @@ -use std::cell::RefCell; -use std::collections::BTreeMap; -use std::os::unix::prelude::{FromRawFd, IntoRawFd}; -use std::rc::Rc; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Context; -use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; -use aquatic_common::privileges::PrivilegeDropper; -use aquatic_common::rustls_config::RustlsConfig; -use aquatic_common::{CanonicalSocketAddr, PanicSentinel, ServerStartInstant}; -use aquatic_http_protocol::common::InfoHash; -use aquatic_http_protocol::request::{Request, RequestParseError, ScrapeRequest}; -use aquatic_http_protocol::response::{ - FailureResponse, Response, ScrapeResponse, ScrapeStatistics, -}; -use arc_swap::ArcSwap; -use either::Either; -use futures::stream::FuturesUnordered; -use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; -use futures_rustls::server::TlsStream; -use futures_rustls::TlsAcceptor; -use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; -use glommio::channels::shared_channel::{self, SharedReceiver}; -use glommio::net::{TcpListener, TcpStream}; -use glommio::task::JoinHandle; -use glommio::timer::TimerActionRepeat; -use glommio::{enclose, prelude::*}; -use once_cell::sync::Lazy; -use slab::Slab; - -use crate::common::*; -use crate::config::Config; - -const REQUEST_BUFFER_SIZE: usize = 2048; -const RESPONSE_BUFFER_SIZE: usize = 4096; - -const RESPONSE_HEADER_A: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Length: "; -const RESPONSE_HEADER_B: &[u8] = b" "; -const RESPONSE_HEADER_C: &[u8] = b"\r\n\r\n"; - -#[cfg(feature = "metrics")] -thread_local! { static WORKER_INDEX: ::std::cell::Cell = Default::default() } - -static RESPONSE_HEADER: Lazy> = - Lazy::new(|| [RESPONSE_HEADER_A, RESPONSE_HEADER_B, RESPONSE_HEADER_C].concat()); - -struct PendingScrapeResponse { - pending_worker_responses: usize, - stats: BTreeMap, -} - -struct ConnectionReference { - task_handle: Option>, - valid_until: ValidUntil, -} - -pub async fn run_socket_worker( - _sentinel: PanicSentinel, - config: Config, - state: State, - tls_config: Arc>, - request_mesh_builder: MeshBuilder, - priv_dropper: PrivilegeDropper, - server_start_instant: ServerStartInstant, - worker_index: usize, -) { - #[cfg(feature = "metrics")] - WORKER_INDEX.with(|index| index.set(worker_index)); - - let config = Rc::new(config); - let access_list = state.access_list; - - let listener = create_tcp_listener(&config, priv_dropper).expect("create tcp listener"); - - let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap(); - let request_senders = Rc::new(request_senders); - - let connection_slab = Rc::new(RefCell::new(Slab::new())); - - TimerActionRepeat::repeat(enclose!((config, connection_slab) move || { - clean_connections( - config.clone(), - connection_slab.clone(), - server_start_instant, - ) - })); - - let mut incoming = listener.incoming(); - - while let Some(stream) = incoming.next().await { - match stream { - Ok(stream) => { - let key = connection_slab.borrow_mut().insert(ConnectionReference { - task_handle: None, - valid_until: ValidUntil::new( - server_start_instant, - config.cleaning.max_connection_idle, - ), - }); - - let task_handle = spawn_local(enclose!((config, access_list, request_senders, tls_config, connection_slab) async move { - let result = match stream.peer_addr() { - Ok(peer_addr) => { - let peer_addr = CanonicalSocketAddr::new(peer_addr); - - #[cfg(feature = "metrics")] - let ip_version_str = peer_addr_to_ip_version_str(&peer_addr); - - #[cfg(feature = "metrics")] - ::metrics::increment_gauge!( - "aquatic_active_connections", - 1.0, - "ip_version" => ip_version_str, - "worker_index" => worker_index.to_string(), - ); - - let result = Connection::run( - config, - access_list, - request_senders, - server_start_instant, - ConnectionId(key), - tls_config, - connection_slab.clone(), - stream, - peer_addr - ).await; - - #[cfg(feature = "metrics")] - ::metrics::decrement_gauge!( - "aquatic_active_connections", - 1.0, - "ip_version" => ip_version_str, - "worker_index" => worker_index.to_string(), - ); - - result - } - Err(err) => { - Err(anyhow::anyhow!("Couldn't get peer addr: {:?}", err)) - } - }; - - if let Err(err) = result { - ::log::debug!("Connection::run() error: {:?}", err); - } - - connection_slab.borrow_mut().try_remove(key); - })) - .detach(); - - if let Some(reference) = connection_slab.borrow_mut().get_mut(key) { - reference.task_handle = Some(task_handle); - } - } - Err(err) => { - ::log::error!("accept connection: {:?}", err); - } - } - } -} - -async fn clean_connections( - config: Rc, - connection_slab: Rc>>, - server_start_instant: ServerStartInstant, -) -> Option { - let now = server_start_instant.seconds_elapsed(); - - connection_slab.borrow_mut().retain(|_, reference| { - if reference.valid_until.valid(now) { - true - } else { - if let Some(ref handle) = reference.task_handle { - handle.cancel(); - } - - false - } - }); - - connection_slab.borrow_mut().shrink_to_fit(); - - Some(Duration::from_secs( - config.cleaning.connection_cleaning_interval, - )) -} - -struct Connection { - config: Rc, - access_list_cache: AccessListCache, - request_senders: Rc>, - connection_slab: Rc>>, - server_start_instant: ServerStartInstant, - stream: TlsStream, - peer_addr: CanonicalSocketAddr, - connection_id: ConnectionId, - request_buffer: [u8; REQUEST_BUFFER_SIZE], - request_buffer_position: usize, - response_buffer: [u8; RESPONSE_BUFFER_SIZE], -} - -impl Connection { - async fn run( - config: Rc, - access_list: Arc, - request_senders: Rc>, - server_start_instant: ServerStartInstant, - connection_id: ConnectionId, - tls_config: Arc>, - connection_slab: Rc>>, - stream: TcpStream, - peer_addr: CanonicalSocketAddr, - ) -> anyhow::Result<()> { - let tls_acceptor: TlsAcceptor = tls_config.load_full().into(); - let stream = tls_acceptor.accept(stream).await?; - - let mut response_buffer = [0; RESPONSE_BUFFER_SIZE]; - - response_buffer[..RESPONSE_HEADER.len()].copy_from_slice(&RESPONSE_HEADER); - - let mut conn = Connection { - config: config.clone(), - access_list_cache: create_access_list_cache(&access_list), - request_senders: request_senders.clone(), - connection_slab, - server_start_instant, - stream, - peer_addr, - connection_id, - request_buffer: [0; REQUEST_BUFFER_SIZE], - request_buffer_position: 0, - response_buffer, - }; - - conn.run_request_response_loop().await?; - - Ok(()) - } - - async fn run_request_response_loop(&mut self) -> anyhow::Result<()> { - loop { - let response = match self.read_request().await? { - Either::Left(response) => Response::Failure(response), - Either::Right(request) => self.handle_request(request).await?, - }; - - self.write_response(&response).await?; - - if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive { - let _ = self - .stream - .get_ref() - .0 - .shutdown(std::net::Shutdown::Both) - .await; - - break; - } - } - - Ok(()) - } - - async fn read_request(&mut self) -> anyhow::Result> { - self.request_buffer_position = 0; - - loop { - if self.request_buffer_position == self.request_buffer.len() { - return Err(anyhow::anyhow!("request buffer is full")); - } - - let bytes_read = self - .stream - .read(&mut self.request_buffer[self.request_buffer_position..]) - .await?; - - if bytes_read == 0 { - return Err(anyhow::anyhow!("peer closed connection")); - } - - self.request_buffer_position += bytes_read; - - match Request::from_bytes(&self.request_buffer[..self.request_buffer_position]) { - Ok(request) => { - return Ok(Either::Right(request)); - } - Err(RequestParseError::Invalid(err)) => { - let response = FailureResponse { - failure_reason: "Invalid request".into(), - }; - - ::log::debug!("Invalid request: {:#}", err); - - return Ok(Either::Left(response)); - } - Err(RequestParseError::NeedMoreData) => { - ::log::debug!( - "need more request data. current data: {}", - &self.request_buffer[..self.request_buffer_position].escape_ascii() - ); - } - } - } - } - - /// Take a request and: - /// - Update connection ValidUntil - /// - Return error response if request is not allowed - /// - If it is an announce request, send it to swarm workers an await a - /// response - /// - If it is a scrape requests, split it up, pass on the parts to - /// relevant swarm workers and await a response - async fn handle_request(&mut self, request: Request) -> anyhow::Result { - if let Ok(mut slab) = self.connection_slab.try_borrow_mut() { - if let Some(reference) = slab.get_mut(self.connection_id.0) { - reference.valid_until = ValidUntil::new( - self.server_start_instant, - self.config.cleaning.max_connection_idle, - ); - } - } - - match request { - Request::Announce(request) => { - #[cfg(feature = "metrics")] - ::metrics::increment_counter!( - "aquatic_requests_total", - "type" => "announce", - "ip_version" => peer_addr_to_ip_version_str(&self.peer_addr), - "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), - ); - - let info_hash = request.info_hash; - - if self - .access_list_cache - .load() - .allows(self.config.access_list.mode, &info_hash.0) - { - let (response_sender, response_receiver) = shared_channel::new_bounded(1); - - let request = ChannelRequest::Announce { - request, - peer_addr: self.peer_addr, - response_sender, - }; - - let consumer_index = calculate_request_consumer_index(&self.config, info_hash); - - // Only fails when receiver is closed - self.request_senders - .send_to(consumer_index, request) - .await - .unwrap(); - - response_receiver - .connect() - .await - .recv() - .await - .ok_or_else(|| anyhow::anyhow!("response sender closed")) - .map(Response::Announce) - } else { - let response = Response::Failure(FailureResponse { - failure_reason: "Info hash not allowed".into(), - }); - - Ok(response) - } - } - Request::Scrape(ScrapeRequest { info_hashes }) => { - #[cfg(feature = "metrics")] - ::metrics::increment_counter!( - "aquatic_requests_total", - "type" => "scrape", - "ip_version" => peer_addr_to_ip_version_str(&self.peer_addr), - "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), - ); - - let mut info_hashes_by_worker: BTreeMap> = BTreeMap::new(); - - for info_hash in info_hashes.into_iter() { - let info_hashes = info_hashes_by_worker - .entry(calculate_request_consumer_index(&self.config, info_hash)) - .or_default(); - - info_hashes.push(info_hash); - } - - let pending_worker_responses = info_hashes_by_worker.len(); - let mut response_receivers = Vec::with_capacity(pending_worker_responses); - - for (consumer_index, info_hashes) in info_hashes_by_worker { - let (response_sender, response_receiver) = shared_channel::new_bounded(1); - - response_receivers.push(response_receiver); - - let request = ChannelRequest::Scrape { - request: ScrapeRequest { info_hashes }, - peer_addr: self.peer_addr, - response_sender, - }; - - // Only fails when receiver is closed - self.request_senders - .send_to(consumer_index, request) - .await - .unwrap(); - } - - let pending_scrape_response = PendingScrapeResponse { - pending_worker_responses, - stats: Default::default(), - }; - - self.wait_for_scrape_responses(response_receivers, pending_scrape_response) - .await - } - } - } - - /// Wait for partial scrape responses to arrive, - /// return full response - async fn wait_for_scrape_responses( - &self, - response_receivers: Vec>, - mut pending: PendingScrapeResponse, - ) -> anyhow::Result { - let mut responses = response_receivers - .into_iter() - .map(|receiver| async { receiver.connect().await.recv().await }) - .collect::>(); - - loop { - let response = responses - .next() - .await - .ok_or_else(|| { - anyhow::anyhow!("stream ended before all partial scrape responses received") - })? - .ok_or_else(|| { - anyhow::anyhow!( - "wait_for_scrape_response: can't receive response, sender is closed" - ) - })?; - - pending.stats.extend(response.files); - pending.pending_worker_responses -= 1; - - if pending.pending_worker_responses == 0 { - let response = Response::Scrape(ScrapeResponse { - files: pending.stats, - }); - - break Ok(response); - } - } - } - - async fn write_response(&mut self, response: &Response) -> anyhow::Result<()> { - // Write body and final newline to response buffer - - let mut position = RESPONSE_HEADER.len(); - - let body_len = response.write(&mut &mut self.response_buffer[position..])?; - - position += body_len; - - if position + 2 > self.response_buffer.len() { - ::log::error!("Response buffer is too short for response"); - - return Err(anyhow::anyhow!("Response buffer is too short for response")); - } - - (&mut self.response_buffer[position..position + 2]).copy_from_slice(b"\r\n"); - - position += 2; - - let content_len = body_len + 2; - - // Clear content-len header value - - { - let start = RESPONSE_HEADER_A.len(); - let end = start + RESPONSE_HEADER_B.len(); - - (&mut self.response_buffer[start..end]).copy_from_slice(RESPONSE_HEADER_B); - } - - // Set content-len header value - - { - let mut buf = ::itoa::Buffer::new(); - let content_len_bytes = buf.format(content_len).as_bytes(); - - let start = RESPONSE_HEADER_A.len(); - let end = start + content_len_bytes.len(); - - (&mut self.response_buffer[start..end]).copy_from_slice(content_len_bytes); - } - - // Write buffer to stream - - self.stream.write(&self.response_buffer[..position]).await?; - self.stream.flush().await?; - - #[cfg(feature = "metrics")] - { - let response_type = match response { - Response::Announce(_) => "announce", - Response::Scrape(_) => "scrape", - Response::Failure(_) => "error", - }; - - ::metrics::increment_counter!( - "aquatic_responses_total", - "type" => response_type, - "ip_version" => peer_addr_to_ip_version_str(&self.peer_addr), - "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), - ); - } - - Ok(()) - } -} - -fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize { - (info_hash.0[0] as usize) % config.swarm_workers -} - -fn create_tcp_listener( - config: &Config, - priv_dropper: PrivilegeDropper, -) -> anyhow::Result { - let domain = if config.network.address.is_ipv4() { - socket2::Domain::IPV4 - } else { - socket2::Domain::IPV6 - }; - - let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?; - - if config.network.only_ipv6 { - socket - .set_only_v6(true) - .with_context(|| "socket: set only ipv6")?; - } - - socket - .set_reuse_port(true) - .with_context(|| "socket: set reuse port")?; - - socket - .bind(&config.network.address.into()) - .with_context(|| format!("socket: bind to {}", config.network.address))?; - - socket - .listen(config.network.tcp_backlog) - .with_context(|| format!("socket: listen on {}", config.network.address))?; - - priv_dropper.after_socket_creation()?; - - Ok(unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) }) -} - -#[cfg(feature = "metrics")] -fn peer_addr_to_ip_version_str(addr: &CanonicalSocketAddr) -> &'static str { - if addr.is_ipv4() { - "4" - } else { - "6" - } -} diff --git a/crates/http/src/workers/socket/connection.rs b/crates/http/src/workers/socket/connection.rs new file mode 100644 index 0000000..7db5146 --- /dev/null +++ b/crates/http/src/workers/socket/connection.rs @@ -0,0 +1,471 @@ +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::net::SocketAddr; +use std::rc::Rc; +use std::sync::Arc; + +use anyhow::Context; +use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; +use aquatic_common::rustls_config::RustlsConfig; +use aquatic_common::{CanonicalSocketAddr, ServerStartInstant}; +use aquatic_http_protocol::common::InfoHash; +use aquatic_http_protocol::request::{Request, ScrapeRequest}; +use aquatic_http_protocol::response::{ + FailureResponse, Response, ScrapeResponse, ScrapeStatistics, +}; +use arc_swap::ArcSwap; +use either::Either; +use futures::stream::FuturesUnordered; +use futures_lite::future::race; +use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; +use futures_rustls::TlsAcceptor; +use glommio::channels::channel_mesh::Senders; +use glommio::channels::local_channel::LocalReceiver; +use glommio::channels::shared_channel::{self, SharedReceiver}; +use glommio::net::TcpStream; +use once_cell::sync::Lazy; + +use crate::common::*; +use crate::config::Config; + +use super::request::{parse_request, RequestParseError}; +#[cfg(feature = "metrics")] +use super::{peer_addr_to_ip_version_str, WORKER_INDEX}; + +const REQUEST_BUFFER_SIZE: usize = 2048; +const RESPONSE_BUFFER_SIZE: usize = 4096; + +const RESPONSE_HEADER_A: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Length: "; +const RESPONSE_HEADER_B: &[u8] = b" "; +const RESPONSE_HEADER_C: &[u8] = b"\r\n\r\n"; + +static RESPONSE_HEADER: Lazy> = + Lazy::new(|| [RESPONSE_HEADER_A, RESPONSE_HEADER_B, RESPONSE_HEADER_C].concat()); + +struct PendingScrapeResponse { + pending_worker_responses: usize, + stats: BTreeMap, +} + +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error("inactive")] + Inactive, + #[error("socket peer addr extraction failed")] + NoSocketPeerAddr(String), + #[error("request buffer full")] + RequestBufferFull, + #[error("response buffer full")] + ResponseBufferFull, + #[error("response buffer write error: {0}")] + ResponseBufferWrite(::std::io::Error), + #[error("peer closed")] + PeerClosed, + #[error("response sender closed")] + ResponseSenderClosed, + #[error("scrape channel error: {0}")] + ScrapeChannelError(&'static str), + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +pub(super) async fn run_connection( + config: Rc, + access_list: Arc, + request_senders: Rc>, + server_start_instant: ServerStartInstant, + opt_tls_config: Option>>, + valid_until: Rc>, + close_conn_receiver: LocalReceiver<()>, + stream: TcpStream, +) -> Result<(), ConnectionError> { + let access_list_cache = create_access_list_cache(&access_list); + let request_buffer = Box::new([0u8; REQUEST_BUFFER_SIZE]); + + let mut response_buffer = Box::new([0; RESPONSE_BUFFER_SIZE]); + + response_buffer[..RESPONSE_HEADER.len()].copy_from_slice(&RESPONSE_HEADER); + + let remote_addr = stream + .peer_addr() + .map_err(|err| ConnectionError::NoSocketPeerAddr(err.to_string()))?; + + let opt_peer_addr = if config.network.runs_behind_reverse_proxy { + None + } else { + Some(CanonicalSocketAddr::new(remote_addr)) + }; + + let peer_port = remote_addr.port(); + + if let Some(tls_config) = opt_tls_config { + let tls_acceptor: TlsAcceptor = tls_config.load_full().into(); + let stream = tls_acceptor + .accept(stream) + .await + .with_context(|| "tls accept")?; + + let mut conn = Connection { + config, + access_list_cache, + request_senders, + valid_until, + server_start_instant, + opt_peer_addr, + peer_port, + request_buffer, + request_buffer_position: 0, + response_buffer, + stream, + }; + + conn.run(close_conn_receiver).await?; + } else { + let mut conn = Connection { + config, + access_list_cache, + request_senders, + valid_until, + server_start_instant, + opt_peer_addr, + peer_port, + request_buffer, + request_buffer_position: 0, + response_buffer, + stream, + }; + + conn.run(close_conn_receiver).await?; + } + + Ok(()) +} + +struct Connection { + config: Rc, + access_list_cache: AccessListCache, + request_senders: Rc>, + valid_until: Rc>, + server_start_instant: ServerStartInstant, + opt_peer_addr: Option, + peer_port: u16, + request_buffer: Box<[u8; REQUEST_BUFFER_SIZE]>, + request_buffer_position: usize, + response_buffer: Box<[u8; RESPONSE_BUFFER_SIZE]>, + stream: S, +} + +impl Connection +where + S: futures::AsyncRead + futures::AsyncWrite + Unpin + 'static, +{ + async fn run(&mut self, close_conn_receiver: LocalReceiver<()>) -> Result<(), ConnectionError> { + let f1 = async { self.run_request_response_loop().await }; + let f2 = async { + close_conn_receiver.recv().await; + + Err(ConnectionError::Inactive) + }; + + race(f1, f2).await + } + + async fn run_request_response_loop(&mut self) -> Result<(), ConnectionError> { + loop { + let response = match self.read_request().await? { + Either::Left(response) => Response::Failure(response), + Either::Right(request) => self.handle_request(request).await?, + }; + + self.write_response(&response).await?; + + if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive { + break; + } + } + + Ok(()) + } + + async fn read_request(&mut self) -> Result, ConnectionError> { + self.request_buffer_position = 0; + + loop { + if self.request_buffer_position == self.request_buffer.len() { + return Err(ConnectionError::RequestBufferFull); + } + + let bytes_read = self + .stream + .read(&mut self.request_buffer[self.request_buffer_position..]) + .await + .with_context(|| "read")?; + + if bytes_read == 0 { + return Err(ConnectionError::PeerClosed); + } + + self.request_buffer_position += bytes_read; + + let buffer_slice = &self.request_buffer[..self.request_buffer_position]; + + match parse_request(&self.config, buffer_slice) { + Ok((request, opt_peer_ip)) => { + if self.config.network.runs_behind_reverse_proxy { + let peer_ip = opt_peer_ip + .expect("logic error: peer ip must have been extracted at this point"); + + self.opt_peer_addr = Some(CanonicalSocketAddr::new(SocketAddr::new( + peer_ip, + self.peer_port, + ))); + } + + return Ok(Either::Right(request)); + } + Err(RequestParseError::MoreDataNeeded) => continue, + Err(RequestParseError::RequiredPeerIpHeaderMissing(err)) => { + panic!("Tracker configured as running behind reverse proxy, but no corresponding IP header set in request. Please check your reverse proxy setup as well as your aquatic configuration. Error: {:#}", err); + } + Err(RequestParseError::Other(err)) => { + ::log::debug!("Failed parsing request: {:#}", err); + + let response = FailureResponse { + failure_reason: "Invalid request".into(), + }; + + return Ok(Either::Left(response)); + } + } + } + } + + /// Take a request and: + /// - Update connection ValidUntil + /// - Return error response if request is not allowed + /// - If it is an announce request, send it to swarm workers an await a + /// response + /// - If it is a scrape requests, split it up, pass on the parts to + /// relevant swarm workers and await a response + async fn handle_request(&mut self, request: Request) -> Result { + let peer_addr = self + .opt_peer_addr + .expect("peer addr should already have been extracted by now"); + + *self.valid_until.borrow_mut() = ValidUntil::new( + self.server_start_instant, + self.config.cleaning.max_connection_idle, + ); + + match request { + Request::Announce(request) => { + #[cfg(feature = "metrics")] + ::metrics::increment_counter!( + "aquatic_requests_total", + "type" => "announce", + "ip_version" => peer_addr_to_ip_version_str(&peer_addr), + "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), + ); + + let info_hash = request.info_hash; + + if self + .access_list_cache + .load() + .allows(self.config.access_list.mode, &info_hash.0) + { + let (response_sender, response_receiver) = shared_channel::new_bounded(1); + + let request = ChannelRequest::Announce { + request, + peer_addr, + response_sender, + }; + + let consumer_index = calculate_request_consumer_index(&self.config, info_hash); + + // Only fails when receiver is closed + self.request_senders + .send_to(consumer_index, request) + .await + .unwrap(); + + response_receiver + .connect() + .await + .recv() + .await + .ok_or(ConnectionError::ResponseSenderClosed) + .map(Response::Announce) + } else { + let response = Response::Failure(FailureResponse { + failure_reason: "Info hash not allowed".into(), + }); + + Ok(response) + } + } + Request::Scrape(ScrapeRequest { info_hashes }) => { + #[cfg(feature = "metrics")] + ::metrics::increment_counter!( + "aquatic_requests_total", + "type" => "scrape", + "ip_version" => peer_addr_to_ip_version_str(&peer_addr), + "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), + ); + + let mut info_hashes_by_worker: BTreeMap> = BTreeMap::new(); + + for info_hash in info_hashes.into_iter() { + let info_hashes = info_hashes_by_worker + .entry(calculate_request_consumer_index(&self.config, info_hash)) + .or_default(); + + info_hashes.push(info_hash); + } + + let pending_worker_responses = info_hashes_by_worker.len(); + let mut response_receivers = Vec::with_capacity(pending_worker_responses); + + for (consumer_index, info_hashes) in info_hashes_by_worker { + let (response_sender, response_receiver) = shared_channel::new_bounded(1); + + response_receivers.push(response_receiver); + + let request = ChannelRequest::Scrape { + request: ScrapeRequest { info_hashes }, + peer_addr, + response_sender, + }; + + // Only fails when receiver is closed + self.request_senders + .send_to(consumer_index, request) + .await + .unwrap(); + } + + let pending_scrape_response = PendingScrapeResponse { + pending_worker_responses, + stats: Default::default(), + }; + + self.wait_for_scrape_responses(response_receivers, pending_scrape_response) + .await + } + } + } + + /// Wait for partial scrape responses to arrive, + /// return full response + async fn wait_for_scrape_responses( + &self, + response_receivers: Vec>, + mut pending: PendingScrapeResponse, + ) -> Result { + let mut responses = response_receivers + .into_iter() + .map(|receiver| async { receiver.connect().await.recv().await }) + .collect::>(); + + loop { + let response = responses + .next() + .await + .ok_or_else(|| { + ConnectionError::ScrapeChannelError( + "stream ended before all partial scrape responses received", + ) + })? + .ok_or_else(|| ConnectionError::ScrapeChannelError("sender is closed"))?; + + pending.stats.extend(response.files); + pending.pending_worker_responses -= 1; + + if pending.pending_worker_responses == 0 { + let response = Response::Scrape(ScrapeResponse { + files: pending.stats, + }); + + break Ok(response); + } + } + } + + async fn write_response(&mut self, response: &Response) -> Result<(), ConnectionError> { + // Write body and final newline to response buffer + + let mut position = RESPONSE_HEADER.len(); + + let body_len = response + .write(&mut &mut self.response_buffer[position..]) + .map_err(|err| ConnectionError::ResponseBufferWrite(err))?; + + position += body_len; + + if position + 2 > self.response_buffer.len() { + return Err(ConnectionError::ResponseBufferFull); + } + + (&mut self.response_buffer[position..position + 2]).copy_from_slice(b"\r\n"); + + position += 2; + + let content_len = body_len + 2; + + // Clear content-len header value + + { + let start = RESPONSE_HEADER_A.len(); + let end = start + RESPONSE_HEADER_B.len(); + + (&mut self.response_buffer[start..end]).copy_from_slice(RESPONSE_HEADER_B); + } + + // Set content-len header value + + { + let mut buf = ::itoa::Buffer::new(); + let content_len_bytes = buf.format(content_len).as_bytes(); + + let start = RESPONSE_HEADER_A.len(); + let end = start + content_len_bytes.len(); + + (&mut self.response_buffer[start..end]).copy_from_slice(content_len_bytes); + } + + // Write buffer to stream + + self.stream + .write(&self.response_buffer[..position]) + .await + .with_context(|| "write")?; + self.stream.flush().await.with_context(|| "flush")?; + + #[cfg(feature = "metrics")] + { + let response_type = match response { + Response::Announce(_) => "announce", + Response::Scrape(_) => "scrape", + Response::Failure(_) => "error", + }; + + let peer_addr = self + .opt_peer_addr + .expect("peer addr should already have been extracted by now"); + + ::metrics::increment_counter!( + "aquatic_responses_total", + "type" => response_type, + "ip_version" => peer_addr_to_ip_version_str(&peer_addr), + "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), + ); + } + + Ok(()) + } +} + +fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize { + (info_hash.0[0] as usize) % config.swarm_workers +} diff --git a/crates/http/src/workers/socket/mod.rs b/crates/http/src/workers/socket/mod.rs new file mode 100644 index 0000000..0baa68b --- /dev/null +++ b/crates/http/src/workers/socket/mod.rs @@ -0,0 +1,212 @@ +mod connection; +mod request; + +use std::cell::RefCell; +use std::os::unix::prelude::{FromRawFd, IntoRawFd}; +use std::rc::Rc; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Context; +use aquatic_common::privileges::PrivilegeDropper; +use aquatic_common::rustls_config::RustlsConfig; +use aquatic_common::{CanonicalSocketAddr, PanicSentinel, ServerStartInstant}; +use arc_swap::ArcSwap; +use futures_lite::StreamExt; +use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role}; +use glommio::channels::local_channel::{new_bounded, LocalSender}; +use glommio::net::TcpListener; +use glommio::timer::TimerActionRepeat; +use glommio::{enclose, prelude::*}; +use slotmap::HopSlotMap; + +use crate::common::*; +use crate::config::Config; +use crate::workers::socket::connection::{run_connection, ConnectionError}; + +#[cfg(feature = "metrics")] +thread_local! { static WORKER_INDEX: ::std::cell::Cell = Default::default() } + +struct ConnectionHandle { + close_conn_sender: LocalSender<()>, + valid_until: Rc>, +} + +pub async fn run_socket_worker( + _sentinel: PanicSentinel, + config: Config, + state: State, + opt_tls_config: Option>>, + request_mesh_builder: MeshBuilder, + priv_dropper: PrivilegeDropper, + server_start_instant: ServerStartInstant, + worker_index: usize, +) { + #[cfg(feature = "metrics")] + WORKER_INDEX.with(|index| index.set(worker_index)); + + let config = Rc::new(config); + let access_list = state.access_list; + + let listener = create_tcp_listener(&config, priv_dropper).expect("create tcp listener"); + + let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap(); + let request_senders = Rc::new(request_senders); + + let connection_handles = Rc::new(RefCell::new(HopSlotMap::with_key())); + + TimerActionRepeat::repeat(enclose!((config, connection_handles) move || { + clean_connections( + config.clone(), + connection_handles.clone(), + server_start_instant, + ) + })); + + let mut incoming = listener.incoming(); + + while let Some(stream) = incoming.next().await { + match stream { + Ok(stream) => { + let (close_conn_sender, close_conn_receiver) = new_bounded(1); + + let valid_until = Rc::new(RefCell::new(ValidUntil::new( + server_start_instant, + config.cleaning.max_connection_idle, + ))); + + let connection_id = connection_handles.borrow_mut().insert(ConnectionHandle { + close_conn_sender, + valid_until: valid_until.clone(), + }); + + spawn_local(enclose!( + ( + config, + access_list, + request_senders, + opt_tls_config, + connection_handles, + valid_until, + ) + async move { + #[cfg(feature = "metrics")] + ::metrics::increment_gauge!( + "aquatic_active_connections", + 1.0, + "worker_index" => worker_index.to_string(), + ); + + let result = run_connection( + config, + access_list, + request_senders, + server_start_instant, + opt_tls_config, + valid_until.clone(), + close_conn_receiver, + stream, + ).await; + + #[cfg(feature = "metrics")] + ::metrics::decrement_gauge!( + "aquatic_active_connections", + 1.0, + "worker_index" => worker_index.to_string(), + ); + + match result { + Ok(()) => (), + Err(err@( + ConnectionError::ResponseBufferWrite(_) | + ConnectionError::ResponseBufferFull | + ConnectionError::ScrapeChannelError(_) | + ConnectionError::ResponseSenderClosed + )) => { + ::log::error!("connection closed: {:#}", err); + } + Err(err@ConnectionError::RequestBufferFull) => { + ::log::info!("connection closed: {:#}", err); + } + Err(err) => { + ::log::debug!("connection closed: {:#}", err); + } + } + + connection_handles.borrow_mut().remove(connection_id); + } + )) + .detach(); + } + Err(err) => { + ::log::error!("accept connection: {:?}", err); + } + } + } +} + +async fn clean_connections( + config: Rc, + connection_slab: Rc>>, + server_start_instant: ServerStartInstant, +) -> Option { + let now = server_start_instant.seconds_elapsed(); + + connection_slab.borrow_mut().retain(|_, handle| { + if handle.valid_until.borrow().valid(now) { + true + } else { + let _ = handle.close_conn_sender.try_send(()); + + false + } + }); + + Some(Duration::from_secs( + config.cleaning.connection_cleaning_interval, + )) +} + +fn create_tcp_listener( + config: &Config, + priv_dropper: PrivilegeDropper, +) -> anyhow::Result { + let domain = if config.network.address.is_ipv4() { + socket2::Domain::IPV4 + } else { + socket2::Domain::IPV6 + }; + + let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?; + + if config.network.only_ipv6 { + socket + .set_only_v6(true) + .with_context(|| "socket: set only ipv6")?; + } + + socket + .set_reuse_port(true) + .with_context(|| "socket: set reuse port")?; + + socket + .bind(&config.network.address.into()) + .with_context(|| format!("socket: bind to {}", config.network.address))?; + + socket + .listen(config.network.tcp_backlog) + .with_context(|| format!("socket: listen on {}", config.network.address))?; + + priv_dropper.after_socket_creation()?; + + Ok(unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) }) +} + +#[cfg(feature = "metrics")] +fn peer_addr_to_ip_version_str(addr: &CanonicalSocketAddr) -> &'static str { + if addr.is_ipv4() { + "4" + } else { + "6" + } +} diff --git a/crates/http/src/workers/socket/request.rs b/crates/http/src/workers/socket/request.rs new file mode 100644 index 0000000..4412382 --- /dev/null +++ b/crates/http/src/workers/socket/request.rs @@ -0,0 +1,147 @@ +use std::net::IpAddr; + +use anyhow::Context; +use aquatic_http_protocol::request::Request; + +use crate::config::{Config, ReverseProxyPeerIpHeaderFormat}; + +#[derive(Debug, thiserror::Error)] +pub enum RequestParseError { + #[error("required peer ip header missing or invalid")] + RequiredPeerIpHeaderMissing(anyhow::Error), + #[error("more data needed")] + MoreDataNeeded, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +pub fn parse_request( + config: &Config, + buffer: &[u8], +) -> Result<(Request, Option), RequestParseError> { + let mut headers = [httparse::EMPTY_HEADER; 16]; + let mut http_request = httparse::Request::new(&mut headers); + + match http_request.parse(buffer).with_context(|| "httparse")? { + httparse::Status::Complete(_) => { + let path = http_request.path.ok_or(anyhow::anyhow!("no http path"))?; + let request = Request::from_http_get_path(path)?; + + let opt_peer_ip = if config.network.runs_behind_reverse_proxy { + let header_name = &config.network.reverse_proxy_ip_header_name; + let header_format = config.network.reverse_proxy_ip_header_format; + + match parse_forwarded_header(header_name, header_format, http_request.headers) { + Ok(peer_ip) => Some(peer_ip), + Err(err) => { + return Err(RequestParseError::RequiredPeerIpHeaderMissing(err)); + } + } + } else { + None + }; + + Ok((request, opt_peer_ip)) + } + httparse::Status::Partial => Err(RequestParseError::MoreDataNeeded), + } +} + +fn parse_forwarded_header( + header_name: &str, + header_format: ReverseProxyPeerIpHeaderFormat, + headers: &[httparse::Header<'_>], +) -> anyhow::Result { + for header in headers.into_iter().rev() { + if header.name == header_name { + match header_format { + ReverseProxyPeerIpHeaderFormat::LastAddress => { + return ::std::str::from_utf8(header.value)? + .split(',') + .last() + .ok_or(anyhow::anyhow!("no header value"))? + .trim() + .parse::() + .with_context(|| "parse ip"); + } + } + } + } + + Err(anyhow::anyhow!("header not present")) +} + +#[cfg(test)] +mod tests { + use super::*; + + const REQUEST_START: &str = "GET /announce?info_hash=%04%0bkV%3f%5cr%14%a6%b7%98%adC%c3%c9.%40%24%00%b9&peer_id=-ABC940-5ert69muw5t8&port=12345&uploaded=1&downloaded=2&left=3&numwant=0&key=4ab4b877&compact=1&supportcrypto=1&event=started HTTP/1.1\r\nHost: example.com\r\n"; + + #[test] + fn test_parse_peer_ip_header_multiple() { + let mut config = Config::default(); + + config.network.runs_behind_reverse_proxy = true; + config.network.reverse_proxy_ip_header_name = "X-Forwarded-For".into(); + config.network.reverse_proxy_ip_header_format = ReverseProxyPeerIpHeaderFormat::LastAddress; + + let mut request = REQUEST_START.to_string(); + + request.push_str("X-Forwarded-For: 200.0.0.1\r\n"); + request.push_str("X-Forwarded-For: 1.2.3.4, 5.6.7.8,9.10.11.12\r\n"); + request.push_str("\r\n"); + + let expected_ip = IpAddr::from([9, 10, 11, 12]); + + assert_eq!( + parse_request(&config, request.as_bytes()) + .unwrap() + .1 + .unwrap(), + expected_ip + ) + } + + #[test] + fn test_parse_peer_ip_header_single() { + let mut config = Config::default(); + + config.network.runs_behind_reverse_proxy = true; + config.network.reverse_proxy_ip_header_name = "X-Forwarded-For".into(); + config.network.reverse_proxy_ip_header_format = ReverseProxyPeerIpHeaderFormat::LastAddress; + + let mut request = REQUEST_START.to_string(); + + request.push_str("X-Forwarded-For: 1.2.3.4, 5.6.7.8,9.10.11.12\r\n"); + request.push_str("X-Forwarded-For: 200.0.0.1\r\n"); + request.push_str("\r\n"); + + let expected_ip = IpAddr::from([200, 0, 0, 1]); + + assert_eq!( + parse_request(&config, request.as_bytes()) + .unwrap() + .1 + .unwrap(), + expected_ip + ) + } + + #[test] + fn test_parse_peer_ip_header_no_header() { + let mut config = Config::default(); + + config.network.runs_behind_reverse_proxy = true; + + let mut request = REQUEST_START.to_string(); + + request.push_str("\r\n"); + + let res = parse_request(&config, request.as_bytes()); + + assert!(matches!( + res, + Err(RequestParseError::RequiredPeerIpHeaderMissing(_)) + )); + } +} diff --git a/crates/http_load_test/Cargo.toml b/crates/http_load_test/Cargo.toml index 291d094..b4f4411 100644 --- a/crates/http_load_test/Cargo.toml +++ b/crates/http_load_test/Cargo.toml @@ -19,6 +19,7 @@ aquatic_http_protocol.workspace = true aquatic_toml_config.workspace = true anyhow = "1" +futures = "0.3" futures-lite = "1" futures-rustls = "0.24" hashbrown = "0.14" diff --git a/crates/http_load_test/src/config.rs b/crates/http_load_test/src/config.rs index 060b401..b41d58a 100644 --- a/crates/http_load_test/src/config.rs +++ b/crates/http_load_test/src/config.rs @@ -23,6 +23,7 @@ pub struct Config { pub url_suffix: String, pub duration: usize, pub keep_alive: bool, + pub enable_tls: bool, pub torrents: TorrentConfig, pub cpu_pinning: CpuPinningConfigDesc, } @@ -44,6 +45,7 @@ impl Default for Config { url_suffix: "".into(), duration: 0, keep_alive: true, + enable_tls: true, torrents: TorrentConfig::default(), cpu_pinning: Default::default(), } diff --git a/crates/http_load_test/src/main.rs b/crates/http_load_test/src/main.rs index 7e3c017..3630124 100644 --- a/crates/http_load_test/src/main.rs +++ b/crates/http_load_test/src/main.rs @@ -59,11 +59,15 @@ fn run(config: Config) -> ::anyhow::Result<()> { gamma: Arc::new(gamma), }; - let tls_config = create_tls_config().unwrap(); + let opt_tls_config = if config.enable_tls { + Some(create_tls_config().unwrap()) + } else { + None + }; for i in 0..config.num_workers { let config = config.clone(); - let tls_config = tls_config.clone(); + let opt_tls_config = opt_tls_config.clone(); let state = state.clone(); let placement = get_worker_placement( @@ -76,7 +80,9 @@ fn run(config: Config) -> ::anyhow::Result<()> { LocalExecutorBuilder::new(placement) .name("load-test") .spawn(move || async move { - run_socket_thread(config, tls_config, state).await.unwrap(); + run_socket_thread(config, opt_tls_config, state) + .await + .unwrap(); }) .unwrap(); } diff --git a/crates/http_load_test/src/network.rs b/crates/http_load_test/src/network.rs index f4d718d..6c1ff12 100644 --- a/crates/http_load_test/src/network.rs +++ b/crates/http_load_test/src/network.rs @@ -9,7 +9,7 @@ use std::{ use aquatic_http_protocol::response::Response; use futures_lite::{AsyncReadExt, AsyncWriteExt}; -use futures_rustls::{client::TlsStream, TlsConnector}; +use futures_rustls::TlsConnector; use glommio::net::TcpStream; use glommio::{prelude::*, timer::TimerActionRepeat}; use rand::{prelude::SmallRng, SeedableRng}; @@ -18,7 +18,7 @@ use crate::{common::LoadTestState, config::Config, utils::create_random_request} pub async fn run_socket_thread( config: Config, - tls_config: Arc, + opt_tls_config: Option>, load_test_state: LoadTestState, ) -> anyhow::Result<()> { let config = Rc::new(config); @@ -30,9 +30,9 @@ pub async fn run_socket_thread( if interval == 0 { loop { if *num_active_connections.borrow() < config.num_connections { - if let Err(err) = Connection::run( + if let Err(err) = run_connection( config.clone(), - tls_config.clone(), + opt_tls_config.clone(), load_test_state.clone(), num_active_connections.clone(), rng.clone(), @@ -50,7 +50,7 @@ pub async fn run_socket_thread( periodically_open_connections( config.clone(), interval, - tls_config.clone(), + opt_tls_config.clone(), load_test_state.clone(), num_active_connections.clone(), rng.clone(), @@ -66,16 +66,16 @@ pub async fn run_socket_thread( async fn periodically_open_connections( config: Rc, interval: Duration, - tls_config: Arc, + opt_tls_config: Option>, load_test_state: LoadTestState, num_active_connections: Rc>, rng: Rc>, ) -> Option { if *num_active_connections.borrow() < config.num_connections { spawn_local(async move { - if let Err(err) = Connection::run( + if let Err(err) = run_connection( config, - tls_config, + opt_tls_config, load_test_state, num_active_connections, rng.clone(), @@ -91,26 +91,18 @@ async fn periodically_open_connections( Some(interval) } -struct Connection { +async fn run_connection( config: Rc, + opt_tls_config: Option>, load_test_state: LoadTestState, + num_active_connections: Rc>, rng: Rc>, - stream: TlsStream, - buffer: [u8; 2048], -} - -impl Connection { - async fn run( - config: Rc, - tls_config: Arc, - load_test_state: LoadTestState, - num_active_connections: Rc>, - rng: Rc>, - ) -> anyhow::Result<()> { - let stream = TcpStream::connect(config.server_address) - .await - .map_err(|err| anyhow::anyhow!("connect: {:?}", err))?; +) -> anyhow::Result<()> { + let stream = TcpStream::connect(config.server_address) + .await + .map_err(|err| anyhow::anyhow!("connect: {:?}", err))?; + if let Some(tls_config) = opt_tls_config { let stream = TlsConnector::from(tls_config) .connect("example.com".try_into().unwrap(), stream) .await?; @@ -120,18 +112,49 @@ impl Connection { load_test_state, rng, stream, - buffer: [0; 2048], + buffer: Box::new([0; 2048]), }; + connection.run(num_active_connections).await?; + } else { + let mut connection = Connection { + config, + load_test_state, + rng, + stream, + buffer: Box::new([0; 2048]), + }; + + connection.run(num_active_connections).await?; + } + + Ok(()) +} + +struct Connection { + config: Rc, + load_test_state: LoadTestState, + rng: Rc>, + stream: S, + buffer: Box<[u8; 2048]>, +} + +impl Connection +where + S: futures::AsyncRead + futures::AsyncWrite + Unpin + 'static, +{ + async fn run(&mut self, num_active_connections: Rc>) -> anyhow::Result<()> { *num_active_connections.borrow_mut() += 1; - if let Err(err) = connection.run_connection_loop().await { + let result = self.run_connection_loop().await; + + if let Err(err) = &result { ::log::info!("connection error: {:?}", err); } *num_active_connections.borrow_mut() -= 1; - Ok(()) + result } async fn run_connection_loop(&mut self) -> anyhow::Result<()> { diff --git a/crates/http_protocol/src/request.rs b/crates/http_protocol/src/request.rs index db4c7e2..865a448 100644 --- a/crates/http_protocol/src/request.rs +++ b/crates/http_protocol/src/request.rs @@ -244,23 +244,6 @@ impl ScrapeRequest { } } -#[derive(Debug)] -pub enum RequestParseError { - NeedMoreData, - Invalid(anyhow::Error), -} - -impl ::std::fmt::Display for RequestParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::NeedMoreData => write!(f, "Incomplete request, more data needed"), - Self::Invalid(err) => write!(f, "Invalid request: {:#}", err), - } - } -} - -impl ::std::error::Error for RequestParseError {} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum Request { Announce(AnnounceRequest), @@ -269,20 +252,20 @@ pub enum Request { impl Request { /// Parse Request from HTTP request bytes - pub fn from_bytes(bytes: &[u8]) -> Result { + pub fn from_bytes(bytes: &[u8]) -> anyhow::Result> { let mut headers = [httparse::EMPTY_HEADER; 16]; let mut http_request = httparse::Request::new(&mut headers); match http_request.parse(bytes) { Ok(httparse::Status::Complete(_)) => { if let Some(path) = http_request.path { - Self::from_http_get_path(path).map_err(RequestParseError::Invalid) + Self::from_http_get_path(path).map(Some) } else { - Err(RequestParseError::Invalid(anyhow::anyhow!("no http path"))) + Err(anyhow::anyhow!("no http path")) } } - Ok(httparse::Status::Partial) => Err(RequestParseError::NeedMoreData), - Err(err) => Err(RequestParseError::Invalid(anyhow::Error::from(err))), + Ok(httparse::Status::Partial) => Ok(None), + Err(err) => Err(anyhow::Error::from(err)), } } @@ -368,7 +351,7 @@ mod tests { bytes.extend_from_slice(&ANNOUNCE_REQUEST_PATH.as_bytes()); bytes.extend_from_slice(b" HTTP/1.1\r\n\r\n"); - let parsed_request = Request::from_bytes(&bytes[..]).unwrap(); + let parsed_request = Request::from_bytes(&bytes[..]).unwrap().unwrap(); let reference_request = get_reference_announce_request(); assert_eq!(parsed_request, reference_request); @@ -382,7 +365,7 @@ mod tests { bytes.extend_from_slice(&SCRAPE_REQUEST_PATH.as_bytes()); bytes.extend_from_slice(b" HTTP/1.1\r\n\r\n"); - let parsed_request = Request::from_bytes(&bytes[..]).unwrap(); + let parsed_request = Request::from_bytes(&bytes[..]).unwrap().unwrap(); let reference_request = Request::Scrape(ScrapeRequest { info_hashes: vec![InfoHash(REFERENCE_INFO_HASH)], }); @@ -449,7 +432,7 @@ mod tests { request.write(&mut bytes, &[]).unwrap(); - let parsed_request = Request::from_bytes(&bytes[..]).unwrap(); + let parsed_request = Request::from_bytes(&bytes[..]).unwrap().unwrap(); let success = request == parsed_request;