From edec526d419b0cadb53662ad4bd596a6c71ea887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Mon, 1 Nov 2021 01:53:00 +0100 Subject: [PATCH] aquatic_http: glommio: clean up Connection code, stop storing handle --- aquatic_http/src/lib/network.rs | 142 +++++++++++++++++--------------- 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/aquatic_http/src/lib/network.rs b/aquatic_http/src/lib/network.rs index dbfea64..966495c 100644 --- a/aquatic_http/src/lib/network.rs +++ b/aquatic_http/src/lib/network.rs @@ -14,13 +14,12 @@ use aquatic_http_protocol::response::{ }; use either::Either; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; -use futures_rustls::TlsAcceptor; use futures_rustls::server::TlsStream; +use futures_rustls::TlsAcceptor; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; use glommio::channels::shared_channel::ConnectedReceiver; use glommio::net::{TcpListener, TcpStream}; -use glommio::task::JoinHandle; use glommio::timer::TimerActionRepeat; use glommio::{enclose, prelude::*}; use slab::Slab; @@ -40,7 +39,6 @@ struct PendingScrapeResponse { struct ConnectionReference { response_sender: LocalSender, - handle: JoinHandle<()>, } struct Connection { @@ -49,7 +47,8 @@ struct Connection { request_senders: Rc>, response_receiver: LocalReceiver, response_consumer_id: ConsumerId, - tls_acceptor: TlsAcceptor, + stream: TlsStream, + peer_addr: SocketAddr, connection_id: ConnectionId, request_buffer: [u8; MAX_REQUEST_SIZE], request_buffer_position: usize, @@ -123,40 +122,27 @@ pub async fn run_socket_worker( match stream { Ok(stream) => { let (response_sender, response_receiver) = new_bounded(config.request_workers); + let key = connection_slab + .borrow_mut() + .insert(ConnectionReference { response_sender }); - let mut slab = connection_slab.borrow_mut(); - let entry = slab.vacant_entry(); - let key = entry.key(); - - let mut conn = Connection { - config: config.clone(), - access_list: access_list.clone(), - request_senders: request_senders.clone(), - response_receiver, - response_consumer_id, - tls_acceptor: tls_config.clone().into(), - connection_id: ConnectionId(entry.key()), - request_buffer: [0u8; MAX_REQUEST_SIZE], - request_buffer_position: 0, - }; - - let connections_to_remove = connections_to_remove.clone(); - - let handle = spawn_local(async move { - if let Err(err) = conn.handle_stream(stream).await { - ::log::info!("conn.handle_stream() error: {:?}", err); + spawn_local(enclose!((config, access_list, request_senders, tls_config, connections_to_remove) async move { + if let Err(err) = Connection::run( + config, + access_list, + request_senders, + response_receiver, + response_consumer_id, + ConnectionId(key), + tls_config, + stream + ).await { + ::log::info!("Connection::run() error: {:?}", err); } connections_to_remove.borrow_mut().push(key); - }) + })) .detach(); - - let connection_reference = ConnectionReference { - response_sender, - handle, - }; - - entry.insert(connection_reference); } Err(err) => { ::log::error!("accept connection: {:?}", err); @@ -182,27 +168,62 @@ async fn receive_responses( } impl Connection { - async fn handle_stream(&mut self, stream: TcpStream) -> anyhow::Result<()> { - let peer_addr = Self::get_peer_addr(&stream)?; - let mut stream = self.tls_acceptor.accept(stream).await?; + async fn run( + config: Rc, + access_list: Rc>, + request_senders: Rc>, + response_receiver: LocalReceiver, + response_consumer_id: ConsumerId, + connection_id: ConnectionId, + tls_config: Arc, + stream: TcpStream, + ) -> anyhow::Result<()> { + let peer_addr = stream + .peer_addr() + .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?; + let tls_acceptor: TlsAcceptor = tls_config.into(); + let stream = tls_acceptor.accept(stream).await?; + + let mut conn = Connection { + config: config.clone(), + access_list: access_list.clone(), + request_senders: request_senders.clone(), + response_receiver, + response_consumer_id, + stream, + peer_addr, + connection_id, + request_buffer: [0u8; MAX_REQUEST_SIZE], + request_buffer_position: 0, + }; + + conn.run_request_response_loop().await?; + + Ok(()) + } + + async fn run_request_response_loop(&mut self) -> anyhow::Result<()> { loop { - let response = match self.read_request(&mut stream).await? { + let response = match self.read_request().await? { Either::Left(response) => Response::Failure(response), - Either::Right(request) => { - match self.handle_request(request, peer_addr).await? { - Either::Left(response) => Response::Failure(response), - Either::Right(opt_pending_scrape_response) => { - self.wait_for_response(peer_addr, opt_pending_scrape_response).await? - } + Either::Right(request) => match self.handle_request(request).await? { + Either::Left(response) => Response::Failure(response), + Either::Right(opt_pending_scrape_response) => { + self.wait_for_response(opt_pending_scrape_response).await? } - } + }, }; - self.write_response(&mut stream, &response).await?; + self.write_response(&response).await?; if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive { - Self::close_stream(stream).await; + let _ = self + .stream + .get_ref() + .0 + .shutdown(std::net::Shutdown::Both) + .await; break; } @@ -211,23 +232,13 @@ impl Connection { Ok(()) } - fn get_peer_addr(stream: &TcpStream) -> anyhow::Result { - stream - .peer_addr() - .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err)) - } - - async fn close_stream(stream: TlsStream) { - let _ = stream.get_ref().0.shutdown(std::net::Shutdown::Both).await; - } - - async fn read_request(&mut self, stream: &mut TlsStream) -> anyhow::Result> { + async fn read_request(&mut self) -> anyhow::Result> { let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE]; loop { ::log::debug!("read"); - let bytes_read = stream.read(&mut buf).await?; + let bytes_read = self.stream.read(&mut buf).await?; let request_buffer_end = self.request_buffer_position + bytes_read; if request_buffer_end > self.request_buffer.len() { @@ -276,7 +287,6 @@ impl Connection { async fn handle_request( &self, request: Request, - peer_addr: SocketAddr, ) -> anyhow::Result>> { match request { Request::Announce(request) => { @@ -291,7 +301,7 @@ impl Connection { request, connection_id: self.connection_id, response_consumer_id: self.response_consumer_id, - peer_addr, + peer_addr: self.peer_addr, }; let consumer_index = calculate_request_consumer_index(&self.config, info_hash); @@ -327,7 +337,7 @@ impl Connection { for (consumer_index, info_hashes) in info_hashes_by_worker { let request = ChannelRequest::Scrape { request: ScrapeRequest { info_hashes }, - peer_addr, + peer_addr: self.peer_addr, response_consumer_id: self.response_consumer_id, connection_id: self.connection_id, }; @@ -353,13 +363,12 @@ impl Connection { /// return full response async fn wait_for_response( &self, - peer_addr: SocketAddr, mut opt_pending_scrape_response: Option, ) -> anyhow::Result { loop { if let Some(channel_response) = self.response_receiver.recv().await { - if channel_response.get_peer_addr() != peer_addr { - return Err(anyhow::anyhow!("peer addressess didn't match")); + if channel_response.get_peer_addr() != self.peer_addr { + return Err(anyhow::anyhow!("peer addresses didn't match")); } match channel_response { @@ -396,7 +405,7 @@ impl Connection { } } - async fn write_response(&mut self, stream: &mut TlsStream, response: &Response) -> anyhow::Result<()> { + async fn write_response(&mut self, response: &Response) -> anyhow::Result<()> { let mut body = Vec::new(); response.write(&mut body).unwrap(); @@ -412,7 +421,8 @@ impl Connection { response_bytes.append(&mut body); response_bytes.extend_from_slice(b"\r\n"); - stream.write(&response_bytes[..]).await?; + self.stream.write(&response_bytes[..]).await?; + self.stream.flush().await?; Ok(()) }