diff --git a/aquatic_udp/src/common.rs b/aquatic_udp/src/common.rs index 082e036..0384177 100644 --- a/aquatic_udp/src/common.rs +++ b/aquatic_udp/src/common.rs @@ -16,13 +16,14 @@ use crate::config::Config; pub const MAX_PACKET_SIZE: usize = 8192; -pub struct ConnectionIdHandler { +#[derive(Clone)] +pub struct ConnectionValidator { start_time: Instant, max_connection_age: Duration, hmac: blake3::Hasher, } -impl ConnectionIdHandler { +impl ConnectionValidator { pub fn new(config: &Config) -> anyhow::Result { let mut key = [0; 32]; @@ -40,19 +41,23 @@ impl ConnectionIdHandler { }) } - pub fn create_connection_id(&mut self, source_ip: IpAddr) -> ConnectionId { + pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId { // Seconds elapsed since server start, as bytes let elapsed_time_bytes = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes(); - self.create_connection_id_inner(elapsed_time_bytes, source_ip) + self.create_connection_id_inner(elapsed_time_bytes, source_addr) } - pub fn connection_id_valid(&mut self, source_ip: IpAddr, connection_id: ConnectionId) -> bool { + pub fn connection_id_valid( + &mut self, + source_addr: CanonicalSocketAddr, + connection_id: ConnectionId, + ) -> bool { let elapsed_time_bytes = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); // i64 comparison should be constant-time let hmac_valid = - connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_ip); + connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_addr); if !hmac_valid { return false; @@ -67,7 +72,7 @@ impl ConnectionIdHandler { fn create_connection_id_inner( &mut self, elapsed_time_bytes: [u8; 4], - source_ip: IpAddr, + source_addr: CanonicalSocketAddr, ) -> ConnectionId { // The first 4 bytes is the elapsed time since server start in seconds. The last 4 is a // truncated message authentication code. @@ -77,7 +82,7 @@ impl ConnectionIdHandler { self.hmac.update(&elapsed_time_bytes); - match source_ip { + match source_addr.get().ip() { IpAddr::V4(ip) => self.hmac.update(&ip.octets()), IpAddr::V6(ip) => self.hmac.update(&ip.octets()), }; diff --git a/aquatic_udp/src/lib.rs b/aquatic_udp/src/lib.rs index ef436d0..864104e 100644 --- a/aquatic_udp/src/lib.rs +++ b/aquatic_udp/src/lib.rs @@ -17,7 +17,8 @@ use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::PanicSentinelWatcher; use common::{ - ConnectedRequestSender, ConnectedResponseSender, RequestWorkerIndex, SocketWorkerIndex, State, + ConnectedRequestSender, ConnectedResponseSender, ConnectionValidator, RequestWorkerIndex, + SocketWorkerIndex, State, }; use config::Config; @@ -31,6 +32,8 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let mut signals = Signals::new([SIGUSR1, SIGTERM])?; + let connection_validator = ConnectionValidator::new(&config)?; + let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel(); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); @@ -96,6 +99,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let sentinel = sentinel.clone(); let state = state.clone(); let config = config.clone(); + let connection_validator = connection_validator.clone(); let request_sender = ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone()); let response_receiver = response_receivers.remove(&i).unwrap(); @@ -117,6 +121,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { state, config, i, + connection_validator, request_sender, response_receiver, priv_dropper, diff --git a/aquatic_udp/src/workers/socket.rs b/aquatic_udp/src/workers/socket.rs index c8b5e05..d67fa40 100644 --- a/aquatic_udp/src/workers/socket.rs +++ b/aquatic_udp/src/workers/socket.rs @@ -9,7 +9,6 @@ use aquatic_common::privileges::PrivilegeDropper; use crossbeam_channel::Receiver; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; -use rand::prelude::{Rng, SeedableRng, StdRng}; use slab::Slab; use aquatic_common::access_list::create_access_list_cache; @@ -22,31 +21,6 @@ use socket2::{Domain, Protocol, Socket, Type}; use crate::common::*; use crate::config::Config; -#[derive(Default)] -pub struct ConnectionMap(AmortizedIndexMap<(ConnectionId, CanonicalSocketAddr), ValidUntil>); - -impl ConnectionMap { - pub fn insert( - &mut self, - connection_id: ConnectionId, - socket_addr: CanonicalSocketAddr, - valid_until: ValidUntil, - ) { - self.0.insert((connection_id, socket_addr), valid_until); - } - - pub fn contains(&self, connection_id: ConnectionId, socket_addr: CanonicalSocketAddr) -> bool { - self.0.contains_key(&(connection_id, socket_addr)) - } - - pub fn clean(&mut self) { - let now = Instant::now(); - - self.0.retain(|_, v| v.0 > now); - self.0.shrink_to_fit(); - } -} - #[derive(Debug)] pub struct PendingScrapeResponseSlabEntry { num_pending: usize, @@ -155,11 +129,11 @@ pub fn run_socket_worker( state: State, config: Config, token_num: usize, + mut connection_validator: ConnectionValidator, request_sender: ConnectedRequestSender, response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, priv_dropper: PrivilegeDropper, ) { - let mut rng = StdRng::from_entropy(); let mut buffer = [0u8; MAX_PACKET_SIZE]; let mut socket = @@ -173,7 +147,6 @@ pub fn run_socket_worker( .unwrap(); let mut events = Events::with_capacity(config.network.poll_event_capacity); - let mut connections = ConnectionMap::default(); let mut pending_scrape_responses = PendingScrapeResponseSlab::default(); let mut access_list_cache = create_access_list_cache(&state.access_list); @@ -181,15 +154,10 @@ pub fn run_socket_worker( let poll_timeout = Duration::from_millis(config.network.poll_timeout_ms); - let connection_cleaning_duration = - Duration::from_secs(config.cleaning.connection_cleaning_interval); let pending_scrape_cleaning_duration = Duration::from_secs(config.cleaning.pending_scrape_cleaning_interval); - let mut connection_valid_until = ValidUntil::new(config.cleaning.max_connection_age); let mut pending_scrape_valid_until = ValidUntil::new(config.cleaning.max_pending_scrape_age); - - let mut last_connection_cleaning = Instant::now(); let mut last_pending_scrape_cleaning = Instant::now(); let mut iter_counter = 0usize; @@ -205,15 +173,13 @@ pub fn run_socket_worker( read_requests( &config, &state, - &mut connections, + &mut connection_validator, &mut pending_scrape_responses, &mut access_list_cache, - &mut rng, &mut socket, &mut buffer, &request_sender, &mut local_responses, - connection_valid_until, pending_scrape_valid_until, ); } @@ -233,16 +199,9 @@ pub fn run_socket_worker( if iter_counter % 128 == 0 { let now = Instant::now(); - connection_valid_until = - ValidUntil::new_with_now(now, config.cleaning.max_connection_age); pending_scrape_valid_until = ValidUntil::new_with_now(now, config.cleaning.max_pending_scrape_age); - if now > last_connection_cleaning + connection_cleaning_duration { - connections.clean(); - - last_connection_cleaning = now; - } if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { pending_scrape_responses.clean(); @@ -258,15 +217,13 @@ pub fn run_socket_worker( fn read_requests( config: &Config, state: &State, - connections: &mut ConnectionMap, + connection_validator: &mut ConnectionValidator, pending_scrape_responses: &mut PendingScrapeResponseSlab, access_list_cache: &mut AccessListCache, - rng: &mut StdRng, socket: &mut UdpSocket, buffer: &mut [u8], request_sender: &ConnectedRequestSender, local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, - connection_valid_until: ValidUntil, pending_scrape_valid_until: ValidUntil, ) { let mut requests_received_ipv4: usize = 0; @@ -297,13 +254,11 @@ fn read_requests( handle_request( config, - connections, + connection_validator, pending_scrape_responses, access_list_cache, - rng, request_sender, local_responses, - connection_valid_until, pending_scrape_valid_until, res_request, src, @@ -341,13 +296,11 @@ fn read_requests( pub fn handle_request( config: &Config, - connections: &mut ConnectionMap, + connection_validator: &mut ConnectionValidator, pending_scrape_responses: &mut PendingScrapeResponseSlab, access_list_cache: &mut AccessListCache, - rng: &mut StdRng, request_sender: &ConnectedRequestSender, local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, - connection_valid_until: ValidUntil, pending_scrape_valid_until: ValidUntil, res_request: Result, src: CanonicalSocketAddr, @@ -356,9 +309,7 @@ pub fn handle_request( match res_request { Ok(Request::Connect(request)) => { - let connection_id = ConnectionId(rng.gen()); - - connections.insert(connection_id, src, connection_valid_until); + let connection_id = connection_validator.create_connection_id(src); let response = Response::Connect(ConnectResponse { connection_id, @@ -368,7 +319,7 @@ pub fn handle_request( local_responses.push((response, src)) } Ok(Request::Announce(request)) => { - if connections.contains(request.connection_id, src) { + if connection_validator.connection_id_valid(src, request.connection_id) { if access_list_cache .load() .allows(access_list_mode, &request.info_hash.0) @@ -392,7 +343,7 @@ pub fn handle_request( } } Ok(Request::Scrape(request)) => { - if connections.contains(request.connection_id, src) { + if connection_validator.connection_id_valid(src, request.connection_id) { let split_requests = pending_scrape_responses.prepare_split_requests( config, request, @@ -417,7 +368,7 @@ pub fn handle_request( err, } = err { - if connections.contains(connection_id, src) { + if connection_validator.connection_id_valid(src, connection_id) { let response = ErrorResponse { transaction_id, message: err.right_or("Parse error").into(),