aquatic_udp: use constant_time_eq crate for ConnectionValidator

Crate is used in official blake3 implementation.

Improves speed and removed need for error-prone custom assembly.
This commit is contained in:
Joakim Frostegård 2022-04-15 23:40:05 +02:00
parent fb9b345990
commit 64452503e7
3 changed files with 28 additions and 52 deletions

9
Cargo.lock generated
View file

@ -211,6 +211,7 @@ dependencies = [
"aquatic_udp_protocol", "aquatic_udp_protocol",
"blake3", "blake3",
"cfg-if", "cfg-if",
"constant_time_eq 0.2.1",
"crossbeam-channel", "crossbeam-channel",
"getrandom", "getrandom",
"hashbrown 0.12.0", "hashbrown 0.12.0",
@ -549,7 +550,7 @@ dependencies = [
"arrayvec 0.7.2", "arrayvec 0.7.2",
"cc", "cc",
"cfg-if", "cfg-if",
"constant_time_eq", "constant_time_eq 0.1.5",
"digest 0.10.3", "digest 0.10.3",
] ]
@ -690,6 +691,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
[[package]]
name = "constant_time_eq"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df04a53a7e91248c27eb6bfc1db165e8f47453e98478e4609f9cce020bb3c65a"
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.2" version = "0.2.2"

View file

@ -26,6 +26,7 @@ aquatic_udp_protocol = { version = "0.2.0", path = "../aquatic_udp_protocol" }
anyhow = "1" anyhow = "1"
blake3 = "1" blake3 = "1"
cfg-if = "1" cfg-if = "1"
constant_time_eq = "0.2"
crossbeam-channel = "0.5" crossbeam-channel = "0.5"
getrandom = "0.2" getrandom = "0.2"
hashbrown = { version = "0.12", default-features = false } hashbrown = { version = "0.12", default-features = false }

View file

@ -6,6 +6,7 @@ use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use anyhow::Context; use anyhow::Context;
use constant_time_eq::constant_time_eq;
use crossbeam_channel::{Sender, TrySendError}; use crossbeam_channel::{Sender, TrySendError};
use getrandom::getrandom; use getrandom::getrandom;
@ -59,7 +60,14 @@ impl ConnectionValidator {
let valid_until = let valid_until =
(self.start_time.elapsed().as_secs() as u32 + self.max_connection_age).to_ne_bytes(); (self.start_time.elapsed().as_secs() as u32 + self.max_connection_age).to_ne_bytes();
self.create_connection_id_inner(valid_until, source_addr) let hash = self.hash(valid_until, source_addr.get().ip());
let mut connection_id_bytes = [0u8; 8];
(&mut connection_id_bytes[..4]).copy_from_slice(&valid_until);
(&mut connection_id_bytes[4..]).copy_from_slice(&hash);
ConnectionId(i64::from_ne_bytes(connection_id_bytes))
} }
pub fn connection_id_valid( pub fn connection_id_valid(
@ -67,63 +75,31 @@ impl ConnectionValidator {
source_addr: CanonicalSocketAddr, source_addr: CanonicalSocketAddr,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> bool { ) -> bool {
let valid_until = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); let bytes = connection_id.0.to_ne_bytes();
let (valid_until, hash) = bytes.split_at(4);
let valid_until: [u8; 4] = valid_until.try_into().unwrap();
// Check that recreating ConnectionId with same inputs yields identical hash. if !constant_time_eq(hash, &self.hash(valid_until, source_addr.get().ip())) {
if !Self::connection_id_eq(
connection_id,
self.create_connection_id_inner(valid_until, source_addr),
) {
return false; return false;
} }
u32::from_ne_bytes(valid_until) > self.start_time.elapsed().as_secs() as u32 u32::from_ne_bytes(valid_until) > self.start_time.elapsed().as_secs() as u32
} }
fn create_connection_id_inner( fn hash(&mut self, valid_until: [u8; 4], ip_addr: IpAddr) -> [u8; 4] {
&mut self,
valid_until: [u8; 4],
source_addr: CanonicalSocketAddr,
) -> ConnectionId {
let mut connection_id_bytes = [0u8; 8];
(&mut connection_id_bytes[..4]).copy_from_slice(&valid_until);
self.keyed_hasher.update(&valid_until); self.keyed_hasher.update(&valid_until);
match source_addr.get().ip() { match ip_addr {
IpAddr::V4(ip) => self.keyed_hasher.update(&ip.octets()), IpAddr::V4(ip) => self.keyed_hasher.update(&ip.octets()),
IpAddr::V6(ip) => self.keyed_hasher.update(&ip.octets()), IpAddr::V6(ip) => self.keyed_hasher.update(&ip.octets()),
}; };
self.keyed_hasher let mut hash = [0u8; 4];
.finalize_xof()
.fill(&mut connection_id_bytes[4..]); self.keyed_hasher.finalize_xof().fill(&mut hash);
self.keyed_hasher.reset(); self.keyed_hasher.reset();
ConnectionId(i64::from_ne_bytes(connection_id_bytes)) hash
}
/// Compare ConnectionIDs without breaking constant time requirements
///
/// Use this instead of PartialEq::eq to avoid optimizations breaking constant
/// time HMAC comparison and thus strongly reducing security.
#[cfg(target_arch = "x86_64")]
fn connection_id_eq(a: ConnectionId, b: ConnectionId) -> bool {
let mut eq = 0u8;
unsafe {
::std::arch::asm!(
"cmp {a}, {b}",
"sete {eq}",
a = in(reg) a.0,
b = in(reg) b.0,
eq = inout(reg_byte) eq,
options(nomem, nostack),
);
}
eq != 0
} }
} }
@ -396,12 +372,4 @@ mod tests {
quickcheck::TestResult::from_bool(original_valid) quickcheck::TestResult::from_bool(original_valid)
} }
} }
#[quickcheck]
fn test_connection_id_eq(a: i64, b: i64) -> bool {
let a = ConnectionId(a);
let b = ConnectionId(b);
ConnectionValidator::connection_id_eq(a, b) == (a == b)
}
} }