diff --git a/aquatic_common/src/access_list.rs b/aquatic_common/src/access_list.rs index 06b5dd4..69f856c 100644 --- a/aquatic_common/src/access_list.rs +++ b/aquatic_common/src/access_list.rs @@ -42,7 +42,7 @@ impl Default for AccessListConfig { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct AccessList(HashSet<[u8; 20]>); impl AccessList { diff --git a/aquatic_udp/src/lib/glommio/network.rs b/aquatic_udp/src/lib/glommio/network.rs index 22516a7..0ef3ae4 100644 --- a/aquatic_udp/src/lib/glommio/network.rs +++ b/aquatic_udp/src/lib/glommio/network.rs @@ -1,6 +1,5 @@ use std::cell::RefCell; use std::io::Cursor; -use std::iter::FromIterator; use std::net::{IpAddr, SocketAddr}; use std::rc::Rc; use std::sync::{ @@ -31,7 +30,6 @@ use crate::config::Config; struct PendingScrapeResponse { pending_worker_responses: usize, valid_until: ValidUntil, - src: SocketAddr, stats: Vec, } @@ -39,16 +37,39 @@ struct PendingScrapeResponse { struct PendingScrapeResponses(HashMap); impl PendingScrapeResponses { - fn insert_empty(&mut self, transaction_id: TransactionId, src: SocketAddr, pending_worker_responses: usize, valid_until: ValidUntil) { + fn prepare(&mut self, transaction_id: TransactionId, pending_worker_responses: usize, valid_until: ValidUntil) { let pending = PendingScrapeResponse { pending_worker_responses, valid_until, - src, stats: Vec::new(), }; self.0.insert(transaction_id, pending); } + + fn add_and_get_finished(&mut self, mut response: ScrapeResponse) -> Option { + let finished = if let Some(r) = self.0.get_mut(&response.transaction_id) { + r.pending_worker_responses -= 1; + r.stats.append(&mut response.torrent_stats); + + r.pending_worker_responses == 0 + } else { + ::log::warn!("PendingScrapeResponses.add didn't find PendingScrapeResponse in map"); + + false + }; + + if finished { + let r = self.0.remove(&response.transaction_id).unwrap(); + + Some(ScrapeResponse { + transaction_id: response.transaction_id, + torrent_stats: r.stats, + }) + } else { + None + } + } } pub async fn run_socket_worker( @@ -78,9 +99,10 @@ pub async fn run_socket_worker( let response_consumer_index = response_receivers.consumer_id().unwrap(); + // FIXME: needs cleaning let pending_scrape_responses = Rc::new(RefCell::new(PendingScrapeResponses::default())); - spawn_local(read_requests( + spawn_local(enclose!((pending_scrape_responses) read_requests( config.clone(), request_senders, response_consumer_index, @@ -88,18 +110,19 @@ pub async fn run_socket_worker( socket.clone(), pending_scrape_responses, access_list, - )) + ))) .detach(); for (_, receiver) in response_receivers.streams().into_iter() { - spawn_local(send_responses( + spawn_local(enclose!((pending_scrape_responses) handle_shared_responses( socket.clone(), - receiver.map(|(response, addr)| (response.into(), addr)), - )) + pending_scrape_responses, + receiver, + ))) .detach(); } - send_responses(socket, local_receiver.stream()).await; + send_local_responses(socket, local_receiver.stream()).await; } async fn read_requests( @@ -215,9 +238,8 @@ async fn read_requests( .info_hashes.push(info_hash); } - pending_scrape_responses.borrow_mut().insert_empty( + pending_scrape_responses.borrow_mut().prepare( request.transaction_id, - src, consumer_requests.len(), connection_valid_until.borrow().to_owned(), // FIXME: use seperate ValidUntil ); @@ -262,32 +284,70 @@ async fn read_requests( } } -async fn send_responses(socket: Rc, mut stream: S) +async fn handle_shared_responses( + socket: Rc, + pending_scrape_responses: Rc>, + mut stream: S, +) where + S: Stream + ::std::marker::Unpin, +{ + let mut buf = [0u8; MAX_PACKET_SIZE]; + let mut buf = Cursor::new(&mut buf[..]); + + while let Some((response, addr)) = stream.next().await { + let opt_response = match response { + ConnectedResponse::Announce(response) => Some((Response::Announce(response), addr)), + ConnectedResponse::Scrape(response) => { + pending_scrape_responses + .borrow_mut() + .add_and_get_finished(response) + .map(|response| (Response::Scrape(response), addr)) + }, + }; + + if let Some((response, addr)) = opt_response { + write_response_to_socket(&socket, &mut buf, addr, response).await; + } + + yield_if_needed().await; + } +} + +async fn send_local_responses(socket: Rc, mut stream: S) where S: Stream + ::std::marker::Unpin, { let mut buf = [0u8; MAX_PACKET_SIZE]; let mut buf = Cursor::new(&mut buf[..]); - while let Some((response, src)) = stream.next().await { - buf.set_position(0); - - ::log::debug!("preparing to send response: {:?}", response.clone()); - - response - .write(&mut buf, ip_version_from_ip(src.ip())) - .expect("write response"); - - let position = buf.position() as usize; - - if let Err(err) = socket.send_to(&buf.get_ref()[..position], src).await { - ::log::info!("send_to failed: {:?}", err); - } + while let Some((response, addr)) = stream.next().await { + write_response_to_socket(&socket, &mut buf, addr, response).await; yield_if_needed().await; } } +async fn write_response_to_socket( + socket: &Rc, + buf: &mut Cursor<&mut [u8]>, + addr: SocketAddr, + response: Response, +) { + buf.set_position(0); + + ::log::debug!("preparing to send response: {:?}", response.clone()); + + response + .write(buf, ip_version_from_ip(addr.ip())) + .expect("write response"); + + let position = buf.position() as usize; + + if let Err(err) = socket.send_to(&buf.get_ref()[..position], addr).await { + ::log::info!("send_to failed: {:?}", err); + } +} + fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize { (info_hash.0[0] as usize) % config.request_workers }