From 0c4140165ba3397a4857722eeec52ea0aa0bb788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Sun, 10 Dec 2023 12:07:38 +0100 Subject: [PATCH] udp: reuse response channel mem, add backpressure, faster peer extract --- CHANGELOG.md | 8 + Cargo.lock | 34 +++ crates/udp/Cargo.toml | 1 + crates/udp/src/common.rs | 129 ++++++-- crates/udp/src/config.rs | 8 +- crates/udp/src/lib.rs | 10 +- crates/udp/src/workers/socket/mio.rs | 278 ++++++++++------- crates/udp/src/workers/socket/mod.rs | 9 +- crates/udp/src/workers/socket/storage.rs | 8 +- crates/udp/src/workers/socket/uring/mod.rs | 108 ++++--- .../src/workers/socket/uring/send_buffers.rs | 41 +-- crates/udp/src/workers/swarm/mod.rs | 178 ++++------- crates/udp/src/workers/swarm/storage.rs | 282 +++++++++--------- crates/udp/tests/requests_responses.rs | 13 +- crates/udp_protocol/src/response.rs | 81 +++-- 15 files changed, 666 insertions(+), 522 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 747bd56..668e3b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,14 @@ * Add support for reporting peer client information +#### Changed + +* Remove support for unbounded worker channels +* Add backpressure in socket workers. They will postpone reading from the + socket if sending a request to a swarm worker failed +* Reuse allocations in swarm response channel +* Remove config key `network.poll_event_capacity` + ### aquatic_http #### Added diff --git a/Cargo.lock b/Cargo.lock index 2b80164..e3bb402 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -256,6 +256,7 @@ dependencies = [ "slab", "socket2 0.5.5", "tempfile", + "thingbuf", "time", "tinytemplate", ] @@ -1958,6 +1959,29 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.48.5", +] + [[package]] name = "percent-encoding" version = "2.3.0" @@ -2628,6 +2652,16 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" +[[package]] +name = "thingbuf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4706f1bfb859af03f099ada2de3cea3e515843c2d3e93b7893f16d94a37f9415" +dependencies = [ + "parking_lot", + "pin-project", +] + [[package]] name = "thiserror" version = "1.0.50" diff --git a/crates/udp/Cargo.toml b/crates/udp/Cargo.toml index b0c0ef2..4e9eac3 100644 --- a/crates/udp/Cargo.toml +++ b/crates/udp/Cargo.toml @@ -51,6 +51,7 @@ serde = { version = "1", features = ["derive"] } signal-hook = { version = "0.3" } slab = "0.4" socket2 = { version = "0.5", features = ["all"] } +thingbuf = "0.1" time = { version = "0.3", features = ["formatting"] } tinytemplate = "1" diff --git a/crates/udp/src/common.rs b/crates/udp/src/common.rs index 1fe08d1..439a497 100644 --- a/crates/udp/src/common.rs +++ b/crates/udp/src/common.rs @@ -1,5 +1,8 @@ +use std::borrow::Cow; use std::collections::BTreeMap; use std::hash::Hash; +use std::io::Write; +use std::net::{SocketAddr, SocketAddrV4}; use std::sync::atomic::AtomicUsize; use std::sync::Arc; @@ -9,11 +12,56 @@ use aquatic_common::access_list::AccessListArcSwap; use aquatic_common::CanonicalSocketAddr; use aquatic_udp_protocol::*; use hdrhistogram::Histogram; +use thingbuf::mpsc::blocking::SendRef; use crate::config::Config; pub const BUFFER_SIZE: usize = 8192; +#[derive(PartialEq, Eq, Clone, Debug)] +pub enum CowResponse<'a> { + Connect(Cow<'a, ConnectResponse>), + AnnounceIpv4(Cow<'a, AnnounceResponse>), + AnnounceIpv6(Cow<'a, AnnounceResponse>), + Scrape(Cow<'a, ScrapeResponse>), + Error(Cow<'a, ErrorResponse>), +} + +impl From for CowResponse<'_> { + fn from(value: Response) -> Self { + match value { + Response::AnnounceIpv4(r) => Self::AnnounceIpv4(Cow::Owned(r)), + Response::AnnounceIpv6(r) => Self::AnnounceIpv6(Cow::Owned(r)), + Response::Connect(r) => Self::Connect(Cow::Owned(r)), + Response::Scrape(r) => Self::Scrape(Cow::Owned(r)), + Response::Error(r) => Self::Error(Cow::Owned(r)), + } + } +} + +impl<'a> CowResponse<'a> { + pub fn into_owned(self) -> Response { + match self { + CowResponse::Connect(r) => Response::Connect(r.into_owned()), + CowResponse::AnnounceIpv4(r) => Response::AnnounceIpv4(r.into_owned()), + CowResponse::AnnounceIpv6(r) => Response::AnnounceIpv6(r.into_owned()), + CowResponse::Scrape(r) => Response::Scrape(r.into_owned()), + CowResponse::Error(r) => Response::Error(r.into_owned()), + } + } + + #[inline] + pub fn write(&self, bytes: &mut impl Write) -> Result<(), ::std::io::Error> { + match self { + Self::Connect(r) => r.write(bytes), + Self::AnnounceIpv4(r) => r.write(bytes), + Self::AnnounceIpv6(r) => r.write(bytes), + Self::Scrape(r) => r.write(bytes), + Self::Error(r) => r.write(bytes), + } + } +} + #[derive(Debug)] pub struct PendingScrapeRequest { pub slab_key: usize, @@ -39,6 +87,43 @@ pub enum ConnectedResponse { Scrape(PendingScrapeResponse), } +pub enum ConnectedResponseKind { + AnnounceIpv4, + AnnounceIpv6, + Scrape, +} + +pub struct ConnectedResponseWithAddr { + pub kind: ConnectedResponseKind, + pub announce_ipv4: AnnounceResponse, + pub announce_ipv6: AnnounceResponse, + pub scrape: PendingScrapeResponse, + pub addr: CanonicalSocketAddr, +} + +pub struct Recycler; + +impl thingbuf::Recycle for Recycler { + fn new_element(&self) -> ConnectedResponseWithAddr { + ConnectedResponseWithAddr { + kind: ConnectedResponseKind::AnnounceIpv4, + announce_ipv4: AnnounceResponse::empty(), + announce_ipv6: AnnounceResponse::empty(), + scrape: PendingScrapeResponse { + slab_key: 0, + torrent_stats: Default::default(), + }, + addr: CanonicalSocketAddr::new(SocketAddr::V4(SocketAddrV4::new(0.into(), 0))), + } + } + fn recycle(&self, element: &mut ConnectedResponseWithAddr) { + element.announce_ipv4.peers.clear(); + element.announce_ipv6.peers.clear(); + element.scrape.torrent_stats.clear(); + element.addr = CanonicalSocketAddr::new(SocketAddr::V4(SocketAddrV4::new(0.into(), 0))); + } +} + #[derive(Clone, Copy, Debug)] pub struct SocketWorkerIndex(pub usize); @@ -64,17 +149,19 @@ impl ConnectedRequestSender { Self { index, senders } } + pub fn any_full(&self) -> bool { + self.senders.iter().any(|sender| sender.is_full()) + } + pub fn try_send_to( &self, index: SwarmWorkerIndex, request: ConnectedRequest, addr: CanonicalSocketAddr, - ) { + ) -> Result<(), (SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)> { match self.senders[index.0].try_send((self.index, request, addr)) { - Ok(()) => {} - Err(TrySendError::Full(_)) => { - ::log::error!("Request channel {} is full, dropping request. Try increasing number of swarm workers or raising config.worker_channel_size.", index.0) - } + Ok(()) => Ok(()), + Err(TrySendError::Full(r)) => Err((index, r.1, r.2)), Err(TrySendError::Disconnected(_)) => { panic!("Request channel {} is disconnected", index.0); } @@ -83,32 +170,34 @@ impl ConnectedRequestSender { } pub struct ConnectedResponseSender { - senders: Vec>, + senders: Vec>, } impl ConnectedResponseSender { - pub fn new(senders: Vec>) -> Self { + pub fn new( + senders: Vec>, + ) -> Self { Self { senders } } - pub fn try_send_to( + pub fn try_send_ref_to( &self, index: SocketWorkerIndex, - response: ConnectedResponse, - addr: CanonicalSocketAddr, - ) { - match self.senders[index.0].try_send((response, addr)) { - Ok(()) => {} - Err(TrySendError::Full(_)) => { - ::log::error!("Response channel {} is full, dropping response. Try increasing number of socket workers or raising config.worker_channel_size.", index.0) - } - Err(TrySendError::Disconnected(_)) => { - panic!("Response channel {} is disconnected", index.0); - } - } + ) -> Result, thingbuf::mpsc::errors::TrySendError> { + self.senders[index.0].try_send_ref() + } + + pub fn send_ref_to( + &self, + index: SocketWorkerIndex, + ) -> Result, thingbuf::mpsc::errors::Closed> { + self.senders[index.0].send_ref() } } +pub type ConnectedResponseReceiver = + thingbuf::mpsc::blocking::Receiver; + #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] pub enum PeerStatus { Seeding, diff --git a/crates/udp/src/config.rs b/crates/udp/src/config.rs index f2438f0..852fbcf 100644 --- a/crates/udp/src/config.rs +++ b/crates/udp/src/config.rs @@ -26,8 +26,7 @@ pub struct Config { pub swarm_workers: usize, pub log_level: LogLevel, /// Maximum number of items in each channel passing requests/responses - /// between workers. A value of zero means that the channels will be of - /// unbounded size. + /// between workers. A value of zero is no longer allowed. pub worker_channel_size: usize, /// How long to block waiting for requests in swarm workers. /// @@ -59,7 +58,7 @@ impl Default for Config { socket_workers: 1, swarm_workers: 1, log_level: LogLevel::Error, - worker_channel_size: 0, + worker_channel_size: 1024 * 16, request_channel_recv_timeout_ms: 100, network: NetworkConfig::default(), protocol: ProtocolConfig::default(), @@ -99,8 +98,6 @@ pub struct NetworkConfig { /// $ sudo sysctl -w net.core.rmem_max=104857600 /// $ sudo sysctl -w net.core.rmem_default=104857600 pub socket_recv_buffer_size: usize, - /// Poll event capacity (mio backend only) - pub poll_event_capacity: usize, /// Poll timeout in milliseconds (mio backend only) pub poll_timeout_ms: u64, /// Number of ring entries (io_uring backend only) @@ -133,7 +130,6 @@ impl Default for NetworkConfig { address: SocketAddr::from(([0, 0, 0, 0], 3000)), only_ipv6: false, socket_recv_buffer_size: 4096 * 128, - poll_event_capacity: 4096, poll_timeout_ms: 50, #[cfg(feature = "io-uring")] ring_size: 1024, diff --git a/crates/udp/src/lib.rs b/crates/udp/src/lib.rs index 6f8ffe6..6658b37 100644 --- a/crates/udp/src/lib.rs +++ b/crates/udp/src/lib.rs @@ -18,7 +18,8 @@ use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::{PanicSentinelWatcher, ServerStartInstant}; use common::{ - ConnectedRequestSender, ConnectedResponseSender, SocketWorkerIndex, State, SwarmWorkerIndex, + ConnectedRequestSender, ConnectedResponseSender, Recycler, SocketWorkerIndex, State, + SwarmWorkerIndex, }; use config::Config; use workers::socket::ConnectionValidator; @@ -58,11 +59,8 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { } for i in 0..config.socket_workers { - let (response_sender, response_receiver) = if config.worker_channel_size == 0 { - unbounded() - } else { - bounded(config.worker_channel_size) - }; + let (response_sender, response_receiver) = + thingbuf::mpsc::blocking::with_recycle(config.worker_channel_size, Recycler); response_senders.push(response_sender); response_receivers.insert(i, response_receiver); diff --git a/crates/udp/src/workers/socket/mio.rs b/crates/udp/src/workers/socket/mio.rs index 1d254f3..fab0f67 100644 --- a/crates/udp/src/workers/socket/mio.rs +++ b/crates/udp/src/workers/socket/mio.rs @@ -1,10 +1,10 @@ +use std::borrow::Cow; use std::io::{Cursor, ErrorKind}; use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; use aquatic_common::access_list::AccessListCache; use aquatic_common::ServerStartInstant; -use crossbeam_channel::Receiver; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; @@ -21,17 +21,32 @@ use super::storage::PendingScrapeResponseSlab; use super::validator::ConnectionValidator; use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; +enum HandleRequestError { + RequestChannelFull(Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>), +} + +#[derive(Clone, Copy, Debug)] +enum PollMode { + Regular, + SkipPolling, + SkipReceiving, +} + pub struct SocketWorker { config: Config, shared_state: State, request_sender: ConnectedRequestSender, - response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + response_receiver: ConnectedResponseReceiver, access_list_cache: AccessListCache, validator: ConnectionValidator, server_start_instant: ServerStartInstant, pending_scrape_responses: PendingScrapeResponseSlab, socket: UdpSocket, + opt_resend_buffer: Option>, buffer: [u8; BUFFER_SIZE], + polling_mode: PollMode, + /// Storage for requests that couldn't be sent to swarm worker because channel was full + pending_requests: Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>, } impl SocketWorker { @@ -42,12 +57,13 @@ impl SocketWorker { validator: ConnectionValidator, server_start_instant: ServerStartInstant, request_sender: ConnectedRequestSender, - response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, ) { let socket = UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket")); let access_list_cache = create_access_list_cache(&shared_state.access_list); + let opt_resend_buffer = (config.network.resend_buffer_max_len > 0).then_some(Vec::new()); let mut worker = Self { config, @@ -59,18 +75,17 @@ impl SocketWorker { access_list_cache, pending_scrape_responses: Default::default(), socket, + opt_resend_buffer, buffer: [0; BUFFER_SIZE], + polling_mode: PollMode::Regular, + pending_requests: Default::default(), }; worker.run_inner(); } pub fn run_inner(&mut self) { - let mut local_responses = Vec::new(); - let mut opt_resend_buffer = - (self.config.network.resend_buffer_max_len > 0).then_some(Vec::new()); - - let mut events = Events::with_capacity(self.config.network.poll_event_capacity); + let mut events = Events::with_capacity(1); let mut poll = Poll::new().expect("create poll"); poll.registry() @@ -91,17 +106,33 @@ impl SocketWorker { let mut iter_counter = 0usize; loop { - poll.poll(&mut events, Some(poll_timeout)) - .expect("failed polling"); + match self.polling_mode { + PollMode::Regular => { + poll.poll(&mut events, Some(poll_timeout)) + .expect("failed polling"); - for event in events.iter() { - if event.is_readable() { - self.read_and_handle_requests(&mut local_responses, pending_scrape_valid_until); + for event in events.iter() { + if event.is_readable() { + self.read_and_handle_requests(pending_scrape_valid_until); + } + } + } + PollMode::SkipPolling => { + self.polling_mode = PollMode::Regular; + + // Continue reading from socket without polling, since + // reading was previouly cancelled + self.read_and_handle_requests(pending_scrape_valid_until); + } + PollMode::SkipReceiving => { + ::log::info!("Postponing receiving requests because swarm worker channel is full. This means that the OS will be relied on to buffer incoming packets. To prevent this, raise config.worker_channel_size."); + + self.polling_mode = PollMode::SkipPolling; } } // If resend buffer is enabled, send any responses in it - if let Some(resend_buffer) = opt_resend_buffer.as_mut() { + if let Some(resend_buffer) = self.opt_resend_buffer.as_mut() { for (response, addr) in resend_buffer.drain(..) { Self::send_response( &self.config, @@ -109,46 +140,23 @@ impl SocketWorker { &mut self.socket, &mut self.buffer, &mut None, - response, + response.into(), addr, ); } } - // Send any connect and error responses generated by this socket worker - for (response, addr) in local_responses.drain(..) { - Self::send_response( - &self.config, - &self.shared_state, - &mut self.socket, - &mut self.buffer, - &mut opt_resend_buffer, - response, - addr, - ); - } - // Check channel for any responses generated by swarm workers - for (response, addr) in self.response_receiver.try_iter() { - let opt_response = match response { - ConnectedResponse::Scrape(r) => self - .pending_scrape_responses - .add_and_get_finished(r) - .map(Response::Scrape), - ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)), - ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)), - }; + self.handle_swarm_worker_responses(); - if let Some(response) = opt_response { - Self::send_response( - &self.config, - &self.shared_state, - &mut self.socket, - &mut self.buffer, - &mut opt_resend_buffer, - response, - addr, - ); + // Try sending pending requests + while let Some((index, request, addr)) = self.pending_requests.pop() { + if let Err(r) = self.request_sender.try_send_to(index, request, addr) { + self.pending_requests.push(r); + + self.polling_mode = PollMode::SkipReceiving; + + break; } } @@ -174,11 +182,7 @@ impl SocketWorker { } } - fn read_and_handle_requests( - &mut self, - local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, - pending_scrape_valid_until: ValidUntil, - ) { + fn read_and_handle_requests(&mut self, pending_scrape_valid_until: ValidUntil) { let mut requests_received_ipv4: usize = 0; let mut requests_received_ipv6: usize = 0; let mut bytes_received_ipv4: usize = 0; @@ -194,20 +198,19 @@ impl SocketWorker { } let src = CanonicalSocketAddr::new(src); - - ::log::trace!("received request bytes: {}", hex_slice(&self.buffer[..bytes_read])); - let request_parsable = match Request::from_bytes( &self.buffer[..bytes_read], self.config.protocol.max_scrape_torrents, ) { Ok(request) => { - self.handle_request( - local_responses, - pending_scrape_valid_until, - request, - src, - ); + if let Err(HandleRequestError::RequestChannelFull(failed_requests)) = + self.handle_request(pending_scrape_valid_until, request, src) + { + self.pending_requests.extend(failed_requests.into_iter()); + self.polling_mode = PollMode::SkipReceiving; + + break; + } true } @@ -226,7 +229,15 @@ impl SocketWorker { message: err.into(), }; - local_responses.push((response.into(), src)); + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut self.opt_resend_buffer, + CowResponse::Error(Cow::Owned(response)), + src, + ); } } @@ -278,29 +289,34 @@ impl SocketWorker { fn handle_request( &mut self, - local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, pending_scrape_valid_until: ValidUntil, request: Request, src: CanonicalSocketAddr, - ) { + ) -> Result<(), HandleRequestError> { let access_list_mode = self.config.access_list.mode; match request { Request::Connect(request) => { - ::log::trace!("received {:?} from {:?}", request, src); - let connection_id = self.validator.create_connection_id(src); - let response = Response::Connect(ConnectResponse { + let response = ConnectResponse { connection_id, transaction_id: request.transaction_id, - }); + }; - local_responses.push((response, src)) + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut self.opt_resend_buffer, + CowResponse::Connect(Cow::Owned(response)), + src, + ); + + Ok(()) } Request::Announce(request) => { - ::log::trace!("received {:?} from {:?}", request, src); - if self .validator .connection_id_valid(src, request.connection_id) @@ -313,24 +329,34 @@ impl SocketWorker { let worker_index = SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); - self.request_sender.try_send_to( - worker_index, - ConnectedRequest::Announce(request), - src, - ); + self.request_sender + .try_send_to(worker_index, ConnectedRequest::Announce(request), src) + .map_err(|request| { + HandleRequestError::RequestChannelFull(vec![request]) + }) } else { - let response = Response::Error(ErrorResponse { + let response = ErrorResponse { transaction_id: request.transaction_id, message: "Info hash not allowed".into(), - }); + }; - local_responses.push((response, src)) + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut self.opt_resend_buffer, + CowResponse::Error(Cow::Owned(response)), + src, + ); + + Ok(()) } + } else { + Ok(()) } } Request::Scrape(request) => { - ::log::trace!("received {:?} from {:?}", request, src); - if self .validator .connection_id_valid(src, request.connection_id) @@ -341,36 +367,87 @@ impl SocketWorker { pending_scrape_valid_until, ); + let mut failed = Vec::new(); + for (swarm_worker_index, request) in split_requests { - self.request_sender.try_send_to( + if let Err(request) = self.request_sender.try_send_to( swarm_worker_index, ConnectedRequest::Scrape(request), src, - ); + ) { + failed.push(request); + } } + + if failed.is_empty() { + Ok(()) + } else { + Err(HandleRequestError::RequestChannelFull(failed)) + } + } else { + Ok(()) } } } } + fn handle_swarm_worker_responses(&mut self) { + loop { + let recv_ref = if let Ok(recv_ref) = self.response_receiver.try_recv_ref() { + recv_ref + } else { + break; + }; + + let response = match recv_ref.kind { + ConnectedResponseKind::Scrape => { + if let Some(r) = self + .pending_scrape_responses + .add_and_get_finished(&recv_ref.scrape) + { + CowResponse::Scrape(Cow::Owned(r)) + } else { + continue; + } + } + ConnectedResponseKind::AnnounceIpv4 => { + CowResponse::AnnounceIpv4(Cow::Borrowed(&recv_ref.announce_ipv4)) + } + ConnectedResponseKind::AnnounceIpv6 => { + CowResponse::AnnounceIpv6(Cow::Borrowed(&recv_ref.announce_ipv6)) + } + }; + + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut self.opt_resend_buffer, + response, + recv_ref.addr, + ); + } + } + fn send_response( config: &Config, shared_state: &State, socket: &mut UdpSocket, buffer: &mut [u8], opt_resend_buffer: &mut Option>, - response: Response, + response: CowResponse, canonical_addr: CanonicalSocketAddr, ) { - let mut cursor = Cursor::new(buffer); + let mut buffer = Cursor::new(&mut buffer[..]); - if let Err(err) = response.write(&mut cursor) { - ::log::error!("Converting response to bytes failed: {:#}", err); + if let Err(err) = response.write(&mut buffer) { + ::log::error!("failed writing response to buffer: {:#}", err); return; } - let bytes_written = cursor.position() as usize; + let bytes_written = buffer.position() as usize; let addr = if config.network.address.is_ipv4() { canonical_addr @@ -380,9 +457,7 @@ impl SocketWorker { canonical_addr.get_ipv6_mapped() }; - ::log::trace!("sending {:?} to {}, bytes: {}", response, addr, hex_slice(&cursor.get_ref()[..bytes_written])); - - match socket.send_to(&cursor.get_ref()[..bytes_written], addr) { + match socket.send_to(&buffer.into_inner()[..bytes_written], addr) { Ok(amt) if config.statistics.active() => { let stats = if canonical_addr.is_ipv4() { let stats = &shared_state.statistics_ipv4; @@ -403,18 +478,18 @@ impl SocketWorker { }; match response { - Response::Connect(_) => { + CowResponse::Connect(_) => { stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed); } - Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { + CowResponse::AnnounceIpv4(_) | CowResponse::AnnounceIpv6(_) => { stats .responses_sent_announce .fetch_add(1, Ordering::Relaxed); } - Response::Scrape(_) => { + CowResponse::Scrape(_) => { stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed); } - Response::Error(_) => { + CowResponse::Error(_) => { stats.responses_sent_error.fetch_add(1, Ordering::Relaxed); } } @@ -428,7 +503,7 @@ impl SocketWorker { if resend_buffer.len() < config.network.resend_buffer_max_len { ::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); - resend_buffer.push((response, canonical_addr)); + resend_buffer.push((response.into_owned(), canonical_addr)); } else { ::log::warn!("Response resend buffer full, dropping response"); } @@ -440,14 +515,3 @@ impl SocketWorker { } } } - -fn hex_slice(bytes: &[u8]) -> String { - let mut output = String::with_capacity(bytes.len() * 3); - - for chunk in bytes.chunks(4) { - output.push_str(&hex::encode(chunk)); - output.push(' '); - } - - output -} \ No newline at end of file diff --git a/crates/udp/src/workers/socket/mod.rs b/crates/udp/src/workers/socket/mod.rs index dfe57a3..b683e13 100644 --- a/crates/udp/src/workers/socket/mod.rs +++ b/crates/udp/src/workers/socket/mod.rs @@ -5,14 +5,11 @@ mod uring; mod validator; use anyhow::Context; -use aquatic_common::{ - privileges::PrivilegeDropper, CanonicalSocketAddr, PanicSentinel, ServerStartInstant, -}; -use crossbeam_channel::Receiver; +use aquatic_common::{privileges::PrivilegeDropper, PanicSentinel, ServerStartInstant}; use socket2::{Domain, Protocol, Socket, Type}; use crate::{ - common::{ConnectedRequestSender, ConnectedResponse, State}, + common::{ConnectedRequestSender, ConnectedResponseReceiver, State}, config::Config, }; @@ -46,7 +43,7 @@ pub fn run_socket_worker( validator: ConnectionValidator, server_start_instant: ServerStartInstant, request_sender: ConnectedRequestSender, - response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, ) { #[cfg(all(target_os = "linux", feature = "io-uring"))] diff --git a/crates/udp/src/workers/socket/storage.rs b/crates/udp/src/workers/socket/storage.rs index b9e0952..5ded3ac 100644 --- a/crates/udp/src/workers/socket/storage.rs +++ b/crates/udp/src/workers/socket/storage.rs @@ -65,14 +65,12 @@ impl PendingScrapeResponseSlab { pub fn add_and_get_finished( &mut self, - response: PendingScrapeResponse, + response: &PendingScrapeResponse, ) -> Option { let finished = if let Some(entry) = self.0.get_mut(response.slab_key) { entry.num_pending -= 1; - entry - .torrent_stats - .extend(response.torrent_stats.into_iter()); + entry.torrent_stats.extend(response.torrent_stats.iter()); entry.num_pending == 0 } else { @@ -205,7 +203,7 @@ mod tests { torrent_stats, }; - if let Some(response) = map.add_and_get_finished(response) { + if let Some(response) = map.add_and_get_finished(&response) { responses.push(response); } } diff --git a/crates/udp/src/workers/socket/uring/mod.rs b/crates/udp/src/workers/socket/uring/mod.rs index c9b7a21..4ddf875 100644 --- a/crates/udp/src/workers/socket/uring/mod.rs +++ b/crates/udp/src/workers/socket/uring/mod.rs @@ -2,6 +2,7 @@ mod buf_ring; mod recv_helper; mod send_buffers; +use std::borrow::Cow; use std::cell::RefCell; use std::collections::VecDeque; use std::net::UdpSocket; @@ -12,7 +13,6 @@ use std::sync::atomic::Ordering; use anyhow::Context; use aquatic_common::access_list::AccessListCache; use aquatic_common::ServerStartInstant; -use crossbeam_channel::Receiver; use io_uring::opcode::Timeout; use io_uring::types::{Fixed, Timespec}; use io_uring::{IoUring, Probe}; @@ -78,7 +78,7 @@ pub struct SocketWorker { config: Config, shared_state: State, request_sender: ConnectedRequestSender, - response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + response_receiver: ConnectedResponseReceiver, access_list_cache: AccessListCache, validator: ConnectionValidator, server_start_instant: ServerStartInstant, @@ -104,7 +104,7 @@ impl SocketWorker { validator: ConnectionValidator, server_start_instant: ServerStartInstant, request_sender: ConnectedRequestSender, - response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, ) { let ring_entries = config.network.ring_size.next_power_of_two(); @@ -210,14 +210,15 @@ impl SocketWorker { // Enqueue local responses for _ in 0..sq_space { if let Some((response, addr)) = self.local_responses.pop_front() { - match self.send_buffers.prepare_entry(&response, addr) { + match self.send_buffers.prepare_entry(response.into(), addr) { Ok(entry) => { unsafe { ring.submission().push(&entry).unwrap() }; num_send_added += 1; } - Err(send_buffers::Error::NoBuffers) => { - self.local_responses.push_front((response, addr)); + Err(send_buffers::Error::NoBuffers(response)) => { + self.local_responses + .push_front((response.into_owned(), addr)); break; } @@ -232,24 +233,46 @@ impl SocketWorker { // Enqueue swarm worker responses for _ in 0..(sq_space - num_send_added) { - if let Some((response, addr)) = self.get_next_swarm_response() { - match self.send_buffers.prepare_entry(&response, addr) { - Ok(entry) => { - unsafe { ring.submission().push(&entry).unwrap() }; - - num_send_added += 1; - } - Err(send_buffers::Error::NoBuffers) => { - self.local_responses.push_back((response, addr)); - - break; - } - Err(send_buffers::Error::SerializationFailed(err)) => { - ::log::error!("Failed serializing response: {:#}", err); - } - } + let recv_ref = if let Ok(recv_ref) = self.response_receiver.try_recv_ref() { + recv_ref } else { break; + }; + + let response = match recv_ref.kind { + ConnectedResponseKind::AnnounceIpv4 => { + CowResponse::AnnounceIpv4(Cow::Borrowed(&recv_ref.announce_ipv4)) + } + ConnectedResponseKind::AnnounceIpv6 => { + CowResponse::AnnounceIpv6(Cow::Borrowed(&recv_ref.announce_ipv6)) + } + ConnectedResponseKind::Scrape => { + if let Some(response) = self + .pending_scrape_responses + .add_and_get_finished(&recv_ref.scrape) + { + CowResponse::Scrape(Cow::Owned(response)) + } else { + continue; + } + } + }; + + match self.send_buffers.prepare_entry(response, recv_ref.addr) { + Ok(entry) => { + unsafe { ring.submission().push(&entry).unwrap() }; + + num_send_added += 1; + } + Err(send_buffers::Error::NoBuffers(response)) => { + self.local_responses + .push_back((response.into_owned(), recv_ref.addr)); + + break; + } + Err(send_buffers::Error::SerializationFailed(err)) => { + ::log::error!("Failed serializing response: {:#}", err); + } } } @@ -283,12 +306,6 @@ impl SocketWorker { self.config.cleaning.max_pending_scrape_age, ); - ::log::info!( - "pending responses: {} local, {} swarm", - self.local_responses.len(), - self.response_receiver.len() - ); - self.resubmittable_sqe_buf .push(self.pulse_timeout_sqe.clone()); } @@ -464,11 +481,13 @@ impl SocketWorker { let worker_index = SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); - self.request_sender.try_send_to( + if let Err(_) = self.request_sender.try_send_to( worker_index, ConnectedRequest::Announce(request), src, - ); + ) { + ::log::warn!("request sender full, dropping request"); + } } else { let response = Response::Error(ErrorResponse { transaction_id: request.transaction_id, @@ -491,39 +510,18 @@ impl SocketWorker { ); for (swarm_worker_index, request) in split_requests { - self.request_sender.try_send_to( + if let Err(_) = self.request_sender.try_send_to( swarm_worker_index, ConnectedRequest::Scrape(request), src, - ); + ) { + ::log::warn!("request sender full, dropping request"); + } } } } } } - - fn get_next_swarm_response(&mut self) -> Option<(Response, CanonicalSocketAddr)> { - loop { - match self.response_receiver.try_recv() { - Ok((ConnectedResponse::AnnounceIpv4(response), addr)) => { - return Some((Response::AnnounceIpv4(response), addr)); - } - Ok((ConnectedResponse::AnnounceIpv6(response), addr)) => { - return Some((Response::AnnounceIpv6(response), addr)); - } - Ok((ConnectedResponse::Scrape(response), addr)) => { - if let Some(response) = - self.pending_scrape_responses.add_and_get_finished(response) - { - return Some((Response::Scrape(response), addr)); - } - } - Err(_) => { - return None; - } - } - } - } } pub fn supported_on_current_kernel() -> anyhow::Result<()> { diff --git a/crates/udp/src/workers/socket/uring/send_buffers.rs b/crates/udp/src/workers/socket/uring/send_buffers.rs index 262de06..b62ca90 100644 --- a/crates/udp/src/workers/socket/uring/send_buffers.rs +++ b/crates/udp/src/workers/socket/uring/send_buffers.rs @@ -6,15 +6,14 @@ use std::{ }; use aquatic_common::CanonicalSocketAddr; -use aquatic_udp_protocol::Response; use io_uring::opcode::SendMsg; -use crate::config::Config; +use crate::{common::CowResponse, config::Config}; use super::{RESPONSE_BUF_LEN, SOCKET_IDENTIFIER}; -pub enum Error { - NoBuffers, +pub enum Error<'a> { + NoBuffers(CowResponse<'a>), SerializationFailed(std::io::Error), } @@ -58,12 +57,16 @@ impl SendBuffers { self.likely_next_free_index = 0; } - pub fn prepare_entry( + pub fn prepare_entry<'a>( &mut self, - response: &Response, + response: CowResponse<'a>, addr: CanonicalSocketAddr, - ) -> Result { - let index = self.next_free_index()?; + ) -> Result> { + let index = if let Some(index) = self.next_free_index() { + index + } else { + return Err(Error::NoBuffers(response)); + }; let (buffer_metadata, buffer) = self.buffers.get_mut(index).unwrap(); @@ -82,9 +85,9 @@ impl SendBuffers { } } - fn next_free_index(&self) -> Result { + fn next_free_index(&self) -> Option { if self.likely_next_free_index >= self.buffers.len() { - return Err(Error::NoBuffers); + return None; } for (i, (meta, _)) in self.buffers[self.likely_next_free_index..] @@ -92,11 +95,11 @@ impl SendBuffers { .enumerate() { if meta.free { - return Ok(self.likely_next_free_index + i); + return Some(self.likely_next_free_index + i); } } - Err(Error::NoBuffers) + None } } @@ -160,7 +163,7 @@ impl SendBuffer { fn prepare_entry( &mut self, - response: &Response, + response: CowResponse, addr: CanonicalSocketAddr, socket_is_ipv4: bool, metadata: &mut SendBufferMetadata, @@ -196,7 +199,7 @@ impl SendBuffer { Ok(()) => { self.iovec.iov_len = cursor.position() as usize; - metadata.response_type = ResponseType::from_response(response); + metadata.response_type = ResponseType::from_response(&response); Ok(SendMsg::new(SOCKET_IDENTIFIER, addr_of_mut!(self.msghdr)).build()) } @@ -234,12 +237,12 @@ pub enum ResponseType { } impl ResponseType { - fn from_response(response: &Response) -> Self { + fn from_response(response: &CowResponse) -> Self { match response { - Response::Connect(_) => Self::Connect, - Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => Self::Announce, - Response::Scrape(_) => Self::Scrape, - Response::Error(_) => Self::Error, + CowResponse::Connect(_) => Self::Connect, + CowResponse::AnnounceIpv4(_) | CowResponse::AnnounceIpv6(_) => Self::Announce, + CowResponse::Scrape(_) => Self::Scrape, + CowResponse::Error(_) => Self::Error, } } } diff --git a/crates/udp/src/workers/swarm/mod.rs b/crates/udp/src/workers/swarm/mod.rs index 6284d42..43ee480 100644 --- a/crates/udp/src/workers/swarm/mod.rs +++ b/crates/udp/src/workers/swarm/mod.rs @@ -12,12 +12,10 @@ use rand::{rngs::SmallRng, SeedableRng}; use aquatic_common::{CanonicalSocketAddr, PanicSentinel, ValidUntil}; -use aquatic_udp_protocol::*; - use crate::common::*; use crate::config::Config; -use storage::{TorrentMap, TorrentMaps}; +use storage::TorrentMaps; pub fn run_swarm_worker( _sentinel: PanicSentinel, @@ -45,42 +43,65 @@ pub fn run_swarm_worker( loop { if let Ok((sender_index, request, src)) = request_receiver.recv_timeout(timeout) { - let response = match (request, src.get().ip()) { - (ConnectedRequest::Announce(request), IpAddr::V4(ip)) => { - let response = handle_announce_request( - &config, - &mut rng, - &statistics_sender, - &mut torrents.ipv4, - request, - ip.into(), - peer_valid_until, - ); + // It is OK to block here as long as we don't do blocking sends + // in socket workers, which could cause a deadlock + match response_sender.send_ref_to(sender_index) { + Ok(mut send_ref) => { + send_ref.addr = src; - ConnectedResponse::AnnounceIpv4(response) - } - (ConnectedRequest::Announce(request), IpAddr::V6(ip)) => { - let response = handle_announce_request( - &config, - &mut rng, - &statistics_sender, - &mut torrents.ipv6, - request, - ip.into(), - peer_valid_until, - ); + match (request, src.get().ip()) { + (ConnectedRequest::Announce(request), IpAddr::V4(ip)) => { + send_ref.kind = ConnectedResponseKind::AnnounceIpv4; - ConnectedResponse::AnnounceIpv6(response) - } - (ConnectedRequest::Scrape(request), IpAddr::V4(_)) => { - ConnectedResponse::Scrape(handle_scrape_request(&mut torrents.ipv4, request)) - } - (ConnectedRequest::Scrape(request), IpAddr::V6(_)) => { - ConnectedResponse::Scrape(handle_scrape_request(&mut torrents.ipv6, request)) - } - }; + torrents + .ipv4 + .0 + .entry(request.info_hash) + .or_default() + .announce( + &config, + &statistics_sender, + &mut rng, + &request, + ip.into(), + peer_valid_until, + &mut send_ref.announce_ipv4, + ); + } + (ConnectedRequest::Announce(request), IpAddr::V6(ip)) => { + send_ref.kind = ConnectedResponseKind::AnnounceIpv6; - response_sender.try_send_to(sender_index, response, src); + torrents + .ipv6 + .0 + .entry(request.info_hash) + .or_default() + .announce( + &config, + &statistics_sender, + &mut rng, + &request, + ip.into(), + peer_valid_until, + &mut send_ref.announce_ipv6, + ); + } + (ConnectedRequest::Scrape(request), IpAddr::V4(_)) => { + send_ref.kind = ConnectedResponseKind::Scrape; + + torrents.ipv4.scrape(request, &mut send_ref.scrape); + } + (ConnectedRequest::Scrape(request), IpAddr::V6(_)) => { + send_ref.kind = ConnectedResponseKind::Scrape; + + torrents.ipv6.scrape(request, &mut send_ref.scrape); + } + }; + } + Err(_) => { + panic!("swarm response channel closed"); + } + } } // Run periodic tasks @@ -116,88 +137,3 @@ pub fn run_swarm_worker( iter_counter = iter_counter.wrapping_add(1); } } - -fn handle_announce_request( - config: &Config, - rng: &mut SmallRng, - statistics_sender: &Sender, - torrents: &mut TorrentMap, - request: AnnounceRequest, - peer_ip: I, - peer_valid_until: ValidUntil, -) -> AnnounceResponse { - let max_num_peers_to_take: usize = if request.peers_wanted.0.get() <= 0 { - config.protocol.max_response_peers - } else { - ::std::cmp::min( - config.protocol.max_response_peers, - request.peers_wanted.0.get().try_into().unwrap(), - ) - }; - - let torrent_data = torrents.0.entry(request.info_hash).or_default(); - - let peer_status = - PeerStatus::from_event_and_bytes_left(request.event.into(), request.bytes_left); - - torrent_data.update_peer( - config, - statistics_sender, - request.peer_id, - peer_ip, - request.port, - peer_status, - peer_valid_until, - ); - - let response_peers = if let PeerStatus::Stopped = peer_status { - Vec::new() - } else { - torrent_data.extract_response_peers(rng, request.peer_id, max_num_peers_to_take) - }; - - AnnounceResponse { - fixed: AnnounceResponseFixedData { - transaction_id: request.transaction_id, - announce_interval: AnnounceInterval::new(config.protocol.peer_announce_interval), - leechers: NumberOfPeers::new( - torrent_data.num_leechers().try_into().unwrap_or(i32::MAX), - ), - seeders: NumberOfPeers::new(torrent_data.num_seeders().try_into().unwrap_or(i32::MAX)), - }, - peers: response_peers, - } -} - -fn handle_scrape_request( - torrents: &mut TorrentMap, - request: PendingScrapeRequest, -) -> PendingScrapeResponse { - let torrent_stats = request - .info_hashes - .into_iter() - .map(|(i, info_hash)| { - let stats = torrents - .0 - .get(&info_hash) - .map(|torrent_data| torrent_data.scrape_statistics()) - .unwrap_or_else(|| create_torrent_scrape_statistics(0, 0)); - - (i, stats) - }) - .collect(); - - PendingScrapeResponse { - slab_key: request.slab_key, - torrent_stats, - } -} - -#[inline(always)] -fn create_torrent_scrape_statistics(seeders: i32, leechers: i32) -> TorrentScrapeStatistics { - TorrentScrapeStatistics { - seeders: NumberOfPeers::new(seeders), - completed: NumberOfDownloads::new(0), // No implementation planned - leechers: NumberOfPeers::new(leechers), - } -} diff --git a/crates/udp/src/workers/swarm/storage.rs b/crates/udp/src/workers/swarm/storage.rs index 676ad91..acc83bd 100644 --- a/crates/udp/src/workers/swarm/storage.rs +++ b/crates/udp/src/workers/swarm/storage.rs @@ -6,19 +6,18 @@ use aquatic_common::SecondsSinceServerStart; use aquatic_common::ServerStartInstant; use aquatic_common::{ access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache, AccessListMode}, - extract_response_peers, ValidUntil, + ValidUntil, }; use aquatic_udp_protocol::*; use crossbeam_channel::Sender; use hdrhistogram::Histogram; use rand::prelude::SmallRng; +use rand::Rng; use crate::common::*; use crate::config::Config; -use super::create_torrent_scrape_statistics; - #[derive(Clone, Debug)] struct Peer { ip_address: I, @@ -44,65 +43,29 @@ pub struct TorrentData { } impl TorrentData { - pub fn update_peer( + pub fn announce( &mut self, config: &Config, statistics_sender: &Sender, - peer_id: PeerId, + rng: &mut SmallRng, + request: &AnnounceRequest, ip_address: I, - port: Port, - status: PeerStatus, valid_until: ValidUntil, + response: &mut AnnounceResponse, ) { - let opt_removed_peer = match status { - PeerStatus::Leeching => { - let peer = Peer { - ip_address, - port, - is_seeder: false, - valid_until, - }; - - self.peers.insert(peer_id, peer) - } - PeerStatus::Seeding => { - let peer = Peer { - ip_address, - port, - is_seeder: true, - valid_until, - }; - - self.num_seeders += 1; - - self.peers.insert(peer_id, peer) - } - PeerStatus::Stopped => self.peers.remove(&peer_id), + let max_num_peers_to_take: usize = if request.peers_wanted.0.get() <= 0 { + config.protocol.max_response_peers + } else { + ::std::cmp::min( + config.protocol.max_response_peers, + request.peers_wanted.0.get().try_into().unwrap(), + ) }; - if config.statistics.peer_clients { - match (status, opt_removed_peer.is_some()) { - // We added a new peer - (PeerStatus::Leeching | PeerStatus::Seeding, false) => { - if let Err(_) = - statistics_sender.try_send(StatisticsMessage::PeerAdded(peer_id)) - { - // Should never happen in practice - ::log::error!("Couldn't send StatisticsMessage::PeerAdded"); - } - } - // We removed an existing peer - (PeerStatus::Stopped, true) => { - if let Err(_) = - statistics_sender.try_send(StatisticsMessage::PeerRemoved(peer_id)) - { - // Should never happen in practice - ::log::error!("Couldn't send StatisticsMessage::PeerRemoved"); - } - } - _ => (), - } - } + let status = + PeerStatus::from_event_and_bytes_left(request.event.into(), request.bytes_left); + + let opt_removed_peer = self.peers.remove(&request.peer_id); if let Some(Peer { is_seeder: true, .. @@ -110,21 +73,69 @@ impl TorrentData { { self.num_seeders -= 1; } - } - pub fn extract_response_peers( - &self, - rng: &mut SmallRng, - peer_id: PeerId, - max_num_peers_to_take: usize, - ) -> Vec> { + // Create the response before inserting the peer. This means that we + // don't have to filter it out from the response peers, and that the + // reported number of seeders/leechers will not include it + + response.fixed = AnnounceResponseFixedData { + transaction_id: request.transaction_id, + announce_interval: AnnounceInterval::new(config.protocol.peer_announce_interval), + leechers: NumberOfPeers::new(self.num_leechers().try_into().unwrap_or(i32::MAX)), + seeders: NumberOfPeers::new(self.num_seeders().try_into().unwrap_or(i32::MAX)), + }; + extract_response_peers( rng, &self.peers, max_num_peers_to_take, - peer_id, Peer::to_response_peer, - ) + &mut response.peers, + ); + + match status { + PeerStatus::Leeching => { + let peer = Peer { + ip_address, + port: request.port, + is_seeder: false, + valid_until, + }; + + self.peers.insert(request.peer_id, peer); + + if config.statistics.peer_clients && opt_removed_peer.is_none() { + statistics_sender + .try_send(StatisticsMessage::PeerAdded(request.peer_id)) + .expect("statistics channel should be unbounded"); + } + } + PeerStatus::Seeding => { + let peer = Peer { + ip_address, + port: request.port, + is_seeder: true, + valid_until, + }; + + self.peers.insert(request.peer_id, peer); + + self.num_seeders += 1; + + if config.statistics.peer_clients && opt_removed_peer.is_none() { + statistics_sender + .try_send(StatisticsMessage::PeerAdded(request.peer_id)) + .expect("statistics channel should be unbounded"); + } + } + PeerStatus::Stopped => { + if config.statistics.peer_clients && opt_removed_peer.is_some() { + statistics_sender + .try_send(StatisticsMessage::PeerRemoved(request.peer_id)) + .expect("statistics channel should be unbounded"); + } + } + }; } pub fn num_leechers(&self) -> usize { @@ -188,6 +199,21 @@ impl Default for TorrentData { pub struct TorrentMap(pub IndexMap>); impl TorrentMap { + pub fn scrape(&mut self, request: PendingScrapeRequest, response: &mut PendingScrapeResponse) { + response.slab_key = request.slab_key; + + let torrent_stats = request.info_hashes.into_iter().map(|(i, info_hash)| { + let stats = self + .0 + .get(&info_hash) + .map(|torrent_data| torrent_data.scrape_statistics()) + .unwrap_or_else(|| create_torrent_scrape_statistics(0, 0)); + + (i, stats) + }); + + response.torrent_stats.extend(torrent_stats); + } /// Remove forbidden or inactive torrents, reclaim space and return number of remaining peers fn clean_and_get_statistics( &mut self, @@ -306,94 +332,60 @@ impl TorrentMaps { } } } +/// Extract response peers +/// +/// If there are more peers in map than `max_num_peers_to_take`, do a random +/// selection of peers from first and second halves of map in order to avoid +/// returning too homogeneous peers. +/// +/// Does NOT filter out announcing peer. +#[inline] +pub fn extract_response_peers( + rng: &mut impl Rng, + peer_map: &IndexMap, + max_num_peers_to_take: usize, + peer_conversion_function: F, + peers: &mut Vec, +) where + K: Eq + ::std::hash::Hash, + F: Fn(&K, &V) -> R, +{ + if peer_map.len() <= max_num_peers_to_take { + peers.extend(peer_map.iter().map(|(k, v)| peer_conversion_function(k, v))); + } else { + let middle_index = peer_map.len() / 2; + let num_to_take_per_half = max_num_peers_to_take / 2; -#[cfg(test)] -mod tests { - use std::collections::HashSet; + let offset_half_one = { + let from = 0; + let to = usize::max(1, middle_index - num_to_take_per_half); - use quickcheck::{quickcheck, TestResult}; - use rand::thread_rng; + rng.gen_range(from..to) + }; + let offset_half_two = { + let from = middle_index; + let to = usize::max(middle_index + 1, peer_map.len() - num_to_take_per_half); - use super::*; + rng.gen_range(from..to) + }; - fn gen_peer_id(i: u32) -> PeerId { - let mut peer_id = PeerId([0; 20]); + let end_half_one = offset_half_one + num_to_take_per_half; + let end_half_two = offset_half_two + num_to_take_per_half; - peer_id.0[0..4].copy_from_slice(&i.to_ne_bytes()); - - peer_id - } - fn gen_peer(i: u32) -> Peer { - Peer { - ip_address: Ipv4AddrBytes(i.to_be_bytes()), - port: Port::new(1), - is_seeder: false, - valid_until: ValidUntil::new(ServerStartInstant::new(), 0), + if let Some(slice) = peer_map.get_range(offset_half_one..end_half_one) { + peers.extend(slice.iter().map(|(k, v)| peer_conversion_function(k, v))); } - } - - #[test] - fn test_extract_response_peers() { - fn prop(data: (u16, u16)) -> TestResult { - let gen_num_peers = data.0 as u32; - let req_num_peers = data.1 as usize; - - let mut peer_map: PeerMap = Default::default(); - - let mut opt_sender_key = None; - let mut opt_sender_peer = None; - - for i in 0..gen_num_peers { - let key = gen_peer_id(i); - let peer = gen_peer((i << 16) + i); - - if i == 0 { - opt_sender_key = Some(key); - opt_sender_peer = Some(Peer::to_response_peer(&key, &peer)); - } - - peer_map.insert(key, peer); - } - - let mut rng = thread_rng(); - - let peers = extract_response_peers( - &mut rng, - &peer_map, - req_num_peers, - opt_sender_key.unwrap_or_else(|| gen_peer_id(1)), - Peer::to_response_peer, - ); - - // Check that number of returned peers is correct - - let mut success = peers.len() <= req_num_peers; - - if req_num_peers >= gen_num_peers as usize { - success &= peers.len() == gen_num_peers as usize - || peers.len() + 1 == gen_num_peers as usize; - } - - // Check that returned peers are unique (no overlap) and that sender - // isn't returned - - let mut ip_addresses = HashSet::with_capacity(peers.len()); - - for peer in peers { - if peer == opt_sender_peer.clone().unwrap() - || ip_addresses.contains(&peer.ip_address) - { - success = false; - - break; - } - - ip_addresses.insert(peer.ip_address); - } - - TestResult::from_bool(success) + if let Some(slice) = peer_map.get_range(offset_half_two..end_half_two) { + peers.extend(slice.iter().map(|(k, v)| peer_conversion_function(k, v))); } - - quickcheck(prop as fn((u16, u16)) -> TestResult); + } +} + +#[inline(always)] +fn create_torrent_scrape_statistics(seeders: i32, leechers: i32) -> TorrentScrapeStatistics { + TorrentScrapeStatistics { + seeders: NumberOfPeers::new(seeders), + completed: NumberOfDownloads::new(0), // No implementation planned + leechers: NumberOfPeers::new(leechers), } } diff --git a/crates/udp/tests/requests_responses.rs b/crates/udp/tests/requests_responses.rs index ee7abc2..0ce345d 100644 --- a/crates/udp/tests/requests_responses.rs +++ b/crates/udp/tests/requests_responses.rs @@ -35,12 +35,6 @@ fn test_multiple_connect_announce_scrape() -> anyhow::Result<()> { for i in 0..20 { let is_seeder = i % 3 == 0; - if is_seeder { - num_seeders += 1; - } else { - num_leechers += 1; - } - let socket = UdpSocket::bind(peer_addr)?; socket.set_read_timeout(Some(Duration::from_secs(1)))?; @@ -81,6 +75,13 @@ fn test_multiple_connect_announce_scrape() -> anyhow::Result<()> { assert_eq!(response_peer_ports, expected_peer_ports); } + // Do this after announce is evaluated, since it is expected not to include announcing peer + if is_seeder { + num_seeders += 1; + } else { + num_leechers += 1; + } + let scrape_response = scrape( &socket, tracker_addr, diff --git a/crates/udp_protocol/src/response.rs b/crates/udp_protocol/src/response.rs index 45d45b8..882a928 100644 --- a/crates/udp_protocol/src/response.rs +++ b/crates/udp_protocol/src/response.rs @@ -20,33 +20,12 @@ impl Response { #[inline] pub fn write(&self, bytes: &mut impl Write) -> Result<(), io::Error> { match self { - Response::Connect(r) => { - bytes.write_i32::(0)?; - bytes.write_all(r.as_bytes())?; - } - Response::AnnounceIpv4(r) => { - bytes.write_i32::(1)?; - bytes.write_all(r.fixed.as_bytes())?; - bytes.write_all((*r.peers.as_slice()).as_bytes())?; - } - Response::AnnounceIpv6(r) => { - bytes.write_i32::(1)?; - bytes.write_all(r.fixed.as_bytes())?; - bytes.write_all((*r.peers.as_slice()).as_bytes())?; - } - Response::Scrape(r) => { - bytes.write_i32::(2)?; - bytes.write_all(r.transaction_id.as_bytes())?; - bytes.write_all((*r.torrent_stats.as_slice()).as_bytes())?; - } - Response::Error(r) => { - bytes.write_i32::(3)?; - bytes.write_all(r.transaction_id.as_bytes())?; - bytes.write_all(r.message.as_bytes())?; - } + Response::Connect(r) => r.write(bytes), + Response::AnnounceIpv4(r) => r.write(bytes), + Response::AnnounceIpv6(r) => r.write(bytes), + Response::Scrape(r) => r.write(bytes), + Response::Error(r) => r.write(bytes), } - - Ok(()) } #[inline] @@ -156,12 +135,40 @@ pub struct ConnectResponse { pub connection_id: ConnectionId, } +impl ConnectResponse { + #[inline] + pub fn write(&self, bytes: &mut impl Write) -> Result<(), io::Error> { + bytes.write_i32::(0)?; + bytes.write_all(self.as_bytes())?; + + Ok(()) + } +} + #[derive(PartialEq, Eq, Clone, Debug)] pub struct AnnounceResponse { pub fixed: AnnounceResponseFixedData, pub peers: Vec>, } +impl AnnounceResponse { + pub fn empty() -> Self { + Self { + fixed: FromZeroes::new_zeroed(), + peers: Default::default(), + } + } + + #[inline] + pub fn write(&self, bytes: &mut impl Write) -> Result<(), io::Error> { + bytes.write_i32::(1)?; + bytes.write_all(self.fixed.as_bytes())?; + bytes.write_all((*self.peers.as_slice()).as_bytes())?; + + Ok(()) + } +} + #[derive(PartialEq, Eq, Clone, Debug, AsBytes, FromBytes, FromZeroes)] #[repr(C, packed)] pub struct AnnounceResponseFixedData { @@ -177,6 +184,17 @@ pub struct ScrapeResponse { pub torrent_stats: Vec, } +impl ScrapeResponse { + #[inline] + pub fn write(&self, bytes: &mut impl Write) -> Result<(), io::Error> { + bytes.write_i32::(2)?; + bytes.write_all(self.transaction_id.as_bytes())?; + bytes.write_all((*self.torrent_stats.as_slice()).as_bytes())?; + + Ok(()) + } +} + #[derive(PartialEq, Eq, Debug, Copy, Clone, AsBytes, FromBytes, FromZeroes)] #[repr(C, packed)] pub struct TorrentScrapeStatistics { @@ -191,6 +209,17 @@ pub struct ErrorResponse { pub message: Cow<'static, str>, } +impl ErrorResponse { + #[inline] + pub fn write(&self, bytes: &mut impl Write) -> Result<(), io::Error> { + bytes.write_i32::(3)?; + bytes.write_all(self.transaction_id.as_bytes())?; + bytes.write_all(self.message.as_bytes())?; + + Ok(()) + } +} + #[cfg(test)] mod tests { use quickcheck_macros::quickcheck;