diff --git a/aquatic_udp/src/lib/common.rs b/aquatic_udp/src/lib/common.rs index cccd553..69ae48d 100644 --- a/aquatic_udp/src/lib/common.rs +++ b/aquatic_udp/src/lib/common.rs @@ -36,6 +36,15 @@ pub struct ConnectionKey { pub socket_addr: SocketAddr, } +impl ConnectionKey { + pub fn new(connection_id: ConnectionId, socket_addr: SocketAddr) -> Self { + Self { + connection_id, + socket_addr, + } + } +} + pub type ConnectionMap = HashMap; #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] diff --git a/aquatic_udp/src/lib/handlers/mod.rs b/aquatic_udp/src/lib/handlers/mod.rs index 96fa2ba..dab6baf 100644 --- a/aquatic_udp/src/lib/handlers/mod.rs +++ b/aquatic_udp/src/lib/handlers/mod.rs @@ -88,43 +88,31 @@ pub fn run_request_worker( // Check announce and scrape requests for valid connections announce_requests.retain(|(request, src)| { - let connection_key = ConnectionKey { - connection_id: request.connection_id, - socket_addr: *src, - }; + let connection_valid = + connections.contains_key(&ConnectionKey::new(request.connection_id, *src)); - if connections.contains_key(&connection_key) { - true - } else { - let response = ErrorResponse { - transaction_id: request.transaction_id, - message: "Connection invalid or expired".to_string(), - }; - - responses.push((response.into(), *src)); - - return false; + if !connection_valid { + responses.push(( + create_invalid_connection_response(request.transaction_id), + *src, + )); } + + connection_valid }); scrape_requests.retain(|(request, src)| { - let connection_key = ConnectionKey { - connection_id: request.connection_id, - socket_addr: *src, - }; + let connection_valid = + connections.contains_key(&ConnectionKey::new(request.connection_id, *src)); - if connections.contains_key(&connection_key) { - true - } else { - let response = ErrorResponse { - transaction_id: request.transaction_id, - message: "Connection invalid or expired".to_string(), - }; - - responses.push((response.into(), *src)); - - false + if !connection_valid { + responses.push(( + create_invalid_connection_response(request.transaction_id), + *src, + )); } + + connection_valid }); ::std::mem::drop(connections); @@ -151,3 +139,10 @@ pub fn run_request_worker( } } } + +fn create_invalid_connection_response(transaction_id: TransactionId) -> Response { + Response::Error(ErrorResponse { + transaction_id, + message: "Connection invalid or expired".to_string(), + }) +}