diff --git a/CHANGELOG.md b/CHANGELOG.md index 00908f2..560f800 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,10 +34,12 @@ * Remove config key `network.poll_event_capacity` (always use 1) * Speed up parsing and serialization of requests and responses by using [zerocopy](https://crates.io/crates/zerocopy) +* Report socket worker related prometheus stats per worker #### Fixed * Quit whole application if any worker thread quits +* Disallow announce requests with port value of 0 ### aquatic_http diff --git a/Cargo.lock b/Cargo.lock index 7ac690c..59407d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -300,6 +300,7 @@ dependencies = [ "compact_str", "constant_time_eq", "crossbeam-channel", + "crossbeam-utils", "getrandom", "hashbrown 0.14.3", "hdrhistogram", diff --git a/crates/udp/Cargo.toml b/crates/udp/Cargo.toml index 041a3cb..3828132 100644 --- a/crates/udp/Cargo.toml +++ b/crates/udp/Cargo.toml @@ -38,6 +38,7 @@ cfg-if = "1" compact_str = "0.7" constant_time_eq = "0.3" crossbeam-channel = "0.5" +crossbeam-utils = "0.8" getrandom = "0.2" hashbrown = { version = "0.14", default-features = false } hdrhistogram = "7" diff --git a/crates/udp/src/common.rs b/crates/udp/src/common.rs index 19f8b27..4dd3496 100644 --- a/crates/udp/src/common.rs +++ b/crates/udp/src/common.rs @@ -1,19 +1,49 @@ use std::collections::BTreeMap; use std::hash::Hash; +use std::iter::repeat_with; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crossbeam_channel::{Receiver, SendError, Sender, TrySendError}; use aquatic_common::access_list::AccessListArcSwap; -use aquatic_common::CanonicalSocketAddr; +use aquatic_common::{CanonicalSocketAddr, ServerStartInstant}; use aquatic_udp_protocol::*; +use crossbeam_utils::CachePadded; use hdrhistogram::Histogram; use crate::config::Config; pub const BUFFER_SIZE: usize = 8192; +#[derive(Clone, Copy, Debug)] +pub enum IpVersion { + V4, + V6, +} + +#[cfg(feature = "prometheus")] +impl IpVersion { + pub fn prometheus_str(&self) -> &'static str { + match self { + Self::V4 => "4", + Self::V6 => "6", + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct SocketWorkerIndex(pub usize); + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct SwarmWorkerIndex(pub usize); + +impl SwarmWorkerIndex { + pub fn from_info_hash(config: &Config, info_hash: InfoHash) -> Self { + Self(info_hash.0[0] as usize % config.swarm_workers) + } +} + #[derive(Debug)] pub struct PendingScrapeRequest { pub slab_key: usize, @@ -39,18 +69,6 @@ pub enum ConnectedResponse { Scrape(PendingScrapeResponse), } -#[derive(Clone, Copy, Debug)] -pub struct SocketWorkerIndex(pub usize); - -#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub struct SwarmWorkerIndex(pub usize); - -impl SwarmWorkerIndex { - pub fn from_info_hash(config: &Config, info_hash: InfoHash) -> Self { - Self(info_hash.0[0] as usize % config.swarm_workers) - } -} - pub struct ConnectedRequestSender { index: SocketWorkerIndex, senders: Vec>, @@ -64,10 +82,6 @@ impl ConnectedRequestSender { Self { index, senders } } - pub fn any_full(&self) -> bool { - self.senders.iter().any(|sender| sender.is_full()) - } - pub fn try_send_to( &self, index: SwarmWorkerIndex, @@ -153,29 +167,59 @@ impl ConnectedResponseSender { pub type ConnectedResponseReceiver = Receiver<(CanonicalSocketAddr, ConnectedResponse)>; -#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] -pub enum PeerStatus { - Seeding, - Leeching, - Stopped, +#[derive(Clone)] +pub struct Statistics { + pub socket: Vec>>, + pub swarm: Vec>>, } -impl PeerStatus { - /// Determine peer status from announce event and number of bytes left. - /// - /// Likely, the last branch will be taken most of the time. - #[inline] - pub fn from_event_and_bytes_left(event: AnnounceEvent, bytes_left: NumberOfBytes) -> Self { - if event == AnnounceEvent::Stopped { - Self::Stopped - } else if bytes_left.0.get() == 0 { - Self::Seeding - } else { - Self::Leeching +impl Statistics { + pub fn new(config: &Config) -> Self { + Self { + socket: repeat_with(Default::default) + .take(config.socket_workers) + .collect(), + swarm: repeat_with(Default::default) + .take(config.swarm_workers) + .collect(), } } } +#[derive(Default)] +pub struct IpVersionStatistics { + pub ipv4: T, + pub ipv6: T, +} + +impl IpVersionStatistics { + pub fn by_ip_version(&self, ip_version: IpVersion) -> &T { + match ip_version { + IpVersion::V4 => &self.ipv4, + IpVersion::V6 => &self.ipv6, + } + } +} + +#[derive(Default)] +pub struct SocketWorkerStatistics { + pub requests: AtomicUsize, + pub responses_connect: AtomicUsize, + pub responses_announce: AtomicUsize, + pub responses_scrape: AtomicUsize, + pub responses_error: AtomicUsize, + pub bytes_received: AtomicUsize, + pub bytes_sent: AtomicUsize, +} + +pub type CachePaddedArc = CachePadded>>; + +#[derive(Default)] +pub struct SwarmWorkerStatistics { + pub torrents: AtomicUsize, + pub peers: AtomicUsize, +} + pub enum StatisticsMessage { Ipv4PeerHistogram(Histogram), Ipv6PeerHistogram(Histogram), @@ -183,86 +227,29 @@ pub enum StatisticsMessage { PeerRemoved(PeerId), } -pub struct Statistics { - pub requests_received: AtomicUsize, - pub responses_sent_connect: AtomicUsize, - pub responses_sent_announce: AtomicUsize, - pub responses_sent_scrape: AtomicUsize, - pub responses_sent_error: AtomicUsize, - pub bytes_received: AtomicUsize, - pub bytes_sent: AtomicUsize, - pub torrents: Vec, - pub peers: Vec, -} - -impl Statistics { - pub fn new(num_swarm_workers: usize) -> Self { - Self { - requests_received: Default::default(), - responses_sent_connect: Default::default(), - responses_sent_announce: Default::default(), - responses_sent_scrape: Default::default(), - responses_sent_error: Default::default(), - bytes_received: Default::default(), - bytes_sent: Default::default(), - torrents: Self::create_atomic_usize_vec(num_swarm_workers), - peers: Self::create_atomic_usize_vec(num_swarm_workers), - } - } - - fn create_atomic_usize_vec(len: usize) -> Vec { - ::std::iter::repeat_with(AtomicUsize::default) - .take(len) - .collect() - } -} - #[derive(Clone)] pub struct State { pub access_list: Arc, - pub statistics_ipv4: Arc, - pub statistics_ipv6: Arc, + pub server_start_instant: ServerStartInstant, } -impl State { - pub fn new(num_swarm_workers: usize) -> Self { +impl Default for State { + fn default() -> Self { Self { access_list: Arc::new(AccessListArcSwap::default()), - statistics_ipv4: Arc::new(Statistics::new(num_swarm_workers)), - statistics_ipv6: Arc::new(Statistics::new(num_swarm_workers)), + server_start_instant: ServerStartInstant::new(), } } } #[cfg(test)] mod tests { - use std::net::Ipv6Addr; + use std::{net::Ipv6Addr, num::NonZeroU16}; use crate::config::Config; use super::*; - #[test] - fn test_peer_status_from_event_and_bytes_left() { - use crate::common::*; - - use PeerStatus::*; - - let f = PeerStatus::from_event_and_bytes_left; - - assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes::new(0))); - assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes::new(1))); - - assert_eq!(Seeding, f(AnnounceEvent::Started, NumberOfBytes::new(0))); - assert_eq!(Leeching, f(AnnounceEvent::Started, NumberOfBytes::new(1))); - - assert_eq!(Seeding, f(AnnounceEvent::Completed, NumberOfBytes::new(0))); - assert_eq!(Leeching, f(AnnounceEvent::Completed, NumberOfBytes::new(1))); - - assert_eq!(Seeding, f(AnnounceEvent::None, NumberOfBytes::new(0))); - assert_eq!(Leeching, f(AnnounceEvent::None, NumberOfBytes::new(1))); - } - // Assumes that announce response with maximum amount of ipv6 peers will // be the longest #[test] @@ -273,7 +260,7 @@ mod tests { let peers = ::std::iter::repeat(ResponsePeer { ip_address: Ipv6AddrBytes(Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1).octets()), - port: Port::new(1), + port: Port::new(NonZeroU16::new(1).unwrap()), }) .take(config.protocol.max_response_peers) .collect(); diff --git a/crates/udp/src/lib.rs b/crates/udp/src/lib.rs index d26e8bb..639e2a9 100644 --- a/crates/udp/src/lib.rs +++ b/crates/udp/src/lib.rs @@ -16,10 +16,10 @@ use aquatic_common::access_list::update_access_list; #[cfg(feature = "cpu-pinning")] use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; use aquatic_common::privileges::PrivilegeDropper; -use aquatic_common::ServerStartInstant; use common::{ - ConnectedRequestSender, ConnectedResponseSender, SocketWorkerIndex, State, SwarmWorkerIndex, + ConnectedRequestSender, ConnectedResponseSender, SocketWorkerIndex, State, Statistics, + SwarmWorkerIndex, }; use config::Config; use workers::socket::ConnectionValidator; @@ -31,7 +31,8 @@ pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn run(config: Config) -> ::anyhow::Result<()> { let mut signals = Signals::new([SIGUSR1])?; - let state = State::new(config.swarm_workers); + let state = State::default(); + let statistics = Statistics::new(&config); let connection_validator = ConnectionValidator::new(&config)?; let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); let mut join_handles = Vec::new(); @@ -46,8 +47,6 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let (statistics_sender, statistics_receiver) = unbounded(); - let server_start_instant = ServerStartInstant::new(); - for i in 0..config.swarm_workers { let (request_sender, request_receiver) = bounded(config.worker_channel_size); @@ -68,6 +67,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let request_receiver = request_receivers.remove(&i).unwrap().clone(); let response_sender = ConnectedResponseSender::new(response_senders.clone()); let statistics_sender = statistics_sender.clone(); + let statistics = statistics.swarm[i].clone(); let handle = Builder::new() .name(format!("swarm-{:02}", i + 1)) @@ -83,7 +83,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let mut worker = SwarmWorker { config, state, - server_start_instant, + statistics, request_receiver, response_sender, statistics_sender, @@ -105,6 +105,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone()); let response_receiver = response_receivers.remove(&i).unwrap(); let priv_dropper = priv_dropper.clone(); + let statistics = statistics.socket[i].clone(); let handle = Builder::new() .name(format!("socket-{:02}", i + 1)) @@ -118,10 +119,10 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { ); workers::socket::run_socket_worker( - state, config, + state, + statistics, connection_validator, - server_start_instant, request_sender, response_receiver, priv_dropper, @@ -147,7 +148,12 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { WorkerIndex::Util, ); - workers::statistics::run_statistics_worker(config, state, statistics_receiver) + workers::statistics::run_statistics_worker( + config, + state, + statistics, + statistics_receiver, + ) }) .with_context(|| "spawn statistics worker")?; @@ -187,6 +193,11 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { .build() .context("build prometheus recorder and exporter")?; + let recorder_handle = recorder.handle(); + + ::metrics::set_global_recorder(recorder) + .context("set global metrics recorder")?; + ::tokio::spawn(async move { let mut interval = ::tokio::time::interval(Duration::from_secs(5)); @@ -195,7 +206,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { // Periodically render metrics to make sure // idles are cleaned up - recorder.handle().render(); + recorder_handle.render(); } }); diff --git a/crates/udp/src/workers/socket/mio.rs b/crates/udp/src/workers/socket/mio.rs index 05f90fa..26ab5be 100644 --- a/crates/udp/src/workers/socket/mio.rs +++ b/crates/udp/src/workers/socket/mio.rs @@ -4,7 +4,6 @@ use std::time::{Duration, Instant}; use anyhow::Context; use aquatic_common::access_list::AccessListCache; -use aquatic_common::ServerStartInstant; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; @@ -35,11 +34,11 @@ enum PollMode { pub struct SocketWorker { config: Config, shared_state: State, + statistics: CachePaddedArc>, request_sender: ConnectedRequestSender, response_receiver: ConnectedResponseReceiver, access_list_cache: AccessListCache, validator: ConnectionValidator, - server_start_instant: ServerStartInstant, pending_scrape_responses: PendingScrapeResponseSlab, socket: UdpSocket, opt_resend_buffer: Option>, @@ -51,10 +50,10 @@ pub struct SocketWorker { impl SocketWorker { pub fn run( - shared_state: State, config: Config, + shared_state: State, + statistics: CachePaddedArc>, validator: ConnectionValidator, - server_start_instant: ServerStartInstant, request_sender: ConnectedRequestSender, response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, @@ -66,8 +65,8 @@ impl SocketWorker { let mut worker = Self { config, shared_state, + statistics, validator, - server_start_instant, request_sender, response_receiver, access_list_cache, @@ -96,7 +95,7 @@ impl SocketWorker { Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval); let mut pending_scrape_valid_until = ValidUntil::new( - self.server_start_instant, + self.shared_state.server_start_instant, self.config.cleaning.max_pending_scrape_age, ); let mut last_pending_scrape_cleaning = Instant::now(); @@ -131,9 +130,9 @@ impl SocketWorker { // If resend buffer is enabled, send any responses in it if let Some(resend_buffer) = self.opt_resend_buffer.as_mut() { for (addr, response) in resend_buffer.drain(..) { - Self::send_response( + send_response( &self.config, - &self.shared_state, + &self.statistics, &mut self.socket, &mut self.buffer, &mut None, @@ -159,7 +158,7 @@ impl SocketWorker { // Run periodic ValidUntil updates and state cleaning if iter_counter % 256 == 0 { - let seconds_since_start = self.server_start_instant.seconds_elapsed(); + let seconds_since_start = self.shared_state.server_start_instant.seconds_elapsed(); pending_scrape_valid_until = ValidUntil::new_with_now( seconds_since_start, @@ -180,26 +179,49 @@ impl SocketWorker { } fn read_and_handle_requests(&mut self, pending_scrape_valid_until: ValidUntil) { - let mut requests_received_ipv4: usize = 0; - let mut requests_received_ipv6: usize = 0; - let mut bytes_received_ipv4: usize = 0; - let mut bytes_received_ipv6 = 0; + let max_scrape_torrents = self.config.protocol.max_scrape_torrents; loop { match self.socket.recv_from(&mut self.buffer[..]) { Ok((bytes_read, src)) => { - if src.port() == 0 { - ::log::debug!("Ignored request from {} because source port is zero", src); + let src_port = src.port(); + let src = CanonicalSocketAddr::new(src); + + // Use canonical address for statistics + let opt_statistics = if self.config.statistics.active() { + if src.is_ipv4() { + let statistics = &self.statistics.ipv4; + + statistics + .bytes_received + .fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); + + Some(statistics) + } else { + let statistics = &self.statistics.ipv6; + + statistics + .bytes_received + .fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + + Some(statistics) + } + } else { + None + }; + + if src_port == 0 { + ::log::debug!("Ignored request because source port is zero"); continue; } - let src = CanonicalSocketAddr::new(src); - let request_parsable = match Request::parse_bytes( - &self.buffer[..bytes_read], - self.config.protocol.max_scrape_torrents, - ) { + match Request::parse_bytes(&self.buffer[..bytes_read], max_scrape_torrents) { Ok(request) => { + if let Some(statistics) = opt_statistics { + statistics.requests.fetch_add(1, Ordering::Relaxed); + } + if let Err(HandleRequestError::RequestChannelFull(failed_requests)) = self.handle_request(pending_scrape_valid_until, request, src) { @@ -208,52 +230,36 @@ impl SocketWorker { break; } + } + Err(RequestParseError::Sendable { + connection_id, + transaction_id, + err, + }) if self.validator.connection_id_valid(src, connection_id) => { + let response = ErrorResponse { + transaction_id, + message: err.into(), + }; - true + send_response( + &self.config, + &self.statistics, + &mut self.socket, + &mut self.buffer, + &mut self.opt_resend_buffer, + Response::Error(response), + src, + ); + + ::log::debug!("request parse error (sent error response): {:?}", err); } Err(err) => { - ::log::debug!("Request::from_bytes error: {:?}", err); - - if let RequestParseError::Sendable { - connection_id, - transaction_id, - err, - } = err - { - if self.validator.connection_id_valid(src, connection_id) { - let response = ErrorResponse { - transaction_id, - message: err.into(), - }; - - Self::send_response( - &self.config, - &self.shared_state, - &mut self.socket, - &mut self.buffer, - &mut self.opt_resend_buffer, - Response::Error(response), - src, - ); - } - } - - false + ::log::debug!( + "request parse error (didn't send error response): {:?}", + err + ); } }; - - // Update statistics for converted address - if src.is_ipv4() { - if request_parsable { - requests_received_ipv4 += 1; - } - bytes_received_ipv4 += bytes_read + EXTRA_PACKET_SIZE_IPV4; - } else { - if request_parsable { - requests_received_ipv6 += 1; - } - bytes_received_ipv6 += bytes_read + EXTRA_PACKET_SIZE_IPV6; - } } Err(err) if err.kind() == ErrorKind::WouldBlock => { break; @@ -263,25 +269,6 @@ impl SocketWorker { } } } - - if self.config.statistics.active() { - self.shared_state - .statistics_ipv4 - .requests_received - .fetch_add(requests_received_ipv4, Ordering::Relaxed); - self.shared_state - .statistics_ipv6 - .requests_received - .fetch_add(requests_received_ipv6, Ordering::Relaxed); - self.shared_state - .statistics_ipv4 - .bytes_received - .fetch_add(bytes_received_ipv4, Ordering::Relaxed); - self.shared_state - .statistics_ipv6 - .bytes_received - .fetch_add(bytes_received_ipv6, Ordering::Relaxed); - } } fn handle_request( @@ -301,9 +288,9 @@ impl SocketWorker { transaction_id: request.transaction_id, }; - Self::send_response( + send_response( &self.config, - &self.shared_state, + &self.statistics, &mut self.socket, &mut self.buffer, &mut self.opt_resend_buffer, @@ -337,9 +324,9 @@ impl SocketWorker { message: "Info hash not allowed".into(), }; - Self::send_response( + send_response( &self.config, - &self.shared_state, + &self.statistics, &mut self.socket, &mut self.buffer, &mut self.opt_resend_buffer, @@ -405,9 +392,9 @@ impl SocketWorker { ConnectedResponse::AnnounceIpv6(r) => Response::AnnounceIpv6(r), }; - Self::send_response( + send_response( &self.config, - &self.shared_state, + &self.statistics, &mut self.socket, &mut self.buffer, &mut self.opt_resend_buffer, @@ -416,73 +403,73 @@ impl SocketWorker { ); } } +} - fn send_response( - config: &Config, - shared_state: &State, - socket: &mut UdpSocket, - buffer: &mut [u8], - opt_resend_buffer: &mut Option>, - response: Response, - canonical_addr: CanonicalSocketAddr, - ) { - let mut buffer = Cursor::new(&mut buffer[..]); +fn send_response( + config: &Config, + statistics: &CachePaddedArc>, + socket: &mut UdpSocket, + buffer: &mut [u8], + opt_resend_buffer: &mut Option>, + response: Response, + canonical_addr: CanonicalSocketAddr, +) { + let mut buffer = Cursor::new(&mut buffer[..]); - if let Err(err) = response.write_bytes(&mut buffer) { - ::log::error!("failed writing response to buffer: {:#}", err); + if let Err(err) = response.write_bytes(&mut buffer) { + ::log::error!("failed writing response to buffer: {:#}", err); - return; - } + return; + } - let bytes_written = buffer.position() as usize; + let bytes_written = buffer.position() as usize; - let addr = if config.network.address.is_ipv4() { - canonical_addr - .get_ipv4() - .expect("found peer ipv6 address while running bound to ipv4 address") - } else { - canonical_addr.get_ipv6_mapped() - }; + let addr = if config.network.address.is_ipv4() { + canonical_addr + .get_ipv4() + .expect("found peer ipv6 address while running bound to ipv4 address") + } else { + canonical_addr.get_ipv6_mapped() + }; - match socket.send_to(&buffer.into_inner()[..bytes_written], addr) { - Ok(amt) if config.statistics.active() => { - let stats = if canonical_addr.is_ipv4() { - let stats = &shared_state.statistics_ipv4; + match socket.send_to(&buffer.into_inner()[..bytes_written], addr) { + Ok(amt) if config.statistics.active() => { + let stats = if canonical_addr.is_ipv4() { + let stats = &statistics.ipv4; - stats - .bytes_sent - .fetch_add(amt + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); + stats + .bytes_sent + .fetch_add(amt + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); - stats - } else { - let stats = &shared_state.statistics_ipv6; + stats + } else { + let stats = &statistics.ipv6; - stats - .bytes_sent - .fetch_add(amt + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + stats + .bytes_sent + .fetch_add(amt + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); - stats - }; + stats + }; - match response { - Response::Connect(_) => { - stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed); - } - Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { - stats - .responses_sent_announce - .fetch_add(1, Ordering::Relaxed); - } - Response::Scrape(_) => { - stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed); - } - Response::Error(_) => { - stats.responses_sent_error.fetch_add(1, Ordering::Relaxed); - } + match response { + Response::Connect(_) => { + stats.responses_connect.fetch_add(1, Ordering::Relaxed); + } + Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { + stats.responses_announce.fetch_add(1, Ordering::Relaxed); + } + Response::Scrape(_) => { + stats.responses_scrape.fetch_add(1, Ordering::Relaxed); + } + Response::Error(_) => { + stats.responses_error.fetch_add(1, Ordering::Relaxed); } } - Ok(_) => (), - Err(err) => match opt_resend_buffer.as_mut() { + } + Ok(_) => (), + Err(err) => { + match opt_resend_buffer.as_mut() { Some(resend_buffer) if (err.raw_os_error() == Some(libc::ENOBUFS)) || (err.kind() == ErrorKind::WouldBlock) => @@ -498,7 +485,7 @@ impl SocketWorker { _ => { ::log::warn!("Sending response to {} failed: {:#}", addr, err); } - }, + } } } } diff --git a/crates/udp/src/workers/socket/mod.rs b/crates/udp/src/workers/socket/mod.rs index f8597cc..f724138 100644 --- a/crates/udp/src/workers/socket/mod.rs +++ b/crates/udp/src/workers/socket/mod.rs @@ -5,11 +5,14 @@ mod uring; mod validator; use anyhow::Context; -use aquatic_common::{privileges::PrivilegeDropper, ServerStartInstant}; +use aquatic_common::privileges::PrivilegeDropper; use socket2::{Domain, Protocol, Socket, Type}; use crate::{ - common::{ConnectedRequestSender, ConnectedResponseReceiver, State}, + common::{ + CachePaddedArc, ConnectedRequestSender, ConnectedResponseReceiver, IpVersionStatistics, + SocketWorkerStatistics, State, + }, config::Config, }; @@ -37,10 +40,10 @@ const EXTRA_PACKET_SIZE_IPV4: usize = 8 + 18 + 20 + 8; const EXTRA_PACKET_SIZE_IPV6: usize = 8 + 18 + 40 + 8; pub fn run_socket_worker( - shared_state: State, config: Config, + shared_state: State, + statistics: CachePaddedArc>, validator: ConnectionValidator, - server_start_instant: ServerStartInstant, request_sender: ConnectedRequestSender, response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, @@ -49,10 +52,10 @@ pub fn run_socket_worker( match self::uring::supported_on_current_kernel() { Ok(()) => { return self::uring::SocketWorker::run( - shared_state, config, + shared_state, + statistics, validator, - server_start_instant, request_sender, response_receiver, priv_dropper, @@ -67,10 +70,10 @@ pub fn run_socket_worker( } self::mio::SocketWorker::run( - shared_state, config, + shared_state, + statistics, validator, - server_start_instant, request_sender, response_receiver, priv_dropper, diff --git a/crates/udp/src/workers/socket/uring/mod.rs b/crates/udp/src/workers/socket/uring/mod.rs index 955a7b3..fe8b490 100644 --- a/crates/udp/src/workers/socket/uring/mod.rs +++ b/crates/udp/src/workers/socket/uring/mod.rs @@ -11,7 +11,6 @@ use std::sync::atomic::Ordering; use anyhow::Context; use aquatic_common::access_list::AccessListCache; -use aquatic_common::ServerStartInstant; use io_uring::opcode::Timeout; use io_uring::types::{Fixed, Timespec}; use io_uring::{IoUring, Probe}; @@ -76,11 +75,11 @@ impl CurrentRing { pub struct SocketWorker { config: Config, shared_state: State, + statistics: CachePaddedArc>, request_sender: ConnectedRequestSender, response_receiver: ConnectedResponseReceiver, access_list_cache: AccessListCache, validator: ConnectionValidator, - server_start_instant: ServerStartInstant, #[allow(dead_code)] socket: UdpSocket, pending_scrape_responses: PendingScrapeResponseSlab, @@ -97,10 +96,10 @@ pub struct SocketWorker { impl SocketWorker { pub fn run( - shared_state: State, config: Config, + shared_state: State, + statistics: CachePaddedArc>, validator: ConnectionValidator, - server_start_instant: ServerStartInstant, request_sender: ConnectedRequestSender, response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, @@ -164,14 +163,16 @@ impl SocketWorker { cleaning_timeout_sqe.clone(), ]; - let pending_scrape_valid_until = - ValidUntil::new(server_start_instant, config.cleaning.max_pending_scrape_age); + let pending_scrape_valid_until = ValidUntil::new( + shared_state.server_start_instant, + config.cleaning.max_pending_scrape_age, + ); let mut worker = Self { config, shared_state, + statistics, validator, - server_start_instant, request_sender, response_receiver, access_list_cache, @@ -293,7 +294,7 @@ impl SocketWorker { } USER_DATA_PULSE_TIMEOUT => { self.pending_scrape_valid_until = ValidUntil::new( - self.server_start_instant, + self.shared_state.server_start_instant, self.config.cleaning.max_pending_scrape_age, ); @@ -302,7 +303,7 @@ impl SocketWorker { } USER_DATA_CLEANING_TIMEOUT => { self.pending_scrape_responses - .clean(self.server_start_instant.seconds_elapsed()); + .clean(self.shared_state.server_start_instant.seconds_elapsed()); self.resubmittable_sqe_buf .push(self.cleaning_timeout_sqe.clone()); @@ -322,9 +323,9 @@ impl SocketWorker { self.send_buffers.response_type_and_ipv4(send_buffer_index); let (statistics, extra_bytes) = if receiver_is_ipv4 { - (&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4) + (&self.statistics.ipv4, EXTRA_PACKET_SIZE_IPV4) } else { - (&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6) + (&self.statistics.ipv6, EXTRA_PACKET_SIZE_IPV6) }; statistics @@ -332,10 +333,10 @@ impl SocketWorker { .fetch_add(result as usize + extra_bytes, Ordering::Relaxed); let response_counter = match response_type { - ResponseType::Connect => &statistics.responses_sent_connect, - ResponseType::Announce => &statistics.responses_sent_announce, - ResponseType::Scrape => &statistics.responses_sent_scrape, - ResponseType::Error => &statistics.responses_sent_error, + ResponseType::Connect => &statistics.responses_connect, + ResponseType::Announce => &statistics.responses_announce, + ResponseType::Scrape => &statistics.responses_scrape, + ResponseType::Error => &statistics.responses_error, }; response_counter.fetch_add(1, Ordering::Relaxed); @@ -433,15 +434,15 @@ impl SocketWorker { if self.config.statistics.active() { let (statistics, extra_bytes) = if addr.is_ipv4() { - (&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4) + (&self.statistics.ipv4, EXTRA_PACKET_SIZE_IPV4) } else { - (&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6) + (&self.statistics.ipv6, EXTRA_PACKET_SIZE_IPV6) }; statistics .bytes_received .fetch_add(buffer.len() + extra_bytes, Ordering::Relaxed); - statistics.requests_received.fetch_add(1, Ordering::Relaxed); + statistics.requests.fetch_add(1, Ordering::Relaxed); } } diff --git a/crates/udp/src/workers/statistics/collector.rs b/crates/udp/src/workers/statistics/collector.rs index 820e0b5..5297853 100644 --- a/crates/udp/src/workers/statistics/collector.rs +++ b/crates/udp/src/workers/statistics/collector.rs @@ -1,43 +1,41 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::atomic::Ordering; use std::time::Instant; use hdrhistogram::Histogram; use num_format::{Locale, ToFormattedString}; use serde::Serialize; -use crate::common::Statistics; use crate::config::Config; +use super::{IpVersion, Statistics}; + #[cfg(feature = "prometheus")] macro_rules! set_peer_histogram_gauge { - ($ip_version:ident, $data:expr, $type_label:expr) => { + ($ip_version:expr, $data:expr, $type_label:expr) => { ::metrics::gauge!( "aquatic_peers_per_torrent", "type" => $type_label, - "ip_version" => $ip_version.clone(), + "ip_version" => $ip_version, ) .set($data as f64); }; } pub struct StatisticsCollector { - shared: Arc, + statistics: Statistics, + ip_version: IpVersion, last_update: Instant, pending_histograms: Vec>, last_complete_histogram: PeerHistogramStatistics, - #[cfg(feature = "prometheus")] - ip_version: String, } impl StatisticsCollector { - pub fn new(shared: Arc, #[cfg(feature = "prometheus")] ip_version: String) -> Self { + pub fn new(statistics: Statistics, ip_version: IpVersion) -> Self { Self { - shared, + statistics, last_update: Instant::now(), pending_histograms: Vec::new(), last_complete_histogram: Default::default(), - #[cfg(feature = "prometheus")] ip_version, } } @@ -55,27 +53,177 @@ impl StatisticsCollector { &mut self, #[cfg(feature = "prometheus")] config: &Config, ) -> CollectedStatistics { - let requests_received = Self::fetch_and_reset(&self.shared.requests_received); - let responses_sent_connect = Self::fetch_and_reset(&self.shared.responses_sent_connect); - let responses_sent_announce = Self::fetch_and_reset(&self.shared.responses_sent_announce); - let responses_sent_scrape = Self::fetch_and_reset(&self.shared.responses_sent_scrape); - let responses_sent_error = Self::fetch_and_reset(&self.shared.responses_sent_error); + let mut requests = 0; + let mut responses_connect: usize = 0; + let mut responses_announce: usize = 0; + let mut responses_scrape: usize = 0; + let mut responses_error: usize = 0; + let mut bytes_received: usize = 0; + let mut bytes_sent: usize = 0; + let mut num_torrents: usize = 0; + let mut num_peers: usize = 0; - let bytes_received = Self::fetch_and_reset(&self.shared.bytes_received); - let bytes_sent = Self::fetch_and_reset(&self.shared.bytes_sent); + #[cfg(feature = "prometheus")] + let ip_version_prometheus_str = self.ip_version.prometheus_str(); - let num_torrents_by_worker: Vec = self - .shared - .torrents + for (i, statistics) in self + .statistics + .socket .iter() - .map(|n| n.load(Ordering::Relaxed)) - .collect(); - let num_peers_by_worker: Vec = self - .shared - .peers + .map(|s| s.by_ip_version(self.ip_version)) + .enumerate() + { + { + let n = statistics.requests.fetch_and(0, Ordering::Relaxed); + + requests += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::counter!( + "aquatic_requests_total", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .increment(n.try_into().unwrap()); + } + } + { + let n = statistics.responses_connect.fetch_and(0, Ordering::Relaxed); + + responses_connect += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::counter!( + "aquatic_responses_total", + "type" => "connect", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .increment(n.try_into().unwrap()); + } + } + { + let n = statistics + .responses_announce + .fetch_and(0, Ordering::Relaxed); + + responses_announce += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::counter!( + "aquatic_responses_total", + "type" => "announce", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .increment(n.try_into().unwrap()); + } + } + { + let n = statistics.responses_scrape.fetch_and(0, Ordering::Relaxed); + + responses_scrape += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::counter!( + "aquatic_responses_total", + "type" => "scrape", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .increment(n.try_into().unwrap()); + } + } + { + let n = statistics.responses_error.fetch_and(0, Ordering::Relaxed); + + responses_error += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::counter!( + "aquatic_responses_total", + "type" => "error", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .increment(n.try_into().unwrap()); + } + } + { + let n = statistics.bytes_received.fetch_and(0, Ordering::Relaxed); + + bytes_received += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::counter!( + "aquatic_rx_bytes", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .increment(n.try_into().unwrap()); + } + } + { + let n = statistics.bytes_sent.fetch_and(0, Ordering::Relaxed); + + bytes_sent += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::counter!( + "aquatic_tx_bytes", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .increment(n.try_into().unwrap()); + } + } + } + + for (i, statistics) in self + .statistics + .swarm .iter() - .map(|n| n.load(Ordering::Relaxed)) - .collect(); + .map(|s| s.by_ip_version(self.ip_version)) + .enumerate() + { + { + let n = statistics.torrents.load(Ordering::Relaxed); + + num_torrents += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::gauge!( + "aquatic_torrents", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .set(n as f64); + } + } + { + let n = statistics.peers.load(Ordering::Relaxed); + + num_peers += n; + + #[cfg(feature = "prometheus")] + if config.statistics.run_prometheus_endpoint { + ::metrics::gauge!( + "aquatic_peers", + "ip_version" => ip_version_prometheus_str, + "worker_index" => i.to_string(), + ) + .set(n as f64); + } + } + } let elapsed = { let now = Instant::now(); @@ -88,84 +236,16 @@ impl StatisticsCollector { }; #[cfg(feature = "prometheus")] - if config.statistics.run_prometheus_endpoint { - ::metrics::counter!( - "aquatic_requests_total", - "ip_version" => self.ip_version.clone(), - ) - .increment(requests_received.try_into().unwrap()); - - ::metrics::counter!( - "aquatic_responses_total", - "type" => "connect", - "ip_version" => self.ip_version.clone(), - ) - .increment(responses_sent_connect.try_into().unwrap()); - - ::metrics::counter!( - "aquatic_responses_total", - "type" => "announce", - "ip_version" => self.ip_version.clone(), - ) - .increment(responses_sent_announce.try_into().unwrap()); - - ::metrics::counter!( - "aquatic_responses_total", - "type" => "scrape", - "ip_version" => self.ip_version.clone(), - ) - .increment(responses_sent_scrape.try_into().unwrap()); - - ::metrics::counter!( - "aquatic_responses_total", - "type" => "error", - "ip_version" => self.ip_version.clone(), - ) - .increment(responses_sent_error.try_into().unwrap()); - - ::metrics::counter!( - "aquatic_rx_bytes", - "ip_version" => self.ip_version.clone(), - ) - .increment(bytes_received.try_into().unwrap()); - - ::metrics::counter!( - "aquatic_tx_bytes", - "ip_version" => self.ip_version.clone(), - ) - .increment(bytes_sent.try_into().unwrap()); - - for (worker_index, n) in num_torrents_by_worker.iter().copied().enumerate() { - ::metrics::gauge!( - "aquatic_torrents", - "ip_version" => self.ip_version.clone(), - "worker_index" => worker_index.to_string(), - ) - .set(n as f64); - } - for (worker_index, n) in num_peers_by_worker.iter().copied().enumerate() { - ::metrics::gauge!( - "aquatic_peers", - "ip_version" => self.ip_version.clone(), - "worker_index" => worker_index.to_string(), - ) - .set(n as f64); - } - - if config.statistics.torrent_peer_histograms { - self.last_complete_histogram - .update_metrics(self.ip_version.clone()); - } + if config.statistics.run_prometheus_endpoint && config.statistics.torrent_peer_histograms { + self.last_complete_histogram + .update_metrics(ip_version_prometheus_str); } - let num_peers: usize = num_peers_by_worker.into_iter().sum(); - let num_torrents: usize = num_torrents_by_worker.into_iter().sum(); - - let requests_per_second = requests_received as f64 / elapsed; - let responses_per_second_connect = responses_sent_connect as f64 / elapsed; - let responses_per_second_announce = responses_sent_announce as f64 / elapsed; - let responses_per_second_scrape = responses_sent_scrape as f64 / elapsed; - let responses_per_second_error = responses_sent_error as f64 / elapsed; + let requests_per_second = requests as f64 / elapsed; + let responses_per_second_connect = responses_connect as f64 / elapsed; + let responses_per_second_announce = responses_announce as f64 / elapsed; + let responses_per_second_scrape = responses_scrape as f64 / elapsed; + let responses_per_second_error = responses_error as f64 / elapsed; let bytes_received_per_second = bytes_received as f64 / elapsed; let bytes_sent_per_second = bytes_sent as f64 / elapsed; @@ -193,10 +273,6 @@ impl StatisticsCollector { peer_histogram: self.last_complete_histogram.clone(), } } - - fn fetch_and_reset(atomic: &AtomicUsize) -> usize { - atomic.fetch_and(0, Ordering::Relaxed) - } } #[derive(Clone, Debug, Serialize)] @@ -253,7 +329,7 @@ impl PeerHistogramStatistics { } #[cfg(feature = "prometheus")] - fn update_metrics(&self, ip_version: String) { + fn update_metrics(&self, ip_version: &'static str) { set_peer_histogram_gauge!(ip_version, self.min, "min"); set_peer_histogram_gauge!(ip_version, self.p10, "p10"); set_peer_histogram_gauge!(ip_version, self.p20, "p20"); diff --git a/crates/udp/src/workers/statistics/mod.rs b/crates/udp/src/workers/statistics/mod.rs index 7600b33..25fc157 100644 --- a/crates/udp/src/workers/statistics/mod.rs +++ b/crates/udp/src/workers/statistics/mod.rs @@ -44,6 +44,7 @@ struct TemplateData { pub fn run_statistics_worker( config: Config, shared_state: State, + statistics: Statistics, statistics_receiver: Receiver, ) -> anyhow::Result<()> { let process_peer_client_data = { @@ -68,16 +69,8 @@ pub fn run_statistics_worker( None }; - let mut ipv4_collector = StatisticsCollector::new( - shared_state.statistics_ipv4, - #[cfg(feature = "prometheus")] - "4".into(), - ); - let mut ipv6_collector = StatisticsCollector::new( - shared_state.statistics_ipv6, - #[cfg(feature = "prometheus")] - "6".into(), - ); + let mut ipv4_collector = StatisticsCollector::new(statistics.clone(), IpVersion::V4); + let mut ipv6_collector = StatisticsCollector::new(statistics, IpVersion::V6); // Store a count to enable not removing peers from the count completely // just because they were removed from one torrent diff --git a/crates/udp/src/workers/swarm/mod.rs b/crates/udp/src/workers/swarm/mod.rs index 4474d71..ccdab8a 100644 --- a/crates/udp/src/workers/swarm/mod.rs +++ b/crates/udp/src/workers/swarm/mod.rs @@ -5,7 +5,6 @@ use std::sync::atomic::Ordering; use std::time::Duration; use std::time::Instant; -use aquatic_common::ServerStartInstant; use crossbeam_channel::Receiver; use crossbeam_channel::Sender; use rand::{rngs::SmallRng, SeedableRng}; @@ -20,7 +19,7 @@ use storage::TorrentMaps; pub struct SwarmWorker { pub config: Config, pub state: State, - pub server_start_instant: ServerStartInstant, + pub statistics: CachePaddedArc>, pub request_receiver: Receiver<(SocketWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>, pub response_sender: ConnectedResponseSender, pub statistics_sender: Sender, @@ -33,8 +32,10 @@ impl SwarmWorker { let mut rng = SmallRng::from_entropy(); let timeout = Duration::from_millis(self.config.request_channel_recv_timeout_ms); - let mut peer_valid_until = - ValidUntil::new(self.server_start_instant, self.config.cleaning.max_peer_age); + let mut peer_valid_until = ValidUntil::new( + self.state.server_start_instant, + self.config.cleaning.max_peer_age, + ); let cleaning_interval = Duration::from_secs(self.config.cleaning.torrent_cleaning_interval); let statistics_update_interval = Duration::from_secs(self.config.statistics.interval); @@ -110,17 +111,18 @@ impl SwarmWorker { if iter_counter % 128 == 0 { let now = Instant::now(); - peer_valid_until = - ValidUntil::new(self.server_start_instant, self.config.cleaning.max_peer_age); + peer_valid_until = ValidUntil::new( + self.state.server_start_instant, + self.config.cleaning.max_peer_age, + ); if now > last_cleaning + cleaning_interval { torrents.clean_and_update_statistics( &self.config, &self.state, + &self.statistics, &self.statistics_sender, &self.state.access_list, - self.server_start_instant, - self.worker_index, ); last_cleaning = now; @@ -128,10 +130,14 @@ impl SwarmWorker { if self.config.statistics.active() && now > last_statistics_update + statistics_update_interval { - self.state.statistics_ipv4.torrents[self.worker_index.0] - .store(torrents.ipv4.num_torrents(), Ordering::Release); - self.state.statistics_ipv6.torrents[self.worker_index.0] - .store(torrents.ipv6.num_torrents(), Ordering::Release); + self.statistics + .ipv4 + .torrents + .store(torrents.ipv4.num_torrents(), Ordering::Relaxed); + self.statistics + .ipv6 + .torrents + .store(torrents.ipv6.num_torrents(), Ordering::Relaxed); last_statistics_update = now; } diff --git a/crates/udp/src/workers/swarm/storage.rs b/crates/udp/src/workers/swarm/storage.rs index 0c1dcc2..4289370 100644 --- a/crates/udp/src/workers/swarm/storage.rs +++ b/crates/udp/src/workers/swarm/storage.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use aquatic_common::IndexMap; use aquatic_common::SecondsSinceServerStart; -use aquatic_common::ServerStartInstant; use aquatic_common::{ access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache, AccessListMode}, ValidUntil, @@ -41,14 +40,13 @@ impl TorrentMaps { &mut self, config: &Config, state: &State, + statistics: &CachePaddedArc>, statistics_sender: &Sender, access_list: &Arc, - server_start_instant: ServerStartInstant, - worker_index: SwarmWorkerIndex, ) { let mut cache = create_access_list_cache(access_list); let mode = config.access_list.mode; - let now = server_start_instant.seconds_elapsed(); + let now = state.server_start_instant.seconds_elapsed(); let ipv4 = self.ipv4 @@ -58,8 +56,8 @@ impl TorrentMaps { .clean_and_get_statistics(config, statistics_sender, &mut cache, mode, now); if config.statistics.active() { - state.statistics_ipv4.peers[worker_index.0].store(ipv4.0, Ordering::Release); - state.statistics_ipv6.peers[worker_index.0].store(ipv6.0, Ordering::Release); + statistics.ipv4.peers.store(ipv4.0, Ordering::Relaxed); + statistics.ipv6.peers.store(ipv6.0, Ordering::Relaxed); if let Some(message) = ipv4.1.map(StatisticsMessage::Ipv4PeerHistogram) { if let Err(err) = statistics_sender.try_send(message) { @@ -516,3 +514,50 @@ struct Peer { is_seeder: bool, valid_until: ValidUntil, } + +#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] +pub enum PeerStatus { + Seeding, + Leeching, + Stopped, +} + +impl PeerStatus { + /// Determine peer status from announce event and number of bytes left. + /// + /// Likely, the last branch will be taken most of the time. + #[inline] + pub fn from_event_and_bytes_left(event: AnnounceEvent, bytes_left: NumberOfBytes) -> Self { + if event == AnnounceEvent::Stopped { + Self::Stopped + } else if bytes_left.0.get() == 0 { + Self::Seeding + } else { + Self::Leeching + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_peer_status_from_event_and_bytes_left() { + use PeerStatus::*; + + let f = PeerStatus::from_event_and_bytes_left; + + assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes::new(0))); + assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes::new(1))); + + assert_eq!(Seeding, f(AnnounceEvent::Started, NumberOfBytes::new(0))); + assert_eq!(Leeching, f(AnnounceEvent::Started, NumberOfBytes::new(1))); + + assert_eq!(Seeding, f(AnnounceEvent::Completed, NumberOfBytes::new(0))); + assert_eq!(Leeching, f(AnnounceEvent::Completed, NumberOfBytes::new(1))); + + assert_eq!(Seeding, f(AnnounceEvent::None, NumberOfBytes::new(0))); + assert_eq!(Leeching, f(AnnounceEvent::None, NumberOfBytes::new(1))); + } +} diff --git a/crates/udp/tests/access_list.rs b/crates/udp/tests/access_list.rs index cd0ec79..bc3dacf 100644 --- a/crates/udp/tests/access_list.rs +++ b/crates/udp/tests/access_list.rs @@ -6,6 +6,7 @@ use std::{ fs::File, io::Write, net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, + num::NonZeroU16, time::Duration, }; @@ -78,7 +79,7 @@ fn test_access_list( &socket, tracker_addr, connection_id, - 1, + NonZeroU16::new(1).unwrap(), info_hash_fail, 10, false, @@ -95,7 +96,7 @@ fn test_access_list( &socket, tracker_addr, connection_id, - 1, + NonZeroU16::new(1).unwrap(), info_hash_success, 10, false, diff --git a/crates/udp/tests/common/mod.rs b/crates/udp/tests/common/mod.rs index 0656b73..7fbec29 100644 --- a/crates/udp/tests/common/mod.rs +++ b/crates/udp/tests/common/mod.rs @@ -3,6 +3,7 @@ use std::{ io::Cursor, net::{SocketAddr, UdpSocket}, + num::NonZeroU16, time::Duration, }; @@ -42,7 +43,7 @@ pub fn announce( socket: &UdpSocket, tracker_addr: SocketAddr, connection_id: ConnectionId, - peer_port: u16, + peer_port: NonZeroU16, info_hash: InfoHash, peers_wanted: usize, seeder: bool, @@ -50,7 +51,7 @@ pub fn announce( let mut peer_id = PeerId([0; 20]); for chunk in peer_id.0.chunks_exact_mut(2) { - chunk.copy_from_slice(&peer_port.to_ne_bytes()); + chunk.copy_from_slice(&peer_port.get().to_ne_bytes()); } let request = Request::Announce(AnnounceRequest { diff --git a/crates/udp/tests/invalid_connection_id.rs b/crates/udp/tests/invalid_connection_id.rs index 7506854..96ecf67 100644 --- a/crates/udp/tests/invalid_connection_id.rs +++ b/crates/udp/tests/invalid_connection_id.rs @@ -5,6 +5,7 @@ use common::*; use std::{ io::{Cursor, ErrorKind}, net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, + num::NonZeroU16, time::Duration, }; @@ -51,7 +52,7 @@ fn test_invalid_connection_id() -> anyhow::Result<()> { ip_address: Ipv4AddrBytes([0; 4]), key: PeerKey::new(0), peers_wanted: NumberOfPeers::new(10), - port: Port::new(1), + port: Port::new(NonZeroU16::new(1).unwrap()), }); let scrape_request = Request::Scrape(ScrapeRequest { diff --git a/crates/udp/tests/requests_responses.rs b/crates/udp/tests/requests_responses.rs index 0ce345d..dc55aa0 100644 --- a/crates/udp/tests/requests_responses.rs +++ b/crates/udp/tests/requests_responses.rs @@ -5,6 +5,7 @@ use common::*; use std::{ collections::{hash_map::RandomState, HashSet}, net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, + num::NonZeroU16, time::Duration, }; @@ -45,7 +46,7 @@ fn test_multiple_connect_announce_scrape() -> anyhow::Result<()> { &socket, tracker_addr, connection_id, - PEER_PORT_START + i as u16, + NonZeroU16::new(PEER_PORT_START + i as u16).unwrap(), info_hash, PEERS_WANTED, is_seeder, diff --git a/crates/udp_protocol/src/common.rs b/crates/udp_protocol/src/common.rs index 6c54cb2..51fbfaf 100644 --- a/crates/udp_protocol/src/common.rs +++ b/crates/udp_protocol/src/common.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; use std::net::{Ipv4Addr, Ipv6Addr}; +use std::num::NonZeroU16; pub use aquatic_peer_id::{PeerClient, PeerId}; use zerocopy::network_endian::{I32, I64, U16, U32}; @@ -76,8 +77,8 @@ impl NumberOfDownloads { pub struct Port(pub U16); impl Port { - pub fn new(v: u16) -> Self { - Self(U16::new(v)) + pub fn new(v: NonZeroU16) -> Self { + Self(U16::new(v.into())) } } diff --git a/crates/udp_protocol/src/request.rs b/crates/udp_protocol/src/request.rs index a60791b..1b8cb43 100644 --- a/crates/udp_protocol/src/request.rs +++ b/crates/udp_protocol/src/request.rs @@ -58,15 +58,21 @@ impl Request { let request = AnnounceRequest::read_from_prefix(bytes) .ok_or_else(|| RequestParseError::unsendable_text("invalid data"))?; - // Make sure not to create AnnounceEventBytes with invalid value - if matches!(request.event.0.get(), (0..=3)) { - Ok(Request::Announce(request)) - } else { + if request.port.0.get() == 0 { + Err(RequestParseError::sendable_text( + "Port can't be 0", + request.connection_id, + request.transaction_id, + )) + } else if !matches!(request.event.0.get(), (0..=3)) { + // Make sure not to allow AnnounceEventBytes with invalid value Err(RequestParseError::sendable_text( "Invalid announce event", request.connection_id, request.transaction_id, )) + } else { + Ok(Request::Announce(request)) } } // Scrape @@ -275,7 +281,7 @@ impl RequestParseError { mod tests { use quickcheck::TestResult; use quickcheck_macros::quickcheck; - use zerocopy::network_endian::{I32, I64, U16}; + use zerocopy::network_endian::{I32, I64}; use super::*; @@ -313,7 +319,7 @@ mod tests { ip_address: Ipv4AddrBytes::arbitrary(g), key: PeerKey::new(i32::arbitrary(g)), peers_wanted: NumberOfPeers(I32::new(i32::arbitrary(g))), - port: Port(U16::new(u16::arbitrary(g))), + port: Port::new(quickcheck::Arbitrary::arbitrary(g)), } } }