diff --git a/TODO.md b/TODO.md index 875687e..17a10ef 100644 --- a/TODO.md +++ b/TODO.md @@ -2,9 +2,6 @@ ## High priority -* udp - * make ConnectionValidator faster by avoiding calling time functions so often - ## Medium priority * stagger cleaning tasks? diff --git a/crates/udp/src/workers/socket/mio.rs b/crates/udp/src/workers/socket/mio.rs index 95bca78..a73a35e 100644 --- a/crates/udp/src/workers/socket/mio.rs +++ b/crates/udp/src/workers/socket/mio.rs @@ -98,6 +98,8 @@ impl SocketWorker { } if iter_counter % 256 == 0 { + self.validator.update_elapsed(); + self.peer_valid_until = ValidUntil::new( self.shared_state.server_start_instant, self.config.cleaning.max_peer_age, diff --git a/crates/udp/src/workers/socket/uring/mod.rs b/crates/udp/src/workers/socket/uring/mod.rs index 2f85e42..ac572d4 100644 --- a/crates/udp/src/workers/socket/uring/mod.rs +++ b/crates/udp/src/workers/socket/uring/mod.rs @@ -134,9 +134,10 @@ impl SocketWorker { let recv_sqe = recv_helper.create_entry(buf_ring.bgid()); - // This timeout enables regular updates of peer_valid_until + // This timeout enables regular updates of ConnectionValidator and + // peer_valid_until let pulse_timeout_sqe = { - let timespec_ptr = Box::into_raw(Box::new(Timespec::new().sec(1))) as *const _; + let timespec_ptr = Box::into_raw(Box::new(Timespec::new().sec(5))) as *const _; Timeout::new(timespec_ptr) .build() @@ -238,6 +239,8 @@ impl SocketWorker { } } USER_DATA_PULSE_TIMEOUT => { + self.validator.update_elapsed(); + self.peer_valid_until = ValidUntil::new( self.shared_state.server_start_instant, self.config.cleaning.max_peer_age, diff --git a/crates/udp/src/workers/socket/validator.rs b/crates/udp/src/workers/socket/validator.rs index c68d1ef..61e744c 100644 --- a/crates/udp/src/workers/socket/validator.rs +++ b/crates/udp/src/workers/socket/validator.rs @@ -12,6 +12,8 @@ use crate::config::Config; /// HMAC (BLAKE3) based ConnectionId creator and validator /// +/// Method update_elapsed must be called at least once a minute. +/// /// The purpose of using ConnectionIds is to make IP spoofing costly, mainly to /// prevent the tracker from being used as an amplification vector for DDoS /// attacks. By including 32 bits of BLAKE3 keyed hash output in the Ids, an @@ -32,6 +34,7 @@ pub struct ConnectionValidator { start_time: Instant, max_connection_age: u64, keyed_hasher: blake3::Hasher, + seconds_since_start: u32, } impl ConnectionValidator { @@ -49,11 +52,12 @@ impl ConnectionValidator { keyed_hasher, start_time: Instant::now(), max_connection_age: config.cleaning.max_connection_age.into(), + seconds_since_start: 0, }) } pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId { - let elapsed = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes(); + let elapsed = (self.seconds_since_start).to_ne_bytes(); let hash = self.hash(elapsed, source_addr.get().ip()); @@ -78,16 +82,23 @@ impl ConnectionValidator { return false; } - let tracker_elapsed = self.start_time.elapsed().as_secs(); + let seconds_since_start = self.seconds_since_start as u64; let client_elapsed = u64::from(u32::from_ne_bytes(elapsed)); let client_expiration_time = client_elapsed + self.max_connection_age; // In addition to checking if the client connection is expired, - // disallow client_elapsed values that are in future and thus could not - // have been sent by the tracker. This prevents brute forcing with - // `u32::MAX` as 'elapsed' part of ConnectionId to find a hash that + // disallow client_elapsed values that are too far in future and thus + // could not have been sent by the tracker. This prevents brute forcing + // with `u32::MAX` as 'elapsed' part of ConnectionId to find a hash that // works until the tracker is restarted. - (client_expiration_time > tracker_elapsed) & (client_elapsed <= tracker_elapsed) + let client_not_expired = client_expiration_time > seconds_since_start; + let client_elapsed_not_in_far_future = client_elapsed <= (seconds_since_start + 60); + + client_not_expired & client_elapsed_not_in_far_future + } + + pub fn update_elapsed(&mut self) { + self.seconds_since_start = self.start_time.elapsed().as_secs() as u32; } fn hash(&mut self, elapsed: [u8; 4], ip_addr: IpAddr) -> [u8; 4] { @@ -145,11 +156,6 @@ mod tests { return quickcheck::TestResult::failed(); } - if max_connection_age == 0 { - quickcheck::TestResult::from_bool(!original_valid) - } else { - // Note: depends on that running this test takes less than a second - quickcheck::TestResult::from_bool(original_valid) - } + quickcheck::TestResult::from_bool(original_valid) } }