From 6404be8ef8ac0c88012e800183d945179b0fe8ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Wed, 27 Oct 2021 13:55:17 +0200 Subject: [PATCH] aquatic_http: glommio: add access list, refactor networking slightly --- aquatic_http/src/lib/glommio/mod.rs | 2 +- aquatic_http/src/lib/glommio/network.rs | 112 +++++++++++++++--------- 2 files changed, 74 insertions(+), 40 deletions(-) diff --git a/aquatic_http/src/lib/glommio/mod.rs b/aquatic_http/src/lib/glommio/mod.rs index acb3e41..c93fbc5 100644 --- a/aquatic_http/src/lib/glommio/mod.rs +++ b/aquatic_http/src/lib/glommio/mod.rs @@ -54,7 +54,7 @@ pub fn run(config: Config) -> anyhow::Result<()> { request_mesh_builder, response_mesh_builder, num_bound_sockets, - // access_list, + access_list, ) .await }); diff --git a/aquatic_http/src/lib/glommio/network.rs b/aquatic_http/src/lib/glommio/network.rs index 9ca5984..a7a2b37 100644 --- a/aquatic_http/src/lib/glommio/network.rs +++ b/aquatic_http/src/lib/glommio/network.rs @@ -7,11 +7,13 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; +use aquatic_common::access_list::AccessList; use aquatic_http_protocol::common::InfoHash; use aquatic_http_protocol::request::{AnnounceRequest, Request, RequestParseError, ScrapeRequest}; use aquatic_http_protocol::response::{ FailureResponse, Response, ScrapeResponse, ScrapeStatistics, }; +use either::Either; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; @@ -42,6 +44,7 @@ struct ConnectionReference { struct Connection { config: Rc, + access_list: Rc>, request_senders: Rc>, response_receiver: LocalReceiver, response_consumer_id: ConsumerId, @@ -50,7 +53,6 @@ struct Connection { connection_id: ConnectionId, request_buffer: Vec, close_after_writing: bool, - pending_scrape_response: Option, } pub async fn run_socket_worker( @@ -59,8 +61,10 @@ pub async fn run_socket_worker( request_mesh_builder: MeshBuilder, response_mesh_builder: MeshBuilder, num_bound_sockets: Arc, + access_list: AccessList, ) { let config = Rc::new(config); + let access_list = Rc::new(RefCell::new(access_list)); let listener = TcpListener::bind(config.network.address).expect("bind socket"); num_bound_sockets.fetch_add(1, Ordering::SeqCst); @@ -115,6 +119,7 @@ pub async fn run_socket_worker( let mut conn = Connection { config: config.clone(), + access_list: access_list.clone(), request_senders: request_senders.clone(), response_receiver, response_consumer_id, @@ -123,7 +128,6 @@ pub async fn run_socket_worker( connection_id: ConnectionId(entry.key()), request_buffer: Vec::new(), close_after_writing: false, - pending_scrape_response: None, }; let connections_to_remove = connections_to_remove.clone(); @@ -172,8 +176,23 @@ impl Connection { let opt_request = self.read_tls().await?; if let Some(request) = opt_request { - self.handle_request(request).await?; - self.wait_for_and_send_response().await?; + let response = match self.handle_request(request).await? { + Some(Either::Left(response)) => { + response + } + Some(Either::Right(pending_scrape_response)) => { + self.wait_for_response(Some(pending_scrape_response)).await? + }, + None => { + self.wait_for_response(None).await? + } + }; + + self.queue_response(&response)?; + + if !self.config.network.keep_alive { + self.close_after_writing = true; + } } self.write_tls().await?; @@ -292,26 +311,43 @@ impl Connection { Ok(()) } - /// Send on request to proper request worker/workers - async fn handle_request(&mut self, request: Request) -> anyhow::Result<()> { + /// Take a request and: + /// - Return error response if request is not allowed + /// - If it is an announce requests, pass it on to request workers and return None + /// - If it is a scrape requests, split it up and pass on parts to + /// relevant request workers, and return PendingScrapeResponse struct. + async fn handle_request( + &self, + request: Request + ) -> anyhow::Result>> { let peer_addr = self.get_peer_addr()?; match request { Request::Announce(request @ AnnounceRequest { info_hash, .. }) => { - let request = ChannelRequest::Announce { - request, - connection_id: self.connection_id, - response_consumer_id: self.response_consumer_id, - peer_addr, - }; + if self.access_list.borrow().allows(self.config.access_list.mode, &info_hash.0) { + let request = ChannelRequest::Announce { + request, + connection_id: self.connection_id, + response_consumer_id: self.response_consumer_id, + peer_addr, + }; - let consumer_index = calculate_request_consumer_index(&self.config, info_hash); + let consumer_index = calculate_request_consumer_index(&self.config, info_hash); - // Only fails when receiver is closed - self.request_senders - .send_to(consumer_index, request) - .await - .unwrap(); + // Only fails when receiver is closed + self.request_senders + .send_to(consumer_index, request) + .await + .unwrap(); + + Ok(None) + } else { + let response = Response::Failure(FailureResponse { + failure_reason: "Info hash not allowed".into(), + }); + + Ok(Some(Either::Left(response))) + } } Request::Scrape(ScrapeRequest { info_hashes }) => { let mut info_hashes_by_worker: BTreeMap> = BTreeMap::new(); @@ -324,10 +360,7 @@ impl Connection { info_hashes.push(info_hash); } - self.pending_scrape_response = Some(PendingScrapeResponse { - pending_worker_responses: info_hashes_by_worker.len(), - stats: Default::default(), - }); + let pending_worker_responses = info_hashes_by_worker.len(); for (consumer_index, info_hashes) in info_hashes_by_worker { let request = ChannelRequest::Scrape { @@ -343,15 +376,24 @@ impl Connection { .await .unwrap(); } + + let pending_scrape_response = PendingScrapeResponse { + pending_worker_responses, + stats: Default::default(), + }; + + Ok(Some(Either::Right(pending_scrape_response))) } } - - Ok(()) } - // Wait for response/responses to arrive, then queue response for sending to peer - async fn wait_for_and_send_response(&mut self) -> anyhow::Result<()> { - let response = loop { + /// Wait for announce response or partial scrape responses to arrive, + /// return full response + async fn wait_for_response( + &self, + mut opt_pending_scrape_response: Option + ) -> anyhow::Result { + loop { if let Some(channel_response) = self.response_receiver.recv().await { if channel_response.get_peer_addr() != self.get_peer_addr()? { return Err(anyhow::anyhow!("peer addressess didn't match")); @@ -359,10 +401,10 @@ impl Connection { match channel_response { ChannelResponse::Announce { response, .. } => { - break Response::Announce(response); + break Ok(Response::Announce(response)); } ChannelResponse::Scrape { response, .. } => { - if let Some(mut pending) = self.pending_scrape_response.take() { + if let Some(mut pending) = opt_pending_scrape_response.take() { pending.stats.extend(response.files); pending.pending_worker_responses -= 1; @@ -371,9 +413,9 @@ impl Connection { files: pending.stats, }); - break response; + break Ok(response); } else { - self.pending_scrape_response = Some(pending); + opt_pending_scrape_response = Some(pending); } } else { return Err(anyhow::anyhow!( @@ -388,15 +430,7 @@ impl Connection { "response receiver can't receive - sender is closed" )); } - }; - - self.queue_response(&response)?; - - if !self.config.network.keep_alive { - self.close_after_writing = true; } - - Ok(()) } fn queue_response(&mut self, response: &Response) -> anyhow::Result<()> {