diff --git a/aquatic_ws/src/workers/socket.rs b/aquatic_ws/src/workers/socket.rs index 0d7f09b..5d730a1 100644 --- a/aquatic_ws/src/workers/socket.rs +++ b/aquatic_ws/src/workers/socket.rs @@ -24,7 +24,8 @@ use glommio::net::{TcpListener, TcpStream}; use glommio::task::JoinHandle; use glommio::timer::{sleep, timeout, TimerActionRepeat}; use glommio::{enclose, prelude::*}; -use hashbrown::{HashMap, HashSet}; +use hashbrown::hash_map::Entry; +use hashbrown::HashMap; use slab::Slab; use crate::config::Config; @@ -45,8 +46,7 @@ struct ConnectionReference { out_message_sender: Rc>, /// Updated after sending message to peer valid_until: ValidUntil, - peer_id: Option, - announced_info_hashes: HashSet, + announced_info_hashes: HashMap, ip_version: IpVersion, } @@ -145,7 +145,6 @@ pub async fn run_socket_worker( server_start_instant, config.cleaning.max_connection_idle, ), - peer_id: None, announced_info_hashes: Default::default(), ip_version, }); @@ -180,26 +179,24 @@ pub async fn run_socket_worker( // Tell swarm workers to remove peer if let Some(reference) = opt_reference { - if let Some(peer_id) = reference.peer_id { - for info_hash in reference.announced_info_hashes { - let message = SwarmControlMessage::ConnectionClosed { - info_hash, - peer_id, - ip_version: reference.ip_version, - }; + for (info_hash, peer_id) in reference.announced_info_hashes { + let message = SwarmControlMessage::ConnectionClosed { + info_hash, + peer_id, + ip_version: reference.ip_version, + }; - let consumer_index = - calculate_in_message_consumer_index(&config, info_hash); + let consumer_index = + calculate_in_message_consumer_index(&config, info_hash); - // Only fails when receiver is closed - control_message_senders - .send_to( - consumer_index, - message - ) - .await - .unwrap(); - } + // Only fails when receiver is closed + control_message_senders + .send_to( + consumer_index, + message + ) + .await + .unwrap(); } } }), tq_regular) @@ -499,26 +496,31 @@ impl ConnectionReader { })?; // Store peer id / check if stored peer id matches - match &mut connection_reference.peer_id { - Some(peer_id) if *peer_id != announce_request.peer_id => { - self.send_error_response( - "Only one peer id can be used per connection".into(), - Some(ErrorResponseAction::Announce), - Some(info_hash), - ) - .await?; - return Err(anyhow::anyhow!("Peer used more than one PeerId")); + match connection_reference + .announced_info_hashes + .entry(announce_request.info_hash) + { + Entry::Occupied(entry) => { + if *entry.get() != announce_request.peer_id { + // Drop Rc borrow before awaiting + drop(connection_slab); + + self.send_error_response( + "Only one peer id can be used per torrent".into(), + Some(ErrorResponseAction::Announce), + Some(info_hash), + ) + .await?; + + return Err(anyhow::anyhow!( + "Peer used more than one PeerId for a single torrent" + )); + } } - Some(_) => (), - opt_peer_id @ None => { - *opt_peer_id = Some(announce_request.peer_id); + Entry::Vacant(entry) => { + entry.insert(announce_request.peer_id); } } - - // Remember info hash for later - connection_reference - .announced_info_hashes - .insert(announce_request.info_hash); } let in_message = InMessage::AnnounceRequest(announce_request);