diff --git a/aquatic_udp/src/common.rs b/aquatic_udp/src/common.rs index 0864621..8832477 100644 --- a/aquatic_udp/src/common.rs +++ b/aquatic_udp/src/common.rs @@ -65,9 +65,10 @@ impl ConnectionValidator { let valid_until = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); // Check that recreating ConnectionId with same inputs yields identical hash. - // - // I expect i64 comparison to be be constant-time. - if connection_id != self.create_connection_id_inner(valid_until, source_addr) { + if !Self::connection_id_eq_constant_time( + connection_id, + self.create_connection_id_inner(valid_until, source_addr), + ) { return false; } @@ -97,6 +98,27 @@ impl ConnectionValidator { ConnectionId(i64::from_ne_bytes(connection_id_bytes)) } + + /// Compare ConnectionIDs in constant time + /// + /// Use this instead of PartialEq::eq to avoid optimizations breaking constant + /// time HMAC comparison. + #[cfg(target_arch = "x86_64")] + fn connection_id_eq_constant_time(a: ConnectionId, b: ConnectionId) -> bool { + let mut eq = 0u8; + + unsafe { + ::std::arch::asm!( + "cmp {a}, {b}", + "sete {eq}", + a = in(reg) a.0, + b = in(reg) b.0, + eq = inout(reg_byte) eq, + ); + } + + eq != 0 + } } #[derive(Debug)] @@ -368,4 +390,12 @@ mod tests { quickcheck::TestResult::from_bool(original_valid) } } + + #[quickcheck] + fn test_connection_id_eq(a: i64, b: i64) -> bool { + let a = ConnectionId(a); + let b = ConnectionId(b); + + ConnectionValidator::connection_id_eq_constant_time(a, b) == (a == b) + } }