diff --git a/aquatic_udp/src/lib/network.rs b/aquatic_udp/src/lib/network.rs index 0ce505d..8ab5d2f 100644 --- a/aquatic_udp/src/lib/network.rs +++ b/aquatic_udp/src/lib/network.rs @@ -19,23 +19,31 @@ use aquatic_udp_protocol::{IpVersion, Request, Response}; use crate::common::*; use crate::config::Config; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ConnectionKey { - pub connection_id: ConnectionId, - pub socket_addr: SocketAddr, -} +#[derive(Default)] +struct ConnectionMap(HashMap<(ConnectionId, SocketAddr), ValidUntil>); -impl ConnectionKey { - pub fn new(connection_id: ConnectionId, socket_addr: SocketAddr) -> Self { - Self { - connection_id, - socket_addr, - } +impl ConnectionMap { + fn insert( + &mut self, + connection_id: ConnectionId, + socket_addr: SocketAddr, + valid_until: ValidUntil, + ) { + self.0.insert((connection_id, socket_addr), valid_until); + } + + fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool { + self.0.contains_key(&(connection_id, socket_addr)) + } + + fn clean(&mut self) { + let now = Instant::now(); + + self.0.retain(|_, v| v.0 > now); + self.0.shrink_to_fit(); } } -pub type ConnectionMap = HashMap; - pub fn run_socket_worker( state: State, config: Config, @@ -107,10 +115,7 @@ pub fn run_socket_worker( local_responses.drain(..), ); - let now = Instant::now(); - - connections.retain(|_, v| v.0 > now); - connections.shrink_to_fit(); + connections.clean(); } } @@ -180,7 +185,7 @@ fn read_requests( Ok(Request::Connect(request)) => { let connection_id = ConnectionId(rng.gen()); - connections.insert(ConnectionKey::new(connection_id, src), valid_until); + connections.insert(connection_id, src, valid_until); let response = Response::Connect(ConnectResponse { connection_id, @@ -190,9 +195,7 @@ fn read_requests( local_responses.push((response, src)) } Ok(Request::Announce(request)) => { - let key = ConnectionKey::new(request.connection_id, src); - - if connections.contains_key(&key) { + if connections.contains(request.connection_id, src) { if state .access_list .allows(access_list_mode, &request.info_hash.0) @@ -209,9 +212,7 @@ fn read_requests( } } Ok(Request::Scrape(request)) => { - let key = ConnectionKey::new(request.connection_id, src); - - if connections.contains_key(&key) { + if connections.contains(request.connection_id, src) { requests.push((ConnectedRequest::Scrape(request), src)); } } @@ -224,7 +225,7 @@ fn read_requests( err, } = err { - if connections.contains_key(&ConnectionKey::new(connection_id, src)) { + if connections.contains(connection_id, src) { let response = ErrorResponse { transaction_id, message: err.right_or("Parse error").into(),