From afce23e321354d850de7bd35e6d685d986b230bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Mon, 1 Nov 2021 01:28:48 +0100 Subject: [PATCH] aquatic_http: glommio: use futures-rustls --- Cargo.lock | 13 ++- aquatic_http/Cargo.toml | 2 +- aquatic_http/src/lib/common.rs | 2 + aquatic_http/src/lib/lib.rs | 9 +- aquatic_http/src/lib/network.rs | 201 ++++++++++---------------------- 5 files changed, 84 insertions(+), 143 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 826d684..d806f52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,6 +96,7 @@ dependencies = [ "core_affinity", "either", "futures-lite", + "futures-rustls", "glommio", "hashbrown 0.11.2", "indexmap", @@ -108,7 +109,6 @@ dependencies = [ "quickcheck", "quickcheck_macros", "rand", - "rustls", "rustls-pemfile", "serde", "slab", @@ -785,6 +785,17 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d383f0425d991a05e564c2f3ec150bd6dde863179c131dd60d8aa73a05434461" +dependencies = [ + "futures-io", + "rustls", + "webpki", +] + [[package]] name = "generic-array" version = "0.14.4" diff --git a/aquatic_http/Cargo.toml b/aquatic_http/Cargo.toml index e969d83..b3f9dca 100644 --- a/aquatic_http/Cargo.toml +++ b/aquatic_http/Cargo.toml @@ -24,6 +24,7 @@ cfg-if = "1" core_affinity = "0.5" either = "1" futures-lite = "1" +futures-rustls = "0.22" glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5" } hashbrown = "0.11.2" indexmap = "1" @@ -34,7 +35,6 @@ memchr = "2" parking_lot = "0.11" privdrop = "0.5" rand = { version = "0.8", features = ["small_rng"] } -rustls = "0.20" rustls-pemfile = "0.2" serde = { version = "1", features = ["derive"] } slab = "0.4" diff --git a/aquatic_http/src/lib/common.rs b/aquatic_http/src/lib/common.rs index e814c36..df1dfce 100644 --- a/aquatic_http/src/lib/common.rs +++ b/aquatic_http/src/lib/common.rs @@ -27,6 +27,8 @@ use aquatic_http_protocol::{ response::{AnnounceResponse, ScrapeResponse}, }; +pub type TlsConfig = futures_rustls::rustls::ServerConfig; + #[derive(Copy, Clone, Debug)] pub struct ConsumerId(pub usize); diff --git a/aquatic_http/src/lib/lib.rs b/aquatic_http/src/lib/lib.rs index 5a1d900..690df06 100644 --- a/aquatic_http/src/lib/lib.rs +++ b/aquatic_http/src/lib/lib.rs @@ -5,6 +5,7 @@ use std::{ }; use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding}; +use common::TlsConfig; use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; use crate::config::Config; @@ -113,14 +114,14 @@ pub fn run(config: Config) -> anyhow::Result<()> { Ok(()) } -fn create_tls_config(config: &Config) -> anyhow::Result { +fn create_tls_config(config: &Config) -> anyhow::Result { let certs = { let f = File::open(&config.network.tls_certificate_path)?; let mut f = BufReader::new(f); rustls_pemfile::certs(&mut f)? .into_iter() - .map(|bytes| rustls::Certificate(bytes)) + .map(|bytes| futures_rustls::rustls::Certificate(bytes)) .collect() }; @@ -130,11 +131,11 @@ fn create_tls_config(config: &Config) -> anyhow::Result { rustls_pemfile::pkcs8_private_keys(&mut f)? .first() - .map(|bytes| rustls::PrivateKey(bytes.clone())) + .map(|bytes| futures_rustls::rustls::PrivateKey(bytes.clone())) .ok_or(anyhow::anyhow!("No private keys in file"))? }; - let tls_config = rustls::ServerConfig::builder() + let tls_config = futures_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(certs, private_key)?; diff --git a/aquatic_http/src/lib/network.rs b/aquatic_http/src/lib/network.rs index 4efbab7..dbfea64 100644 --- a/aquatic_http/src/lib/network.rs +++ b/aquatic_http/src/lib/network.rs @@ -1,6 +1,5 @@ use std::cell::RefCell; use std::collections::BTreeMap; -use std::io::{Cursor, ErrorKind, Read, Write}; use std::net::SocketAddr; use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -15,6 +14,8 @@ use aquatic_http_protocol::response::{ }; use either::Either; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; +use futures_rustls::TlsAcceptor; +use futures_rustls::server::TlsStream; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; use glommio::channels::shared_channel::ConnectedReceiver; @@ -22,7 +23,6 @@ use glommio::net::{TcpListener, TcpStream}; use glommio::task::JoinHandle; use glommio::timer::TimerActionRepeat; use glommio::{enclose, prelude::*}; -use rustls::ServerConnection; use slab::Slab; use crate::common::num_digits_in_usize; @@ -49,8 +49,7 @@ struct Connection { request_senders: Rc>, response_receiver: LocalReceiver, response_consumer_id: ConsumerId, - tls: ServerConnection, - stream: TcpStream, + tls_acceptor: TlsAcceptor, connection_id: ConnectionId, request_buffer: [u8; MAX_REQUEST_SIZE], request_buffer_position: usize, @@ -58,7 +57,7 @@ struct Connection { pub async fn run_socket_worker( config: Config, - tls_config: Arc, + tls_config: Arc, request_mesh_builder: MeshBuilder, response_mesh_builder: MeshBuilder, num_bound_sockets: Arc, @@ -135,8 +134,7 @@ pub async fn run_socket_worker( request_senders: request_senders.clone(), response_receiver, response_consumer_id, - tls: ServerConnection::new(tls_config.clone()).unwrap(), - stream, + tls_acceptor: tls_config.clone().into(), connection_id: ConnectionId(entry.key()), request_buffer: [0u8; MAX_REQUEST_SIZE], request_buffer_position: 0, @@ -145,7 +143,7 @@ pub async fn run_socket_worker( let connections_to_remove = connections_to_remove.clone(); let handle = spawn_local(async move { - if let Err(err) = conn.handle_stream().await { + if let Err(err) = conn.handle_stream(stream).await { ::log::info!("conn.handle_stream() error: {:?}", err); } @@ -184,41 +182,27 @@ async fn receive_responses( } impl Connection { - async fn handle_stream(&mut self) -> anyhow::Result<()> { - let mut close_after_writing = false; + async fn handle_stream(&mut self, stream: TcpStream) -> anyhow::Result<()> { + let peer_addr = Self::get_peer_addr(&stream)?; + let mut stream = self.tls_acceptor.accept(stream).await?; loop { - match self.read_tls().await? { - Some(Either::Left(request)) => { - let response = match self.handle_request(request).await? { - Some(Either::Left(response)) => response, - Some(Either::Right(pending_scrape_response)) => { - self.wait_for_response(Some(pending_scrape_response)) - .await? + let response = match self.read_request(&mut stream).await? { + Either::Left(response) => Response::Failure(response), + Either::Right(request) => { + match self.handle_request(request, peer_addr).await? { + Either::Left(response) => Response::Failure(response), + Either::Right(opt_pending_scrape_response) => { + self.wait_for_response(peer_addr, opt_pending_scrape_response).await? } - None => self.wait_for_response(None).await?, - }; - - self.queue_response(&response)?; - - if !self.config.network.keep_alive { - close_after_writing = true; } } - Some(Either::Right(response)) => { - self.queue_response(&Response::Failure(response))?; + }; - close_after_writing = true; - } - None => { - // Still handshaking - } - } + self.write_response(&mut stream, &response).await?; - self.write_tls().await?; - - if close_after_writing { - let _ = self.stream.shutdown(std::net::Shutdown::Both).await; + if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive { + Self::close_stream(stream).await; break; } @@ -227,73 +211,42 @@ impl Connection { Ok(()) } - async fn read_tls(&mut self) -> anyhow::Result>> { + fn get_peer_addr(stream: &TcpStream) -> anyhow::Result { + stream + .peer_addr() + .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err)) + } + + async fn close_stream(stream: TlsStream) { + let _ = stream.get_ref().0.shutdown(std::net::Shutdown::Both).await; + } + + async fn read_request(&mut self, stream: &mut TlsStream) -> anyhow::Result> { + let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE]; + loop { - ::log::debug!("read_tls"); + ::log::debug!("read"); - let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE]; + let bytes_read = stream.read(&mut buf).await?; + let request_buffer_end = self.request_buffer_position + bytes_read; - let bytes_read = self.stream.read(&mut buf).await?; + if request_buffer_end > self.request_buffer.len() { + return Err(anyhow::anyhow!("request too large")); + } else { + let request_buffer_slice = + &mut self.request_buffer[self.request_buffer_position..request_buffer_end]; - if bytes_read == 0 { - return Err(anyhow::anyhow!("Peer has closed connection")); - } + request_buffer_slice.copy_from_slice(&buf[..bytes_read]); - let _ = self.tls.read_tls(&mut &buf[..bytes_read]).unwrap(); + self.request_buffer_position = request_buffer_end; - let io_state = self.tls.process_new_packets()?; - - let mut added_plaintext = false; - - if io_state.plaintext_bytes_to_read() != 0 { - loop { - match self.tls.reader().read(&mut buf) { - Ok(0) => { - break; - } - Ok(amt) => { - let end = self.request_buffer_position + amt; - - if end > self.request_buffer.len() { - return Err(anyhow::anyhow!("request too large")); - } else { - let request_buffer_slice = - &mut self.request_buffer[self.request_buffer_position..end]; - - request_buffer_slice.copy_from_slice(&buf[..amt]); - - self.request_buffer_position = end; - - added_plaintext = true; - } - } - Err(err) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(err) => { - // Should never happen - ::log::error!("tls.reader().read error: {:?}", err); - - break; - } - } - } - } - - if added_plaintext { match Request::from_bytes(&self.request_buffer[..self.request_buffer_position]) { Ok(request) => { ::log::debug!("received request: {:?}", request); self.request_buffer_position = 0; - return Ok(Some(Either::Left(request))); - } - Err(RequestParseError::NeedMoreData) => { - ::log::debug!( - "need more request data. current data: {:?}", - std::str::from_utf8(&self.request_buffer) - ); + return Ok(Either::Right(request)); } Err(RequestParseError::Invalid(err)) => { ::log::debug!("invalid request: {:?}", err); @@ -302,50 +255,29 @@ impl Connection { failure_reason: "Invalid request".into(), }; - return Ok(Some(Either::Right(response))); + return Ok(Either::Left(response)); + } + Err(RequestParseError::NeedMoreData) => { + ::log::debug!( + "need more request data. current data: {:?}", + std::str::from_utf8(&self.request_buffer) + ); } } } - - if self.tls.wants_write() { - break; - } } - - Ok(None) - } - - async fn write_tls(&mut self) -> anyhow::Result<()> { - if !self.tls.wants_write() { - return Ok(()); - } - - ::log::debug!("write_tls (wants write)"); - - let mut buf = Vec::new(); - let mut buf = Cursor::new(&mut buf); - - while self.tls.wants_write() { - self.tls.write_tls(&mut buf).unwrap(); - } - - self.stream.write_all(&buf.into_inner()).await?; - self.stream.flush().await?; - - Ok(()) } /// Take a request and: /// - Return error response if request is not allowed - /// - If it is an announce requests, pass it on to request workers and return None + /// - If it is an announce requests, pass it on to request workers and return Either::Right(None) /// - If it is a scrape requests, split it up and pass on parts to - /// relevant request workers, and return PendingScrapeResponse struct. + /// relevant request workers, and return Either::Right(Some(PendingScrapeResponse)). async fn handle_request( &self, request: Request, - ) -> anyhow::Result>> { - let peer_addr = self.get_peer_addr()?; - + peer_addr: SocketAddr, + ) -> anyhow::Result>> { match request { Request::Announce(request) => { let info_hash = request.info_hash; @@ -370,13 +302,13 @@ impl Connection { .await .unwrap(); - Ok(None) + Ok(Either::Right(None)) } else { - let response = Response::Failure(FailureResponse { + let response = FailureResponse { failure_reason: "Info hash not allowed".into(), - }); + }; - Ok(Some(Either::Left(response))) + Ok(Either::Left(response)) } } Request::Scrape(ScrapeRequest { info_hashes }) => { @@ -412,7 +344,7 @@ impl Connection { stats: Default::default(), }; - Ok(Some(Either::Right(pending_scrape_response))) + Ok(Either::Right(Some(pending_scrape_response))) } } } @@ -421,11 +353,12 @@ impl Connection { /// return full response async fn wait_for_response( &self, + peer_addr: SocketAddr, mut opt_pending_scrape_response: Option, ) -> anyhow::Result { loop { if let Some(channel_response) = self.response_receiver.recv().await { - if channel_response.get_peer_addr() != self.get_peer_addr()? { + if channel_response.get_peer_addr() != peer_addr { return Err(anyhow::anyhow!("peer addressess didn't match")); } @@ -463,7 +396,7 @@ impl Connection { } } - fn queue_response(&mut self, response: &Response) -> anyhow::Result<()> { + async fn write_response(&mut self, stream: &mut TlsStream, response: &Response) -> anyhow::Result<()> { let mut body = Vec::new(); response.write(&mut body).unwrap(); @@ -479,16 +412,10 @@ impl Connection { response_bytes.append(&mut body); response_bytes.extend_from_slice(b"\r\n"); - self.tls.writer().write(&response_bytes[..])?; + stream.write(&response_bytes[..]).await?; Ok(()) } - - fn get_peer_addr(&self) -> anyhow::Result { - self.stream - .peer_addr() - .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err)) - } } fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize {