diff --git a/Cargo.lock b/Cargo.lock index 826d684..d806f52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,6 +96,7 @@ dependencies = [ "core_affinity", "either", "futures-lite", + "futures-rustls", "glommio", "hashbrown 0.11.2", "indexmap", @@ -108,7 +109,6 @@ dependencies = [ "quickcheck", "quickcheck_macros", "rand", - "rustls", "rustls-pemfile", "serde", "slab", @@ -785,6 +785,17 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d383f0425d991a05e564c2f3ec150bd6dde863179c131dd60d8aa73a05434461" +dependencies = [ + "futures-io", + "rustls", + "webpki", +] + [[package]] name = "generic-array" version = "0.14.4" diff --git a/aquatic_http/Cargo.toml b/aquatic_http/Cargo.toml index e969d83..b3f9dca 100644 --- a/aquatic_http/Cargo.toml +++ b/aquatic_http/Cargo.toml @@ -24,6 +24,7 @@ cfg-if = "1" core_affinity = "0.5" either = "1" futures-lite = "1" +futures-rustls = "0.22" glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5" } hashbrown = "0.11.2" indexmap = "1" @@ -34,7 +35,6 @@ memchr = "2" parking_lot = "0.11" privdrop = "0.5" rand = { version = "0.8", features = ["small_rng"] } -rustls = "0.20" rustls-pemfile = "0.2" serde = { version = "1", features = ["derive"] } slab = "0.4" diff --git a/aquatic_http/src/lib/common.rs b/aquatic_http/src/lib/common.rs index e814c36..df1dfce 100644 --- a/aquatic_http/src/lib/common.rs +++ b/aquatic_http/src/lib/common.rs @@ -27,6 +27,8 @@ use aquatic_http_protocol::{ response::{AnnounceResponse, ScrapeResponse}, }; +pub type TlsConfig = futures_rustls::rustls::ServerConfig; + #[derive(Copy, Clone, Debug)] pub struct ConsumerId(pub usize); diff --git a/aquatic_http/src/lib/lib.rs b/aquatic_http/src/lib/lib.rs index 5a1d900..690df06 100644 --- a/aquatic_http/src/lib/lib.rs +++ b/aquatic_http/src/lib/lib.rs @@ -5,6 +5,7 @@ use std::{ }; use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding}; +use common::TlsConfig; use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; use crate::config::Config; @@ -113,14 +114,14 @@ pub fn run(config: Config) -> anyhow::Result<()> { Ok(()) } -fn create_tls_config(config: &Config) -> anyhow::Result { +fn create_tls_config(config: &Config) -> anyhow::Result { let certs = { let f = File::open(&config.network.tls_certificate_path)?; let mut f = BufReader::new(f); rustls_pemfile::certs(&mut f)? .into_iter() - .map(|bytes| rustls::Certificate(bytes)) + .map(|bytes| futures_rustls::rustls::Certificate(bytes)) .collect() }; @@ -130,11 +131,11 @@ fn create_tls_config(config: &Config) -> anyhow::Result { rustls_pemfile::pkcs8_private_keys(&mut f)? .first() - .map(|bytes| rustls::PrivateKey(bytes.clone())) + .map(|bytes| futures_rustls::rustls::PrivateKey(bytes.clone())) .ok_or(anyhow::anyhow!("No private keys in file"))? }; - let tls_config = rustls::ServerConfig::builder() + let tls_config = futures_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(certs, private_key)?; diff --git a/aquatic_http/src/lib/network.rs b/aquatic_http/src/lib/network.rs index 4efbab7..6970343 100644 --- a/aquatic_http/src/lib/network.rs +++ b/aquatic_http/src/lib/network.rs @@ -1,6 +1,5 @@ use std::cell::RefCell; use std::collections::BTreeMap; -use std::io::{Cursor, ErrorKind, Read, Write}; use std::net::SocketAddr; use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -15,14 +14,14 @@ use aquatic_http_protocol::response::{ }; use either::Either; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; +use futures_rustls::server::TlsStream; +use futures_rustls::TlsAcceptor; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; use glommio::channels::shared_channel::ConnectedReceiver; use glommio::net::{TcpListener, TcpStream}; -use glommio::task::JoinHandle; use glommio::timer::TimerActionRepeat; use glommio::{enclose, prelude::*}; -use rustls::ServerConnection; use slab::Slab; use crate::common::num_digits_in_usize; @@ -40,7 +39,6 @@ struct PendingScrapeResponse { struct ConnectionReference { response_sender: LocalSender, - handle: JoinHandle<()>, } struct Connection { @@ -49,8 +47,8 @@ struct Connection { request_senders: Rc>, response_receiver: LocalReceiver, response_consumer_id: ConsumerId, - tls: ServerConnection, - stream: TcpStream, + stream: TlsStream, + peer_addr: SocketAddr, connection_id: ConnectionId, request_buffer: [u8; MAX_REQUEST_SIZE], request_buffer_position: usize, @@ -58,7 +56,7 @@ struct Connection { pub async fn run_socket_worker( config: Config, - tls_config: Arc, + tls_config: Arc, request_mesh_builder: MeshBuilder, response_mesh_builder: MeshBuilder, num_bound_sockets: Arc, @@ -91,22 +89,11 @@ pub async fn run_socket_worker( // Periodically remove closed connections TimerActionRepeat::repeat( enclose!((config, connection_slab, connections_to_remove) move || { - enclose!((config, connection_slab, connections_to_remove) move || async move { - let connections_to_remove = connections_to_remove.replace(Vec::new()); - - for connection_id in connections_to_remove { - if let Some(_) = connection_slab.borrow_mut().try_remove(connection_id) { - ::log::debug!("removed connection with id {}", connection_id); - } else { - ::log::error!( - "couldn't remove connection with id {}, it is not in connection slab", - connection_id - ); - } - } - - Some(Duration::from_secs(config.cleaning.interval)) - })() + remove_closed_connections( + config.clone(), + connection_slab.clone(), + connections_to_remove.clone(), + ) }), ); @@ -124,41 +111,27 @@ pub async fn run_socket_worker( match stream { Ok(stream) => { let (response_sender, response_receiver) = new_bounded(config.request_workers); + let key = connection_slab + .borrow_mut() + .insert(ConnectionReference { response_sender }); - let mut slab = connection_slab.borrow_mut(); - let entry = slab.vacant_entry(); - let key = entry.key(); - - let mut conn = Connection { - config: config.clone(), - access_list: access_list.clone(), - request_senders: request_senders.clone(), - response_receiver, - response_consumer_id, - tls: ServerConnection::new(tls_config.clone()).unwrap(), - stream, - connection_id: ConnectionId(entry.key()), - request_buffer: [0u8; MAX_REQUEST_SIZE], - request_buffer_position: 0, - }; - - let connections_to_remove = connections_to_remove.clone(); - - let handle = spawn_local(async move { - if let Err(err) = conn.handle_stream().await { - ::log::info!("conn.handle_stream() error: {:?}", err); + spawn_local(enclose!((config, access_list, request_senders, tls_config, connections_to_remove) async move { + if let Err(err) = Connection::run( + config, + access_list, + request_senders, + response_receiver, + response_consumer_id, + ConnectionId(key), + tls_config, + stream + ).await { + ::log::debug!("Connection::run() error: {:?}", err); } connections_to_remove.borrow_mut().push(key); - }) + })) .detach(); - - let connection_reference = ConnectionReference { - response_sender, - handle, - }; - - entry.insert(connection_reference); } Err(err) => { ::log::error!("accept connection: {:?}", err); @@ -167,6 +140,27 @@ pub async fn run_socket_worker( } } +async fn remove_closed_connections( + config: Rc, + connection_slab: Rc>>, + connections_to_remove: Rc>>, +) -> Option { + let connections_to_remove = connections_to_remove.replace(Vec::new()); + + for connection_id in connections_to_remove { + if let Some(_) = connection_slab.borrow_mut().try_remove(connection_id) { + ::log::debug!("removed connection with id {}", connection_id); + } else { + ::log::error!( + "couldn't remove connection with id {}, it is not in connection slab", + connection_id + ); + } + } + + Some(Duration::from_secs(config.cleaning.interval)) +} + async fn receive_responses( mut response_receiver: ConnectedReceiver, connection_references: Rc>>, @@ -184,41 +178,57 @@ async fn receive_responses( } impl Connection { - async fn handle_stream(&mut self) -> anyhow::Result<()> { - let mut close_after_writing = false; + async fn run( + config: Rc, + access_list: Rc>, + request_senders: Rc>, + response_receiver: LocalReceiver, + response_consumer_id: ConsumerId, + connection_id: ConnectionId, + tls_config: Arc, + stream: TcpStream, + ) -> anyhow::Result<()> { + let peer_addr = stream + .peer_addr() + .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?; + let tls_acceptor: TlsAcceptor = tls_config.into(); + let stream = tls_acceptor.accept(stream).await?; + + let mut conn = Connection { + config: config.clone(), + access_list: access_list.clone(), + request_senders: request_senders.clone(), + response_receiver, + response_consumer_id, + stream, + peer_addr, + connection_id, + request_buffer: [0u8; MAX_REQUEST_SIZE], + request_buffer_position: 0, + }; + + conn.run_request_response_loop().await?; + + Ok(()) + } + + async fn run_request_response_loop(&mut self) -> anyhow::Result<()> { loop { - match self.read_tls().await? { - Some(Either::Left(request)) => { - let response = match self.handle_request(request).await? { - Some(Either::Left(response)) => response, - Some(Either::Right(pending_scrape_response)) => { - self.wait_for_response(Some(pending_scrape_response)) - .await? - } - None => self.wait_for_response(None).await?, - }; + let response = match self.read_request().await? { + Either::Left(response) => Response::Failure(response), + Either::Right(request) => self.handle_request(request).await?, + }; - self.queue_response(&response)?; + self.write_response(&response).await?; - if !self.config.network.keep_alive { - close_after_writing = true; - } - } - Some(Either::Right(response)) => { - self.queue_response(&Response::Failure(response))?; - - close_after_writing = true; - } - None => { - // Still handshaking - } - } - - self.write_tls().await?; - - if close_after_writing { - let _ = self.stream.shutdown(std::net::Shutdown::Both).await; + if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive { + let _ = self + .stream + .get_ref() + .0 + .shutdown(std::net::Shutdown::Both) + .await; break; } @@ -227,73 +237,37 @@ impl Connection { Ok(()) } - async fn read_tls(&mut self) -> anyhow::Result>> { - loop { - ::log::debug!("read_tls"); + async fn read_request(&mut self) -> anyhow::Result> { + let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE]; - let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE]; + loop { + ::log::debug!("read"); let bytes_read = self.stream.read(&mut buf).await?; if bytes_read == 0 { - return Err(anyhow::anyhow!("Peer has closed connection")); + return Err(anyhow::anyhow!("peer closed connection")); } - let _ = self.tls.read_tls(&mut &buf[..bytes_read]).unwrap(); + let request_buffer_end = self.request_buffer_position + bytes_read; - let io_state = self.tls.process_new_packets()?; + if request_buffer_end > self.request_buffer.len() { + return Err(anyhow::anyhow!("request too large")); + } else { + let request_buffer_slice = + &mut self.request_buffer[self.request_buffer_position..request_buffer_end]; - let mut added_plaintext = false; + request_buffer_slice.copy_from_slice(&buf[..bytes_read]); - if io_state.plaintext_bytes_to_read() != 0 { - loop { - match self.tls.reader().read(&mut buf) { - Ok(0) => { - break; - } - Ok(amt) => { - let end = self.request_buffer_position + amt; + self.request_buffer_position = request_buffer_end; - if end > self.request_buffer.len() { - return Err(anyhow::anyhow!("request too large")); - } else { - let request_buffer_slice = - &mut self.request_buffer[self.request_buffer_position..end]; - - request_buffer_slice.copy_from_slice(&buf[..amt]); - - self.request_buffer_position = end; - - added_plaintext = true; - } - } - Err(err) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(err) => { - // Should never happen - ::log::error!("tls.reader().read error: {:?}", err); - - break; - } - } - } - } - - if added_plaintext { match Request::from_bytes(&self.request_buffer[..self.request_buffer_position]) { Ok(request) => { ::log::debug!("received request: {:?}", request); self.request_buffer_position = 0; - return Ok(Some(Either::Left(request))); - } - Err(RequestParseError::NeedMoreData) => { - ::log::debug!( - "need more request data. current data: {:?}", - std::str::from_utf8(&self.request_buffer) - ); + return Ok(Either::Right(request)); } Err(RequestParseError::Invalid(err)) => { ::log::debug!("invalid request: {:?}", err); @@ -302,50 +276,28 @@ impl Connection { failure_reason: "Invalid request".into(), }; - return Ok(Some(Either::Right(response))); + return Ok(Either::Left(response)); + } + Err(RequestParseError::NeedMoreData) => { + ::log::debug!( + "need more request data. current data: {:?}", + std::str::from_utf8( + &self.request_buffer[..self.request_buffer_position] + ) + ); } } } - - if self.tls.wants_write() { - break; - } } - - Ok(None) - } - - async fn write_tls(&mut self) -> anyhow::Result<()> { - if !self.tls.wants_write() { - return Ok(()); - } - - ::log::debug!("write_tls (wants write)"); - - let mut buf = Vec::new(); - let mut buf = Cursor::new(&mut buf); - - while self.tls.wants_write() { - self.tls.write_tls(&mut buf).unwrap(); - } - - self.stream.write_all(&buf.into_inner()).await?; - self.stream.flush().await?; - - Ok(()) } /// Take a request and: /// - Return error response if request is not allowed - /// - If it is an announce requests, pass it on to request workers and return None - /// - If it is a scrape requests, split it up and pass on parts to - /// relevant request workers, and return PendingScrapeResponse struct. - async fn handle_request( - &self, - request: Request, - ) -> anyhow::Result>> { - let peer_addr = self.get_peer_addr()?; - + /// - If it is an announce request, send it to request workers an await a + /// response + /// - If it is a scrape requests, split it up, pass on the parts to + /// relevant request workers and await a response + async fn handle_request(&self, request: Request) -> anyhow::Result { match request { Request::Announce(request) => { let info_hash = request.info_hash; @@ -359,7 +311,7 @@ impl Connection { request, connection_id: self.connection_id, response_consumer_id: self.response_consumer_id, - peer_addr, + peer_addr: self.peer_addr, }; let consumer_index = calculate_request_consumer_index(&self.config, info_hash); @@ -370,13 +322,13 @@ impl Connection { .await .unwrap(); - Ok(None) + self.wait_for_response(None).await } else { let response = Response::Failure(FailureResponse { failure_reason: "Info hash not allowed".into(), }); - Ok(Some(Either::Left(response))) + Ok(response) } } Request::Scrape(ScrapeRequest { info_hashes }) => { @@ -395,7 +347,7 @@ impl Connection { for (consumer_index, info_hashes) in info_hashes_by_worker { let request = ChannelRequest::Scrape { request: ScrapeRequest { info_hashes }, - peer_addr, + peer_addr: self.peer_addr, response_consumer_id: self.response_consumer_id, connection_id: self.connection_id, }; @@ -412,7 +364,7 @@ impl Connection { stats: Default::default(), }; - Ok(Some(Either::Right(pending_scrape_response))) + self.wait_for_response(Some(pending_scrape_response)).await } } } @@ -424,46 +376,45 @@ impl Connection { mut opt_pending_scrape_response: Option, ) -> anyhow::Result { loop { - if let Some(channel_response) = self.response_receiver.recv().await { - if channel_response.get_peer_addr() != self.get_peer_addr()? { - return Err(anyhow::anyhow!("peer addressess didn't match")); - } + let channel_response = self + .response_receiver + .recv() + .await + .expect("wait_for_response: can't receive response, sender is closed"); - match channel_response { - ChannelResponse::Announce { response, .. } => { - break Ok(Response::Announce(response)); - } - ChannelResponse::Scrape { response, .. } => { - if let Some(mut pending) = opt_pending_scrape_response.take() { - pending.stats.extend(response.files); - pending.pending_worker_responses -= 1; - - if pending.pending_worker_responses == 0 { - let response = Response::Scrape(ScrapeResponse { - files: pending.stats, - }); - - break Ok(response); - } else { - opt_pending_scrape_response = Some(pending); - } - } else { - return Err(anyhow::anyhow!( - "received channel scrape response without pending scrape response" - )); - } - } - }; - } else { - // TODO: this is a serious error condition and should maybe be handled differently - return Err(anyhow::anyhow!( - "response receiver can't receive - sender is closed" - )); + if channel_response.get_peer_addr() != self.peer_addr { + return Err(anyhow::anyhow!("peer addresses didn't match")); } + + match channel_response { + ChannelResponse::Announce { response, .. } => { + break Ok(Response::Announce(response)); + } + ChannelResponse::Scrape { response, .. } => { + if let Some(mut pending) = opt_pending_scrape_response.take() { + pending.stats.extend(response.files); + pending.pending_worker_responses -= 1; + + if pending.pending_worker_responses == 0 { + let response = Response::Scrape(ScrapeResponse { + files: pending.stats, + }); + + break Ok(response); + } else { + opt_pending_scrape_response = Some(pending); + } + } else { + return Err(anyhow::anyhow!( + "received channel scrape response without pending scrape response" + )); + } + } + }; } } - fn queue_response(&mut self, response: &Response) -> anyhow::Result<()> { + async fn write_response(&mut self, response: &Response) -> anyhow::Result<()> { let mut body = Vec::new(); response.write(&mut body).unwrap(); @@ -479,16 +430,11 @@ impl Connection { response_bytes.append(&mut body); response_bytes.extend_from_slice(b"\r\n"); - self.tls.writer().write(&response_bytes[..])?; + self.stream.write(&response_bytes[..]).await?; + self.stream.flush().await?; Ok(()) } - - fn get_peer_addr(&self) -> anyhow::Result { - self.stream - .peer_addr() - .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err)) - } } fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize { diff --git a/aquatic_http_protocol/src/request.rs b/aquatic_http_protocol/src/request.rs index eeefa12..3197fa7 100644 --- a/aquatic_http_protocol/src/request.rs +++ b/aquatic_http_protocol/src/request.rs @@ -363,9 +363,7 @@ mod tests { return TestResult::discard(); } } - Request::Scrape(ScrapeRequest { - ref info_hashes, - }) => { + Request::Scrape(ScrapeRequest { ref info_hashes }) => { if info_hashes.is_empty() { return TestResult::discard(); }