From d2595e9746b94449dd66822abb76d27e4ecb7cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Mon, 1 Nov 2021 21:34:34 +0100 Subject: [PATCH] aquatic_ws: split networking into reader and writer tasks --- aquatic_ws/src/lib/common.rs | 4 + aquatic_ws/src/lib/network.rs | 247 ++++++++++++++++++---------------- 2 files changed, 138 insertions(+), 113 deletions(-) diff --git a/aquatic_ws/src/lib/common.rs b/aquatic_ws/src/lib/common.rs index 76f08fc..9ee675c 100644 --- a/aquatic_ws/src/lib/common.rs +++ b/aquatic_ws/src/lib/common.rs @@ -19,6 +19,9 @@ use crate::config::Config; pub type TlsConfig = futures_rustls::rustls::ServerConfig; +#[derive(Copy, Clone, Debug)] +pub struct PendingScrapeId(pub usize); + #[derive(Copy, Clone, Debug)] pub struct ConsumerId(pub usize); @@ -35,6 +38,7 @@ pub struct ConnectionMeta { /// an IPv4 address if it was a IPv4-mapped IPv6 address pub naive_peer_addr: SocketAddr, pub converted_peer_ip: IpAddr, + pub pending_scrape_id: Option, } #[derive(PartialEq, Eq, Clone, Copy, Debug)] diff --git a/aquatic_ws/src/lib/network.rs b/aquatic_ws/src/lib/network.rs index f800222..f2682d1 100644 --- a/aquatic_ws/src/lib/network.rs +++ b/aquatic_ws/src/lib/network.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::cell::RefCell; use std::collections::BTreeMap; use std::net::SocketAddr; @@ -10,8 +11,8 @@ use aquatic_common::access_list::AccessList; use aquatic_common::convert_ipv4_mapped_ipv6; use aquatic_ws_protocol::*; use async_tungstenite::WebSocketStream; -use either::Either; use futures::stream::{SplitSink, SplitStream}; +use futures_lite::future::race; use futures_lite::StreamExt; use futures_rustls::server::TlsStream; use futures_rustls::TlsAcceptor; @@ -34,19 +35,7 @@ struct PendingScrapeResponse { } struct ConnectionReference { - out_message_sender: LocalSender<(ConnectionMeta, OutMessage)>, -} - -struct Connection { - config: Rc, - access_list: Rc>, - in_message_senders: Rc>, - out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, - out_message_consumer_id: ConsumerId, - ws_out: SplitSink>, tungstenite::Message>, - ws_in: SplitStream>>, - peer_addr: SocketAddr, - connection_id: ConnectionId, + out_message_sender: Rc>, } pub async fn run_socket_worker( @@ -107,16 +96,19 @@ pub async fn run_socket_worker( match stream { Ok(stream) => { let (out_message_sender, out_message_receiver) = - new_bounded(config.request_workers); - let key = connection_slab - .borrow_mut() - .insert(ConnectionReference { out_message_sender }); + new_bounded(config.request_workers + 1); + let out_message_sender = Rc::new(out_message_sender); + + let key = RefCell::borrow_mut(&connection_slab).insert(ConnectionReference { + out_message_sender: out_message_sender.clone(), + }); spawn_local(enclose!((config, access_list, in_message_senders, tls_config, connections_to_remove) async move { if let Err(err) = Connection::run( config, access_list, in_message_senders, + out_message_sender, out_message_receiver, out_message_consumer_id, ConnectionId(key), @@ -126,7 +118,7 @@ pub async fn run_socket_worker( ::log::debug!("Connection::run() error: {:?}", err); } - connections_to_remove.borrow_mut().push(key); + RefCell::borrow_mut(&connections_to_remove).push(key); })) .detach(); } @@ -145,7 +137,7 @@ async fn remove_closed_connections( let connections_to_remove = connections_to_remove.replace(Vec::new()); for connection_id in connections_to_remove { - if let Some(_) = connection_slab.borrow_mut().try_remove(connection_id) { + if let Some(_) = RefCell::borrow_mut(&connection_slab).try_remove(connection_id) { ::log::debug!("removed connection with id {}", connection_id); } else { ::log::error!( @@ -168,17 +160,23 @@ async fn receive_out_messages( .get(channel_out_message.0.connection_id.0) { if let Err(err) = reference.out_message_sender.try_send(channel_out_message) { - ::log::error!("Couldn't send out_message to local receiver: {:?}", err); + ::log::error!( + "Couldn't send out_message from shared channel to local receiver: {:?}", + err + ); } } } } +struct Connection; + impl Connection { async fn run( config: Rc, access_list: Rc>, in_message_senders: Rc>, + out_message_sender: Rc>, out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, out_message_consumer_id: ConsumerId, connection_id: ConnectionId, @@ -200,45 +198,57 @@ impl Connection { let stream = async_tungstenite::accept_async_with_config(stream, Some(ws_config)).await?; let (ws_out, ws_in) = futures::StreamExt::split(stream); - let mut conn = Connection { - config: config.clone(), - access_list: access_list.clone(), - in_message_senders: in_message_senders.clone(), - out_message_receiver, - out_message_consumer_id, - ws_out, - ws_in, - peer_addr, - connection_id, - }; + let pending_scrape_slab = Rc::new(RefCell::new(Slab::new())); - conn.run_message_loop().await?; - - Ok(()) - } - - async fn run_message_loop(&mut self) -> anyhow::Result<()> { - loop { - let out_message = match self.read_in_message().await? { - Either::Left(out_message) => OutMessage::ErrorResponse(out_message), - Either::Right(in_message) => self.handle_in_message(in_message).await?, + let reader_handle = spawn_local(enclose!((pending_scrape_slab) async move { + let mut reader = ConnectionReader { + config, + access_list, + in_message_senders, + out_message_sender, + pending_scrape_slab, + out_message_consumer_id, + ws_in, + peer_addr, + connection_id, }; - self.write_out_message(&out_message).await?; + reader.run_in_message_loop().await + })) + .detach(); - if matches!(out_message, OutMessage::ErrorResponse(_)) { - // TODO: shut down? + let writer_handle = spawn_local(async move { + let mut writer = ConnectionWriter { + out_message_receiver, + ws_out, + pending_scrape_slab, + peer_addr, + }; - break; - } - } + writer.run_out_message_loop().await + }) + .detach(); - Ok(()) + race(reader_handle, writer_handle).await.unwrap() } +} - async fn read_in_message(&mut self) -> anyhow::Result> { +struct ConnectionReader { + config: Rc, + access_list: Rc>, + in_message_senders: Rc>, + out_message_sender: Rc>, + pending_scrape_slab: Rc>>, + out_message_consumer_id: ConsumerId, + ws_in: SplitStream>>, + peer_addr: SocketAddr, + connection_id: ConnectionId, +} + +impl ConnectionReader { + async fn run_in_message_loop(&mut self) -> anyhow::Result<()> { loop { - ::log::debug!("read"); + ::log::debug!("read_in_message"); let message = self.ws_in.next().await.unwrap()?; @@ -246,24 +256,18 @@ impl Connection { Ok(in_message) => { ::log::debug!("received in_message: {:?}", in_message); - return Ok(Either::Right(in_message)); + self.handle_in_message(in_message).await?; } Err(err) => { ::log::debug!("Couldn't parse in_message: {:?}", err); - let out_message = ErrorResponse { - action: None, - failure_reason: "Invalid request".into(), - info_hash: None, - }; - - return Ok(Either::Left(out_message)); + self.send_error_response("Invalid request".into(), None); } } } } - async fn handle_in_message(&self, in_message: InMessage) -> anyhow::Result { + async fn handle_in_message(&mut self, in_message: InMessage) -> anyhow::Result<()> { match in_message { InMessage::AnnounceRequest(announce_request) => { let info_hash = announce_request.info_hash; @@ -273,12 +277,6 @@ impl Connection { .borrow() .allows(self.config.access_list.mode, &info_hash.0) { - let meta = ConnectionMeta { - connection_id: self.connection_id, - out_message_consumer_id: self.out_message_consumer_id, - naive_peer_addr: self.peer_addr, - converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()), - }; let in_message = InMessage::AnnounceRequest(announce_request); let consumer_index = @@ -286,19 +284,14 @@ impl Connection { // Only fails when receiver is closed self.in_message_senders - .send_to(consumer_index, (meta, in_message)) + .send_to( + consumer_index, + (self.make_connection_meta(None), in_message), + ) .await .unwrap(); - - self.wait_for_out_message(None).await } else { - let out_message = OutMessage::ErrorResponse(ErrorResponse { - action: Some(ErrorResponseAction::Announce), - failure_reason: "Info hash not allowed".into(), - info_hash: Some(info_hash), - }); - - Ok(out_message) + self.send_error_response("Info hash not allowed".into(), Some(info_hash)); } } InMessage::ScrapeRequest(ScrapeRequest { info_hashes, .. }) => { @@ -307,13 +300,9 @@ impl Connection { } else { // If request.info_hashes is empty, don't return scrape for all // torrents, even though reference server does it. It is too expensive. - let out_message = OutMessage::ErrorResponse(ErrorResponse { - action: Some(ErrorResponseAction::Scrape), - failure_reason: "Full scrapes are not allowed".into(), - info_hash: None, - }); + self.send_error_response("Full scrapes are not allowed".into(), None); - return Ok(out_message); + return Ok(()); }; let mut info_hashes_by_worker: BTreeMap> = BTreeMap::new(); @@ -328,13 +317,17 @@ impl Connection { let pending_worker_out_messages = info_hashes_by_worker.len(); - let meta = ConnectionMeta { - connection_id: self.connection_id, - out_message_consumer_id: self.out_message_consumer_id, - naive_peer_addr: self.peer_addr, - converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()), + let pending_scrape_response = PendingScrapeResponse { + pending_worker_out_messages, + stats: Default::default(), }; + let pending_scrape_id = PendingScrapeId( + RefCell::borrow_mut(&mut self.pending_scrape_slab) + .insert(pending_scrape_response), + ); + let meta = self.make_connection_meta(Some(pending_scrape_id)); + for (consumer_index, info_hashes) in info_hashes_by_worker { let in_message = InMessage::ScrapeRequest(ScrapeRequest { action: ScrapeAction, @@ -347,24 +340,43 @@ impl Connection { .await .unwrap(); } - - let pending_scrape_out_message = PendingScrapeResponse { - pending_worker_out_messages, - stats: Default::default(), - }; - - self.wait_for_out_message(Some(pending_scrape_out_message)) - .await } } + + Ok(()) } - /// Wait for announce out_message or partial scrape out_messages to arrive, - /// return full out_message - async fn wait_for_out_message( - &self, - mut opt_pending_scrape_out_message: Option, - ) -> anyhow::Result { + fn send_error_response(&self, failure_reason: Cow<'static, str>, info_hash: Option) { + let out_message = OutMessage::ErrorResponse(ErrorResponse { + action: Some(ErrorResponseAction::Scrape), + failure_reason, + info_hash, + }); + + self.out_message_sender + .try_send((self.make_connection_meta(None), out_message)); + } + + fn make_connection_meta(&self, pending_scrape_id: Option) -> ConnectionMeta { + ConnectionMeta { + connection_id: self.connection_id, + out_message_consumer_id: self.out_message_consumer_id, + naive_peer_addr: self.peer_addr, + converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()), + pending_scrape_id, + } + } +} + +struct ConnectionWriter { + out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, + ws_out: SplitSink>, tungstenite::Message>, + pending_scrape_slab: Rc>>, + peer_addr: SocketAddr, +} + +impl ConnectionWriter { + async fn run_out_message_loop(&mut self) -> anyhow::Result<()> { loop { let (meta, out_message) = self .out_message_receiver @@ -378,34 +390,43 @@ impl Connection { match out_message { OutMessage::ScrapeResponse(out_message) => { - if let Some(mut pending) = opt_pending_scrape_out_message.take() { + let pending_scrape_id = meta + .pending_scrape_id + .expect("meta.pending_scrape_id not set"); + + let opt_message = if let Some(pending) = Slab::get_mut( + &mut RefCell::borrow_mut(&self.pending_scrape_slab), + pending_scrape_id.0, + ) { pending.stats.extend(out_message.files); pending.pending_worker_out_messages -= 1; if pending.pending_worker_out_messages == 0 { - let out_message = OutMessage::ScrapeResponse(ScrapeResponse { + Some(OutMessage::ScrapeResponse(ScrapeResponse { action: ScrapeAction, - files: pending.stats, - }); - - break Ok(out_message); + files: pending.stats.clone(), // FIXME: clone + })) } else { - opt_pending_scrape_out_message = Some(pending); + None } } else { - return Err(anyhow::anyhow!( - "received channel scrape out_message without pending scrape out_message" - )); + return Err(anyhow::anyhow!("pending scrape not found in slab")); + }; + + if let Some(out_message) = opt_message { + self.send_out_message(&out_message).await?; + + RefCell::borrow_mut(&self.pending_scrape_slab).remove(pending_scrape_id.0); } } out_message => { - break Ok(out_message); + self.send_out_message(&out_message).await?; } }; } } - async fn write_out_message(&mut self, out_message: &OutMessage) -> anyhow::Result<()> { + async fn send_out_message(&mut self, out_message: &OutMessage) -> anyhow::Result<()> { futures::SinkExt::send(&mut self.ws_out, out_message.to_ws_message()).await?; futures::SinkExt::flush(&mut self.ws_out).await?;