diff --git a/TODO.md b/TODO.md index 1647d00..5396993 100644 --- a/TODO.md +++ b/TODO.md @@ -20,13 +20,9 @@ * stagger cleaning tasks? * aquatic_ws - * remove peer from all torrent maps when connection is closed - * store `Vec` in ConnectionReference, containing all used - info hashes. When connection is closed, send - InMessage::ConnectionClosed or similar to request workers. - Storing PeerId in ConnectionReference will also be necessary, as - well as making sure clients only use a single one. Alternatively, - a HashMap> can be used for storage. + * Can peer IP address change after connection has been established + due to some kind of renegotition? It would cause issues. + * Add cleaning task for ConnectionHandle.announced_info_hashes? * RES memory still high after traffic stops, even if torrent maps and connection slabs go down to 0 len and capacity * replacing indexmap_amortized / simd_json with equivalents doesn't help * SinkExt::send maybe doesn't wake up properly? diff --git a/aquatic_ws/src/common.rs b/aquatic_ws/src/common.rs index 006ceb8..27aa134 100644 --- a/aquatic_ws/src/common.rs +++ b/aquatic_ws/src/common.rs @@ -4,6 +4,7 @@ use aquatic_common::access_list::AccessListArcSwap; use aquatic_common::CanonicalSocketAddr; pub use aquatic_common::ValidUntil; +use aquatic_ws_protocol::{InfoHash, PeerId}; #[derive(Default, Clone)] pub struct State { @@ -28,3 +29,12 @@ pub struct ConnectionMeta { pub peer_addr: CanonicalSocketAddr, pub pending_scrape_id: Option, } + +#[derive(Clone, Copy, Debug)] +pub enum SwarmControlMessage { + ConnectionClosed { + info_hash: InfoHash, + peer_id: PeerId, + peer_addr: CanonicalSocketAddr, + }, +} diff --git a/aquatic_ws/src/lib.rs b/aquatic_ws/src/lib.rs index daa77ed..66cc57b 100644 --- a/aquatic_ws/src/lib.rs +++ b/aquatic_ws/src/lib.rs @@ -36,6 +36,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE); let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16); + let control_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE); let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel(); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); @@ -52,6 +53,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let config = config.clone(); let state = state.clone(); let tls_config = tls_config.clone(); + let control_mesh_builder = control_mesh_builder.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); let priv_dropper = priv_dropper.clone(); @@ -71,6 +73,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { config, state, tls_config, + control_mesh_builder, request_mesh_builder, response_mesh_builder, priv_dropper, @@ -86,6 +89,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let sentinel = sentinel.clone(); let config = config.clone(); let state = state.clone(); + let control_mesh_builder = control_mesh_builder.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); @@ -103,6 +107,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { sentinel, config, state, + control_mesh_builder, request_mesh_builder, response_mesh_builder, ) diff --git a/aquatic_ws/src/workers/socket.rs b/aquatic_ws/src/workers/socket.rs index 8c04798..e21a72d 100644 --- a/aquatic_ws/src/workers/socket.rs +++ b/aquatic_ws/src/workers/socket.rs @@ -48,6 +48,7 @@ struct ConnectionReference { valid_until: ValidUntil, peer_id: Option, announced_info_hashes: HashSet, + peer_addr: CanonicalSocketAddr, } pub async fn run_socket_worker( @@ -55,6 +56,7 @@ pub async fn run_socket_worker( config: Config, state: State, tls_config: Arc, + control_message_mesh_builder: MeshBuilder, in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, priv_dropper: PrivilegeDropper, @@ -64,6 +66,12 @@ pub async fn run_socket_worker( let listener = create_tcp_listener(&config, priv_dropper).expect("create tcp listener"); + let (control_message_senders, _) = control_message_mesh_builder + .join(Role::Producer) + .await + .unwrap(); + let control_message_senders = Rc::new(control_message_senders); + let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap(); let in_message_senders = Rc::new(in_message_senders); @@ -107,6 +115,18 @@ pub async fn run_socket_worker( while let Some(stream) = incoming.next().await { match stream { Ok(stream) => { + let peer_addr = match stream.peer_addr() { + Ok(peer_addr) => CanonicalSocketAddr::new(peer_addr), + Err(err) => { + ::log::info!( + "could not extract peer address, closing connection: {:#}", + err + ); + + continue; + } + }; + let (out_message_sender, out_message_receiver) = new_bounded(LOCAL_CHANNEL_SIZE); let out_message_sender = Rc::new(out_message_sender); @@ -116,13 +136,14 @@ pub async fn run_socket_worker( valid_until: ValidUntil::new(config.cleaning.max_connection_idle), peer_id: None, announced_info_hashes: Default::default(), + peer_addr, }); ::log::info!("accepting stream: {}", key); - let task_handle = spawn_local_into(enclose!((config, access_list, in_message_senders, connection_slab, tls_config) async move { + let task_handle = spawn_local_into(enclose!((config, access_list, control_message_senders, in_message_senders, connection_slab, tls_config) async move { if let Err(err) = run_connection( - config, + config.clone(), access_list, in_message_senders, tq_prioritized, @@ -133,12 +154,40 @@ pub async fn run_socket_worker( out_message_consumer_id, ConnectionId(key), tls_config, - stream + stream, + peer_addr, ).await { ::log::debug!("Connection::run() error: {:?}", err); } - connection_slab.borrow_mut().try_remove(key); + // Remove reference in separate statement to avoid + // multiple RefCell borrows + let opt_reference = connection_slab.borrow_mut().try_remove(key); + + // Tell swarm workers to remove peer + if let Some(reference) = opt_reference { + if let Some(peer_id) = reference.peer_id { + for info_hash in reference.announced_info_hashes { + let message = SwarmControlMessage::ConnectionClosed { + info_hash, + peer_id, + peer_addr: reference.peer_addr, + }; + + let consumer_index = + calculate_in_message_consumer_index(&config, info_hash); + + // Only fails when receiver is closed + control_message_senders + .send_to( + consumer_index, + message + ) + .await + .unwrap(); + } + } + } }), tq_regular) .unwrap() .detach(); @@ -148,7 +197,7 @@ pub async fn run_socket_worker( } } Err(err) => { - ::log::error!("accept connection: {:?}", err); + ::log::error!("accept connection: {:#}", err); } } } @@ -221,12 +270,8 @@ async fn run_connection( connection_id: ConnectionId, tls_config: Arc, stream: TcpStream, + peer_addr: CanonicalSocketAddr, ) -> anyhow::Result<()> { - let peer_addr = stream - .peer_addr() - .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?; - let peer_addr = CanonicalSocketAddr::new(peer_addr); - let tls_acceptor: TlsAcceptor = tls_config.into(); let stream = tls_acceptor.accept(stream).await?; @@ -357,6 +402,7 @@ impl ConnectionReader { ) })?; + // Store peer id / check if stored peer id matches match &mut connection_reference.peer_id { Some(peer_id) if *peer_id != announce_request.peer_id => { self.send_error_response( @@ -373,6 +419,7 @@ impl ConnectionReader { } } + // Remember info hash for later connection_reference .announced_info_hashes .insert(announce_request.info_hash); diff --git a/aquatic_ws/src/workers/swarm.rs b/aquatic_ws/src/workers/swarm.rs index 87cdd48..0ea07ac 100644 --- a/aquatic_ws/src/workers/swarm.rs +++ b/aquatic_ws/src/workers/swarm.rs @@ -68,6 +68,22 @@ impl Default for TorrentData { } } +impl TorrentData { + pub fn remove_peer(&mut self, peer_id: PeerId) { + if let Some(peer) = self.peers.remove(&peer_id) { + match peer.status { + PeerStatus::Leeching => { + self.num_leechers -= 1; + } + PeerStatus::Seeding => { + self.num_seeders -= 1; + } + PeerStatus::Stopped => (), + } + } + } +} + type TorrentMap = AmortizedIndexMap; #[derive(Default)] @@ -131,9 +147,15 @@ pub async fn run_swarm_worker( _sentinel: PanicSentinel, config: Config, state: State, + control_message_mesh_builder: MeshBuilder, in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, ) { + let (_, mut control_message_receivers) = control_message_mesh_builder + .join(Role::Consumer) + .await + .unwrap(); + let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); @@ -153,6 +175,13 @@ pub async fn run_swarm_worker( let mut handles = Vec::new(); + for (_, receiver) in control_message_receivers.streams() { + let handle = + spawn_local(handle_control_message_stream(torrents.clone(), receiver)).detach(); + + handles.push(handle); + } + for (_, receiver) in in_message_receivers.streams() { let handle = spawn_local(handle_request_stream( config.clone(), @@ -170,6 +199,36 @@ pub async fn run_swarm_worker( } } +async fn handle_control_message_stream(torrents: Rc>, mut stream: S) +where + S: futures_lite::Stream + ::std::marker::Unpin, +{ + while let Some(message) = stream.next().await { + match message { + SwarmControlMessage::ConnectionClosed { + info_hash, + peer_id, + peer_addr, + } => { + ::log::debug!( + "Removing peer {} from torrents because connection was closed", + peer_addr.get() + ); + + if peer_addr.is_ipv4() { + if let Some(torrent_data) = torrents.borrow_mut().ipv4.get_mut(&info_hash) { + torrent_data.remove_peer(peer_id); + } + } else { + if let Some(torrent_data) = torrents.borrow_mut().ipv6.get_mut(&info_hash) { + torrent_data.remove_peer(peer_id); + } + } + } + } + } +} + async fn handle_request_stream( config: Config, torrents: Rc>,