diff --git a/aquatic_http/src/lib/glommio/common.rs b/aquatic_http/src/lib/glommio/common.rs index 0ca0bed..bfc8053 100644 --- a/aquatic_http/src/lib/glommio/common.rs +++ b/aquatic_http/src/lib/glommio/common.rs @@ -1,9 +1,6 @@ use std::net::SocketAddr; -use aquatic_http_protocol::{ - request::{AnnounceRequest, ScrapeRequest}, - response::{AnnounceResponse, ScrapeResponse}, -}; +use aquatic_http_protocol::{common::InfoHash, request::{AnnounceRequest, ScrapeRequest}, response::{AnnounceResponse, ScrapeResponse, ScrapeStatistics}}; #[derive(Copy, Clone, Debug)] pub struct ConsumerId(pub usize); @@ -22,7 +19,6 @@ pub enum ChannelRequest { Scrape { request: ScrapeRequest, peer_addr: SocketAddr, - original_indices: Vec, connection_id: ConnectionId, response_consumer_id: ConsumerId, }, @@ -38,7 +34,6 @@ pub enum ChannelResponse { Scrape { response: ScrapeResponse, peer_addr: SocketAddr, - original_indices: Vec, connection_id: ConnectionId, }, } diff --git a/aquatic_http/src/lib/glommio/handlers.rs b/aquatic_http/src/lib/glommio/handlers.rs index e7abd65..13ca6b2 100644 --- a/aquatic_http/src/lib/glommio/handlers.rs +++ b/aquatic_http/src/lib/glommio/handlers.rs @@ -118,7 +118,6 @@ async fn handle_request_stream( peer_addr, response_consumer_id, connection_id, - original_indices, } => { let meta = ConnectionMeta { worker_index: response_consumer_id.0, @@ -133,7 +132,6 @@ async fn handle_request_stream( response, peer_addr, connection_id, - original_indices, }; (response, response_consumer_id) diff --git a/aquatic_http/src/lib/glommio/network.rs b/aquatic_http/src/lib/glommio/network.rs index 9d8a264..7f46e84 100644 --- a/aquatic_http/src/lib/glommio/network.rs +++ b/aquatic_http/src/lib/glommio/network.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::BTreeMap; use std::io::{Cursor, ErrorKind, Read, Write}; use std::net::SocketAddr; use std::rc::Rc; @@ -6,8 +7,8 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use aquatic_http_protocol::common::InfoHash; -use aquatic_http_protocol::request::{AnnounceRequest, Request, RequestParseError}; -use aquatic_http_protocol::response::{FailureResponse, Response}; +use aquatic_http_protocol::request::{AnnounceRequest, Request, RequestParseError, ScrapeRequest}; +use aquatic_http_protocol::response::{FailureResponse, Response, ScrapeResponse, ScrapeStatistics}; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; @@ -25,6 +26,11 @@ use super::common::*; const BUFFER_SIZE: usize = 1024; +struct PendingScrapeResponse { + pending_worker_responses: usize, + stats: BTreeMap, +} + struct ConnectionReference { response_sender: LocalSender, handle: JoinHandle<()>, @@ -40,6 +46,7 @@ struct Connection { connection_id: ConnectionId, request_buffer: Vec, close_after_writing: bool, + pending_scrape_response: Option, } pub async fn run_socket_worker( @@ -90,6 +97,7 @@ pub async fn run_socket_worker( connection_id: ConnectionId(entry.key()), request_buffer: Vec::new(), close_after_writing: false, + pending_scrape_response: None, }; async fn handle_stream(mut conn: Connection) { @@ -269,43 +277,92 @@ impl Connection { let consumer_index = calculate_request_consumer_index(&self.config, info_hash); - self.request_senders.try_send_to(consumer_index, request); + + if let Err(err) = self.request_senders.try_send_to( + consumer_index, + request, + ) { + ::log::warn!("request_sender.try_send failed: {:?}", err); + } } - Request::Scrape(request) => { - // TODO + Request::Scrape(ScrapeRequest { info_hashes }) => { + let mut info_hashes_by_worker: BTreeMap> = BTreeMap::new(); + + for info_hash in info_hashes.into_iter() { + let info_hashes = info_hashes_by_worker + .entry(calculate_request_consumer_index(&self.config, info_hash)) + .or_default(); + + info_hashes.push(info_hash); + } + + self.pending_scrape_response = Some(PendingScrapeResponse { + pending_worker_responses: info_hashes_by_worker.len(), + stats: Default::default(), + }); + + for (consumer_index, info_hashes) in info_hashes_by_worker { + let request = ChannelRequest::Scrape { + request: ScrapeRequest { info_hashes }, + peer_addr, + response_consumer_id: self.response_consumer_id, + connection_id: self.connection_id, + }; + + if let Err(err) = self.request_senders.try_send_to( + consumer_index, + request, + ) { + ::log::warn!("request_sender.try_send failed: {:?}", err); + } + } } } Ok(()) } - // Wait for response to arrive, then queue it for sending to peer + // Wait for response/responses to arrive, then queue response for sending to peer async fn wait_for_and_send_response(&mut self) -> anyhow::Result<()> { - if let Some(channel_response) = self.response_receiver.recv().await { - if channel_response.get_peer_addr() != self.get_peer_addr()? { - return Err(anyhow::anyhow!("peer addressess didn't match")); + let response = loop { + if let Some(channel_response) = self.response_receiver.recv().await { + if channel_response.get_peer_addr() != self.get_peer_addr()? { + return Err(anyhow::anyhow!("peer addressess didn't match")); + } + + match channel_response { + ChannelResponse::Announce { response, .. } => { + break Response::Announce(response); + } + ChannelResponse::Scrape { response, .. } => { + if let Some(mut pending) = self.pending_scrape_response.take() { + pending.stats.extend(response.files); + pending.pending_worker_responses -= 1; + + if pending.pending_worker_responses == 0 { + let response = Response::Scrape(ScrapeResponse { + files: pending.stats, + }); + + break response; + } else { + self.pending_scrape_response = Some(pending); + } + } else { + return Err(anyhow::anyhow!("received channel scrape response without pending scrape response")); + } + } + }; + } else { + // TODO: this is a serious error condition and should maybe be handled differently + return Err(anyhow::anyhow!("response receiver can't receive - sender is closed")); } + }; - let opt_response = match channel_response { - ChannelResponse::Announce { response, .. } => { - Some(Response::Announce(response)) - } - ChannelResponse::Scrape { - response, - original_indices, - .. - } => { - None // TODO: accumulate scrape requests - } - }; + self.queue_response(&response)?; - if let Some(response) = opt_response { - self.queue_response(&response)?; - - if !self.config.network.keep_alive { - self.close_after_writing = true; - } - } + if !self.config.network.keep_alive { + self.close_after_writing = true; } Ok(())