udp: optimize/simplify ConnectionValidator

This commit is contained in:
Joakim Frostegård 2022-04-13 23:40:04 +02:00
parent 059ef495bf
commit 4203e86eca

View file

@ -3,7 +3,7 @@ use std::hash::Hash;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::Instant;
use crossbeam_channel::{Sender, TrySendError}; use crossbeam_channel::{Sender, TrySendError};
use getrandom::getrandom; use getrandom::getrandom;
@ -19,7 +19,7 @@ pub const MAX_PACKET_SIZE: usize = 8192;
#[derive(Clone)] #[derive(Clone)]
pub struct ConnectionValidator { pub struct ConnectionValidator {
start_time: Instant, start_time: Instant,
max_connection_age: Duration, max_connection_age: u32,
hmac: blake3::Hasher, hmac: blake3::Hasher,
} }
@ -31,21 +31,18 @@ impl ConnectionValidator {
let hmac = blake3::Hasher::new_keyed(&key); let hmac = blake3::Hasher::new_keyed(&key);
let start_time = Instant::now();
let max_connection_age = Duration::from_secs(config.cleaning.max_connection_age);
Ok(Self { Ok(Self {
hmac, hmac,
start_time, start_time: Instant::now(),
max_connection_age, max_connection_age: config.cleaning.max_connection_age as u32,
}) })
} }
pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId { pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId {
// Seconds elapsed since server start, as bytes let valid_until =
let elapsed_time_bytes = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes(); (self.start_time.elapsed().as_secs() as u32 + self.max_connection_age).to_ne_bytes();
self.create_connection_id_inner(elapsed_time_bytes, source_addr) self.create_connection_id_inner(valid_until, source_addr)
} }
pub fn connection_id_valid( pub fn connection_id_valid(
@ -53,34 +50,31 @@ impl ConnectionValidator {
source_addr: CanonicalSocketAddr, source_addr: CanonicalSocketAddr,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> bool { ) -> bool {
let elapsed_time_bytes = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); let valid_until = connection_id.0.to_ne_bytes()[..4].try_into().unwrap();
// i64 comparison should be constant-time // Check that recreating ConnectionId with same inputs yields identical HMAC.
let hmac_valid = //
connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_addr); // I expect i64 comparison to be be constant-time.
if connection_id != self.create_connection_id_inner(valid_until, source_addr) {
if !hmac_valid {
return false; return false;
} }
let connection_elapsed_since_start = u32::from_ne_bytes(valid_until) > self.start_time.elapsed().as_secs() as u32
Duration::from_secs(u32::from_ne_bytes(elapsed_time_bytes) as u64);
connection_elapsed_since_start + self.max_connection_age > self.start_time.elapsed()
} }
fn create_connection_id_inner( fn create_connection_id_inner(
&mut self, &mut self,
elapsed_time_bytes: [u8; 4], valid_until: [u8; 4],
source_addr: CanonicalSocketAddr, source_addr: CanonicalSocketAddr,
) -> ConnectionId { ) -> ConnectionId {
// The first 4 bytes is the elapsed time since server start in seconds. The last 4 is a // The first 4 bytes is number of seconds since server start until
// truncated message authentication code. // connection is no longer valid. The last 4 is the truncated message
// authentication code.
let mut connection_id_bytes = [0u8; 8]; let mut connection_id_bytes = [0u8; 8];
(&mut connection_id_bytes[..4]).copy_from_slice(&elapsed_time_bytes); (&mut connection_id_bytes[..4]).copy_from_slice(&valid_until);
self.hmac.update(&elapsed_time_bytes); self.hmac.update(&valid_until);
match source_addr.get().ip() { match source_addr.get().ip() {
IpAddr::V4(ip) => self.hmac.update(&ip.octets()), IpAddr::V4(ip) => self.hmac.update(&ip.octets()),