use std::collections::BTreeMap; use std::hash::Hash; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use std::time::Instant; use anyhow::Context; use crossbeam_channel::{Sender, TrySendError}; use getrandom::getrandom; use aquatic_common::access_list::AccessListArcSwap; use aquatic_common::CanonicalSocketAddr; use aquatic_udp_protocol::*; use crate::config::Config; pub const BUFFER_SIZE: usize = 8192; /// HMAC (BLAKE3) based ConnectionID creator and validator /// /// Structure of created ConnectionID (bytes making up inner i64): /// - &[0..4]: connection expiration time as number of seconds after /// ConnectionValidator instance was created, encoded as u32 bytes. /// Value fits around 136 years. /// - &[4..8]: truncated keyed BLAKE3 hash of above 4 bytes and octets of /// client IP address /// /// The purpose of using ConnectionIDs is to prevent IP spoofing, mainly to /// prevent the tracker from being used as an amplification vector for DDoS /// attacks. By including 32 bits of BLAKE3 keyed hash output in its contents, /// such abuse should be rendered impractical. #[derive(Clone)] pub struct ConnectionValidator { start_time: Instant, max_connection_age: u32, keyed_hasher: blake3::Hasher, } impl ConnectionValidator { /// Create new instance. Must be created once and cloned if used in several /// threads. pub fn new(config: &Config) -> anyhow::Result { let mut key = [0; 32]; getrandom(&mut key) .with_context(|| "Couldn't get random bytes for ConnectionValidator key")?; let keyed_hasher = blake3::Hasher::new_keyed(&key); Ok(Self { keyed_hasher, start_time: Instant::now(), max_connection_age: config.cleaning.max_connection_age, }) } pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId { let valid_until = (self.start_time.elapsed().as_secs() as u32 + self.max_connection_age).to_ne_bytes(); self.create_connection_id_inner(valid_until, source_addr) } pub fn connection_id_valid( &mut self, source_addr: CanonicalSocketAddr, connection_id: ConnectionId, ) -> bool { let valid_until = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); // Check that recreating ConnectionId with same inputs yields identical hash. if !Self::connection_id_eq( connection_id, self.create_connection_id_inner(valid_until, source_addr), ) { return false; } u32::from_ne_bytes(valid_until) > self.start_time.elapsed().as_secs() as u32 } fn create_connection_id_inner( &mut self, valid_until: [u8; 4], source_addr: CanonicalSocketAddr, ) -> ConnectionId { let mut connection_id_bytes = [0u8; 8]; (&mut connection_id_bytes[..4]).copy_from_slice(&valid_until); self.keyed_hasher.update(&valid_until); match source_addr.get().ip() { IpAddr::V4(ip) => self.keyed_hasher.update(&ip.octets()), IpAddr::V6(ip) => self.keyed_hasher.update(&ip.octets()), }; self.keyed_hasher .finalize_xof() .fill(&mut connection_id_bytes[4..]); self.keyed_hasher.reset(); ConnectionId(i64::from_ne_bytes(connection_id_bytes)) } /// Compare ConnectionIDs without breaking constant time requirements /// /// Use this instead of PartialEq::eq to avoid optimizations breaking constant /// time HMAC comparison and thus strongly reducing security. #[cfg(target_arch = "x86_64")] fn connection_id_eq(a: ConnectionId, b: ConnectionId) -> bool { let mut eq = 0u8; unsafe { ::std::arch::asm!( "cmp {a}, {b}", "sete {eq}", a = in(reg) a.0, b = in(reg) b.0, eq = inout(reg_byte) eq, options(nomem, nostack), ); } eq != 0 } } #[derive(Debug)] pub struct PendingScrapeRequest { pub slab_key: usize, pub info_hashes: BTreeMap, } #[derive(Debug)] pub struct PendingScrapeResponse { pub slab_key: usize, pub torrent_stats: BTreeMap, } #[derive(Debug)] pub enum ConnectedRequest { Announce(AnnounceRequest), Scrape(PendingScrapeRequest), } #[derive(Debug)] pub enum ConnectedResponse { AnnounceIpv4(AnnounceResponse), AnnounceIpv6(AnnounceResponse), Scrape(PendingScrapeResponse), } #[derive(Clone, Copy, Debug)] pub struct SocketWorkerIndex(pub usize); #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct RequestWorkerIndex(pub usize); impl RequestWorkerIndex { pub fn from_info_hash(config: &Config, info_hash: InfoHash) -> Self { Self(info_hash.0[0] as usize % config.request_workers) } } pub struct ConnectedRequestSender { index: SocketWorkerIndex, senders: Vec>, } impl ConnectedRequestSender { pub fn new( index: SocketWorkerIndex, senders: Vec>, ) -> Self { Self { index, senders } } pub fn try_send_to( &self, index: RequestWorkerIndex, request: ConnectedRequest, addr: CanonicalSocketAddr, ) { match self.senders[index.0].try_send((self.index, request, addr)) { Ok(()) => {} Err(TrySendError::Full(_)) => { ::log::error!("Request channel {} is full, dropping request. Try increasing number of request workers or raising config.worker_channel_size.", index.0) } Err(TrySendError::Disconnected(_)) => { panic!("Request channel {} is disconnected", index.0); } } } } pub struct ConnectedResponseSender { senders: Vec>, } impl ConnectedResponseSender { pub fn new(senders: Vec>) -> Self { Self { senders } } pub fn try_send_to( &self, index: SocketWorkerIndex, response: ConnectedResponse, addr: CanonicalSocketAddr, ) { match self.senders[index.0].try_send((response, addr)) { Ok(()) => {} Err(TrySendError::Full(_)) => { ::log::error!("Response channel {} is full, dropping response. Try increasing number of socket workers or raising config.worker_channel_size.", index.0) } Err(TrySendError::Disconnected(_)) => { panic!("Response channel {} is disconnected", index.0); } } } } #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] pub enum PeerStatus { Seeding, Leeching, Stopped, } impl PeerStatus { /// Determine peer status from announce event and number of bytes left. /// /// Likely, the last branch will be taken most of the time. #[inline] pub fn from_event_and_bytes_left(event: AnnounceEvent, bytes_left: NumberOfBytes) -> Self { if event == AnnounceEvent::Stopped { Self::Stopped } else if bytes_left.0 == 0 { Self::Seeding } else { Self::Leeching } } } pub struct Statistics { pub requests_received: AtomicUsize, pub responses_sent_connect: AtomicUsize, pub responses_sent_announce: AtomicUsize, pub responses_sent_scrape: AtomicUsize, pub responses_sent_error: AtomicUsize, pub bytes_received: AtomicUsize, pub bytes_sent: AtomicUsize, pub torrents: Vec, pub peers: Vec, } impl Statistics { pub fn new(num_request_workers: usize) -> Self { Self { requests_received: Default::default(), responses_sent_connect: Default::default(), responses_sent_announce: Default::default(), responses_sent_scrape: Default::default(), responses_sent_error: Default::default(), bytes_received: Default::default(), bytes_sent: Default::default(), torrents: Self::create_atomic_usize_vec(num_request_workers), peers: Self::create_atomic_usize_vec(num_request_workers), } } fn create_atomic_usize_vec(len: usize) -> Vec { ::std::iter::repeat_with(|| AtomicUsize::default()) .take(len) .collect() } } #[derive(Clone)] pub struct State { pub access_list: Arc, pub statistics_ipv4: Arc, pub statistics_ipv6: Arc, } impl State { pub fn new(num_request_workers: usize) -> Self { Self { access_list: Arc::new(AccessListArcSwap::default()), statistics_ipv4: Arc::new(Statistics::new(num_request_workers)), statistics_ipv6: Arc::new(Statistics::new(num_request_workers)), } } } #[cfg(test)] mod tests { use std::net::{Ipv6Addr, SocketAddr}; use quickcheck_macros::quickcheck; use crate::config::Config; use super::*; #[test] fn test_peer_status_from_event_and_bytes_left() { use crate::common::*; use PeerStatus::*; let f = PeerStatus::from_event_and_bytes_left; assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes(0))); assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes(1))); assert_eq!(Seeding, f(AnnounceEvent::Started, NumberOfBytes(0))); assert_eq!(Leeching, f(AnnounceEvent::Started, NumberOfBytes(1))); assert_eq!(Seeding, f(AnnounceEvent::Completed, NumberOfBytes(0))); assert_eq!(Leeching, f(AnnounceEvent::Completed, NumberOfBytes(1))); assert_eq!(Seeding, f(AnnounceEvent::None, NumberOfBytes(0))); assert_eq!(Leeching, f(AnnounceEvent::None, NumberOfBytes(1))); } // Assumes that announce response with maximum amount of ipv6 peers will // be the longest #[test] fn test_buffer_size() { use aquatic_udp_protocol::*; let config = Config::default(); let peers = ::std::iter::repeat(ResponsePeer { ip_address: Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1), port: Port(1), }) .take(config.protocol.max_response_peers) .collect(); let response = Response::AnnounceIpv6(AnnounceResponse { transaction_id: TransactionId(1), announce_interval: AnnounceInterval(1), seeders: NumberOfPeers(1), leechers: NumberOfPeers(1), peers, }); let mut buf = Vec::new(); response.write(&mut buf).unwrap(); println!("Buffer len: {}", buf.len()); assert!(buf.len() <= BUFFER_SIZE); } #[quickcheck] fn test_connection_validator( original_addr: IpAddr, different_addr: IpAddr, max_connection_age: u32, ) -> quickcheck::TestResult { let original_addr = CanonicalSocketAddr::new(SocketAddr::new(original_addr, 0)); let different_addr = CanonicalSocketAddr::new(SocketAddr::new(different_addr, 0)); if original_addr == different_addr { return quickcheck::TestResult::discard(); } let mut validator = { let mut config = Config::default(); config.cleaning.max_connection_age = max_connection_age; ConnectionValidator::new(&config).unwrap() }; let connection_id = validator.create_connection_id(original_addr); let original_valid = validator.connection_id_valid(original_addr, connection_id); let different_valid = validator.connection_id_valid(different_addr, connection_id); if different_valid { return quickcheck::TestResult::failed(); } if max_connection_age == 0 { quickcheck::TestResult::from_bool(!original_valid) } else { // Note: depends on that running this test takes less than a second quickcheck::TestResult::from_bool(original_valid) } } #[quickcheck] fn test_connection_id_eq(a: i64, b: i64) -> bool { let a = ConnectionId(a); let b = ConnectionId(b); ConnectionValidator::connection_id_eq(a, b) == (a == b) } }