From 4203e86eca1cb7d35221e2b280cf921257c6588a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Wed, 13 Apr 2022 23:40:04 +0200 Subject: [PATCH] udp: optimize/simplify ConnectionValidator --- aquatic_udp/src/common.rs | 44 +++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/aquatic_udp/src/common.rs b/aquatic_udp/src/common.rs index 0384177..bce0afe 100644 --- a/aquatic_udp/src/common.rs +++ b/aquatic_udp/src/common.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use crossbeam_channel::{Sender, TrySendError}; use getrandom::getrandom; @@ -19,7 +19,7 @@ pub const MAX_PACKET_SIZE: usize = 8192; #[derive(Clone)] pub struct ConnectionValidator { start_time: Instant, - max_connection_age: Duration, + max_connection_age: u32, hmac: blake3::Hasher, } @@ -31,21 +31,18 @@ impl ConnectionValidator { let hmac = blake3::Hasher::new_keyed(&key); - let start_time = Instant::now(); - let max_connection_age = Duration::from_secs(config.cleaning.max_connection_age); - Ok(Self { hmac, - start_time, - max_connection_age, + start_time: Instant::now(), + max_connection_age: config.cleaning.max_connection_age as u32, }) } pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId { - // Seconds elapsed since server start, as bytes - let elapsed_time_bytes = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes(); + let valid_until = + (self.start_time.elapsed().as_secs() as u32 + self.max_connection_age).to_ne_bytes(); - self.create_connection_id_inner(elapsed_time_bytes, source_addr) + self.create_connection_id_inner(valid_until, source_addr) } pub fn connection_id_valid( @@ -53,34 +50,31 @@ impl ConnectionValidator { source_addr: CanonicalSocketAddr, connection_id: ConnectionId, ) -> bool { - let elapsed_time_bytes = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); + let valid_until = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); - // i64 comparison should be constant-time - let hmac_valid = - connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_addr); - - if !hmac_valid { + // Check that recreating ConnectionId with same inputs yields identical HMAC. + // + // I expect i64 comparison to be be constant-time. + if connection_id != self.create_connection_id_inner(valid_until, source_addr) { return false; } - let connection_elapsed_since_start = - Duration::from_secs(u32::from_ne_bytes(elapsed_time_bytes) as u64); - - connection_elapsed_since_start + self.max_connection_age > self.start_time.elapsed() + u32::from_ne_bytes(valid_until) > self.start_time.elapsed().as_secs() as u32 } fn create_connection_id_inner( &mut self, - elapsed_time_bytes: [u8; 4], + valid_until: [u8; 4], source_addr: CanonicalSocketAddr, ) -> ConnectionId { - // The first 4 bytes is the elapsed time since server start in seconds. The last 4 is a - // truncated message authentication code. + // The first 4 bytes is number of seconds since server start until + // connection is no longer valid. The last 4 is the truncated message + // authentication code. let mut connection_id_bytes = [0u8; 8]; - (&mut connection_id_bytes[..4]).copy_from_slice(&elapsed_time_bytes); + (&mut connection_id_bytes[..4]).copy_from_slice(&valid_until); - self.hmac.update(&elapsed_time_bytes); + self.hmac.update(&valid_until); match source_addr.get().ip() { IpAddr::V4(ip) => self.hmac.update(&ip.octets()),