ws: remove peer from all torrent maps when connection is closed

This commit is contained in:
Joakim Frostegård 2022-07-05 13:13:53 +02:00
parent b30da1a930
commit 720ceacf99
5 changed files with 134 additions and 17 deletions

View file

@ -48,6 +48,7 @@ struct ConnectionReference {
valid_until: ValidUntil,
peer_id: Option<PeerId>,
announced_info_hashes: HashSet<InfoHash>,
peer_addr: CanonicalSocketAddr,
}
pub async fn run_socket_worker(
@ -55,6 +56,7 @@ pub async fn run_socket_worker(
config: Config,
state: State,
tls_config: Arc<RustlsConfig>,
control_message_mesh_builder: MeshBuilder<SwarmControlMessage, Partial>,
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
priv_dropper: PrivilegeDropper,
@ -64,6 +66,12 @@ pub async fn run_socket_worker(
let listener = create_tcp_listener(&config, priv_dropper).expect("create tcp listener");
let (control_message_senders, _) = control_message_mesh_builder
.join(Role::Producer)
.await
.unwrap();
let control_message_senders = Rc::new(control_message_senders);
let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap();
let in_message_senders = Rc::new(in_message_senders);
@ -107,6 +115,18 @@ pub async fn run_socket_worker(
while let Some(stream) = incoming.next().await {
match stream {
Ok(stream) => {
let peer_addr = match stream.peer_addr() {
Ok(peer_addr) => CanonicalSocketAddr::new(peer_addr),
Err(err) => {
::log::info!(
"could not extract peer address, closing connection: {:#}",
err
);
continue;
}
};
let (out_message_sender, out_message_receiver) = new_bounded(LOCAL_CHANNEL_SIZE);
let out_message_sender = Rc::new(out_message_sender);
@ -116,13 +136,14 @@ pub async fn run_socket_worker(
valid_until: ValidUntil::new(config.cleaning.max_connection_idle),
peer_id: None,
announced_info_hashes: Default::default(),
peer_addr,
});
::log::info!("accepting stream: {}", key);
let task_handle = spawn_local_into(enclose!((config, access_list, in_message_senders, connection_slab, tls_config) async move {
let task_handle = spawn_local_into(enclose!((config, access_list, control_message_senders, in_message_senders, connection_slab, tls_config) async move {
if let Err(err) = run_connection(
config,
config.clone(),
access_list,
in_message_senders,
tq_prioritized,
@ -133,12 +154,40 @@ pub async fn run_socket_worker(
out_message_consumer_id,
ConnectionId(key),
tls_config,
stream
stream,
peer_addr,
).await {
::log::debug!("Connection::run() error: {:?}", err);
}
connection_slab.borrow_mut().try_remove(key);
// Remove reference in separate statement to avoid
// multiple RefCell borrows
let opt_reference = connection_slab.borrow_mut().try_remove(key);
// Tell swarm workers to remove peer
if let Some(reference) = opt_reference {
if let Some(peer_id) = reference.peer_id {
for info_hash in reference.announced_info_hashes {
let message = SwarmControlMessage::ConnectionClosed {
info_hash,
peer_id,
peer_addr: reference.peer_addr,
};
let consumer_index =
calculate_in_message_consumer_index(&config, info_hash);
// Only fails when receiver is closed
control_message_senders
.send_to(
consumer_index,
message
)
.await
.unwrap();
}
}
}
}), tq_regular)
.unwrap()
.detach();
@ -148,7 +197,7 @@ pub async fn run_socket_worker(
}
}
Err(err) => {
::log::error!("accept connection: {:?}", err);
::log::error!("accept connection: {:#}", err);
}
}
}
@ -221,12 +270,8 @@ async fn run_connection(
connection_id: ConnectionId,
tls_config: Arc<RustlsConfig>,
stream: TcpStream,
peer_addr: CanonicalSocketAddr,
) -> anyhow::Result<()> {
let peer_addr = stream
.peer_addr()
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?;
let peer_addr = CanonicalSocketAddr::new(peer_addr);
let tls_acceptor: TlsAcceptor = tls_config.into();
let stream = tls_acceptor.accept(stream).await?;
@ -357,6 +402,7 @@ impl ConnectionReader {
)
})?;
// Store peer id / check if stored peer id matches
match &mut connection_reference.peer_id {
Some(peer_id) if *peer_id != announce_request.peer_id => {
self.send_error_response(
@ -373,6 +419,7 @@ impl ConnectionReader {
}
}
// Remember info hash for later
connection_reference
.announced_info_hashes
.insert(announce_request.info_hash);