aquatic_udp: improve code for request connection validity checks

This commit is contained in:
Joakim Frostegård 2021-10-15 23:59:16 +02:00
parent 03c36f51c9
commit 881579435a
2 changed files with 34 additions and 30 deletions

View file

@ -36,6 +36,15 @@ pub struct ConnectionKey {
pub socket_addr: SocketAddr, pub socket_addr: SocketAddr,
} }
impl ConnectionKey {
pub fn new(connection_id: ConnectionId, socket_addr: SocketAddr) -> Self {
Self {
connection_id,
socket_addr,
}
}
}
pub type ConnectionMap = HashMap<ConnectionKey, ValidUntil>; pub type ConnectionMap = HashMap<ConnectionKey, ValidUntil>;
#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)]

View file

@ -88,43 +88,31 @@ pub fn run_request_worker(
// Check announce and scrape requests for valid connections // Check announce and scrape requests for valid connections
announce_requests.retain(|(request, src)| { announce_requests.retain(|(request, src)| {
let connection_key = ConnectionKey { let connection_valid =
connection_id: request.connection_id, connections.contains_key(&ConnectionKey::new(request.connection_id, *src));
socket_addr: *src,
};
if connections.contains_key(&connection_key) { if !connection_valid {
true responses.push((
} else { create_invalid_connection_response(request.transaction_id),
let response = ErrorResponse { *src,
transaction_id: request.transaction_id, ));
message: "Connection invalid or expired".to_string(),
};
responses.push((response.into(), *src));
return false;
} }
connection_valid
}); });
scrape_requests.retain(|(request, src)| { scrape_requests.retain(|(request, src)| {
let connection_key = ConnectionKey { let connection_valid =
connection_id: request.connection_id, connections.contains_key(&ConnectionKey::new(request.connection_id, *src));
socket_addr: *src,
};
if connections.contains_key(&connection_key) { if !connection_valid {
true responses.push((
} else { create_invalid_connection_response(request.transaction_id),
let response = ErrorResponse { *src,
transaction_id: request.transaction_id, ));
message: "Connection invalid or expired".to_string(),
};
responses.push((response.into(), *src));
false
} }
connection_valid
}); });
::std::mem::drop(connections); ::std::mem::drop(connections);
@ -151,3 +139,10 @@ pub fn run_request_worker(
} }
} }
} }
fn create_invalid_connection_response(transaction_id: TransactionId) -> Response {
Response::Error(ErrorResponse {
transaction_id,
message: "Connection invalid or expired".to_string(),
})
}