use std::cell::RefCell; use std::collections::BTreeMap; use std::net::SocketAddr; use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; use aquatic_common::access_list::AccessList; use aquatic_common::convert_ipv4_mapped_ipv6; use aquatic_ws_protocol::*; use async_tungstenite::WebSocketStream; use either::Either; use futures::stream::{SplitSink, SplitStream}; use futures_lite::StreamExt; use futures_rustls::server::TlsStream; use futures_rustls::TlsAcceptor; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; use glommio::channels::shared_channel::ConnectedReceiver; use glommio::net::{TcpListener, TcpStream}; use glommio::timer::TimerActionRepeat; use glommio::{enclose, prelude::*}; use hashbrown::HashMap; use slab::Slab; use crate::config::Config; use super::common::*; struct PendingScrapeResponse { pending_worker_out_messages: usize, stats: HashMap, } struct ConnectionReference { out_message_sender: LocalSender<(ConnectionMeta, OutMessage)>, } struct Connection { config: Rc, access_list: Rc>, in_message_senders: Rc>, out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, out_message_consumer_id: ConsumerId, ws_out: SplitSink>, tungstenite::Message>, ws_in: SplitStream>>, peer_addr: SocketAddr, connection_id: ConnectionId, } pub async fn run_socket_worker( config: Config, tls_config: Arc, in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, num_bound_sockets: Arc, access_list: AccessList, ) { let config = Rc::new(config); let access_list = Rc::new(RefCell::new(access_list)); let listener = TcpListener::bind(config.network.address).expect("bind socket"); num_bound_sockets.fetch_add(1, Ordering::SeqCst); let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap(); let in_message_senders = Rc::new(in_message_senders); let (_, mut out_message_receivers) = out_message_mesh_builder.join(Role::Consumer).await.unwrap(); let out_message_consumer_id = ConsumerId(out_message_receivers.consumer_id().unwrap()); let connection_slab = Rc::new(RefCell::new(Slab::new())); let connections_to_remove = Rc::new(RefCell::new(Vec::new())); // Periodically update access list TimerActionRepeat::repeat(enclose!((config, access_list) move || { enclose!((config, access_list) move || async move { update_access_list(config.clone(), access_list.clone()).await; Some(Duration::from_secs(config.cleaning.interval)) })() })); // Periodically remove closed connections TimerActionRepeat::repeat( enclose!((config, connection_slab, connections_to_remove) move || { remove_closed_connections( config.clone(), connection_slab.clone(), connections_to_remove.clone(), ) }), ); for (_, out_message_receiver) in out_message_receivers.streams() { spawn_local(receive_out_messages( out_message_receiver, connection_slab.clone(), )) .detach(); } let mut incoming = listener.incoming(); while let Some(stream) = incoming.next().await { match stream { Ok(stream) => { let (out_message_sender, out_message_receiver) = new_bounded(config.request_workers); let key = connection_slab .borrow_mut() .insert(ConnectionReference { out_message_sender }); spawn_local(enclose!((config, access_list, in_message_senders, tls_config, connections_to_remove) async move { if let Err(err) = Connection::run( config, access_list, in_message_senders, out_message_receiver, out_message_consumer_id, ConnectionId(key), tls_config, stream ).await { ::log::debug!("Connection::run() error: {:?}", err); } connections_to_remove.borrow_mut().push(key); })) .detach(); } Err(err) => { ::log::error!("accept connection: {:?}", err); } } } } async fn remove_closed_connections( config: Rc, connection_slab: Rc>>, connections_to_remove: Rc>>, ) -> Option { let connections_to_remove = connections_to_remove.replace(Vec::new()); for connection_id in connections_to_remove { if let Some(_) = connection_slab.borrow_mut().try_remove(connection_id) { ::log::debug!("removed connection with id {}", connection_id); } else { ::log::error!( "couldn't remove connection with id {}, it is not in connection slab", connection_id ); } } Some(Duration::from_secs(config.cleaning.interval)) } async fn receive_out_messages( mut out_message_receiver: ConnectedReceiver<(ConnectionMeta, OutMessage)>, connection_references: Rc>>, ) { while let Some(channel_out_message) = out_message_receiver.next().await { if let Some(reference) = connection_references .borrow() .get(channel_out_message.0.connection_id.0) { if let Err(err) = reference.out_message_sender.try_send(channel_out_message) { ::log::error!("Couldn't send out_message to local receiver: {:?}", err); } } } } impl Connection { async fn run( config: Rc, access_list: Rc>, in_message_senders: Rc>, out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, out_message_consumer_id: ConsumerId, connection_id: ConnectionId, tls_config: Arc, stream: TcpStream, ) -> anyhow::Result<()> { let peer_addr = stream .peer_addr() .map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?; let tls_acceptor: TlsAcceptor = tls_config.into(); let stream = tls_acceptor.accept(stream).await?; let ws_config = tungstenite::protocol::WebSocketConfig { max_frame_size: Some(config.network.websocket_max_frame_size), max_message_size: Some(config.network.websocket_max_message_size), ..Default::default() }; let stream = async_tungstenite::accept_async_with_config(stream, Some(ws_config)).await?; let (ws_out, ws_in) = futures::StreamExt::split(stream); let mut conn = Connection { config: config.clone(), access_list: access_list.clone(), in_message_senders: in_message_senders.clone(), out_message_receiver, out_message_consumer_id, ws_out, ws_in, peer_addr, connection_id, }; conn.run_message_loop().await?; Ok(()) } async fn run_message_loop(&mut self) -> anyhow::Result<()> { loop { let out_message = match self.read_in_message().await? { Either::Left(out_message) => OutMessage::ErrorResponse(out_message), Either::Right(in_message) => self.handle_in_message(in_message).await?, }; self.write_out_message(&out_message).await?; if matches!(out_message, OutMessage::ErrorResponse(_)) { // TODO: shut down? break; } } Ok(()) } async fn read_in_message(&mut self) -> anyhow::Result> { loop { ::log::debug!("read"); let message = self.ws_in.next().await.unwrap()?; match InMessage::from_ws_message(message) { Ok(in_message) => { ::log::debug!("received in_message: {:?}", in_message); return Ok(Either::Right(in_message)); } Err(err) => { ::log::debug!("Couldn't parse in_message: {:?}", err); let out_message = ErrorResponse { action: None, failure_reason: "Invalid request".into(), info_hash: None, }; return Ok(Either::Left(out_message)); } } } } async fn handle_in_message(&self, in_message: InMessage) -> anyhow::Result { match in_message { InMessage::AnnounceRequest(announce_request) => { let info_hash = announce_request.info_hash; if self .access_list .borrow() .allows(self.config.access_list.mode, &info_hash.0) { let meta = ConnectionMeta { connection_id: self.connection_id, out_message_consumer_id: self.out_message_consumer_id, naive_peer_addr: self.peer_addr, converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()), }; let in_message = InMessage::AnnounceRequest(announce_request); let consumer_index = calculate_in_message_consumer_index(&self.config, info_hash); // Only fails when receiver is closed self.in_message_senders .send_to(consumer_index, (meta, in_message)) .await .unwrap(); self.wait_for_out_message(None).await } else { let out_message = OutMessage::ErrorResponse(ErrorResponse { action: Some(ErrorResponseAction::Announce), failure_reason: "Info hash not allowed".into(), info_hash: Some(info_hash), }); Ok(out_message) } } InMessage::ScrapeRequest(ScrapeRequest { info_hashes, .. }) => { let info_hashes = if let Some(info_hashes) = info_hashes { info_hashes } else { let out_message = OutMessage::ErrorResponse(ErrorResponse { action: Some(ErrorResponseAction::Scrape), failure_reason: "Full scrapes are not allowed".into(), info_hash: None, }); return Ok(out_message); }; let mut info_hashes_by_worker: BTreeMap> = BTreeMap::new(); for info_hash in info_hashes.as_vec() { let info_hashes = info_hashes_by_worker .entry(calculate_in_message_consumer_index(&self.config, info_hash)) .or_default(); info_hashes.push(info_hash); } let pending_worker_out_messages = info_hashes_by_worker.len(); let meta = ConnectionMeta { connection_id: self.connection_id, out_message_consumer_id: self.out_message_consumer_id, naive_peer_addr: self.peer_addr, converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()), }; for (consumer_index, info_hashes) in info_hashes_by_worker { let in_message = InMessage::ScrapeRequest(ScrapeRequest { action: ScrapeAction, info_hashes: Some(ScrapeRequestInfoHashes::Multiple(info_hashes)), }); // Only fails when receiver is closed self.in_message_senders .send_to(consumer_index, (meta, in_message)) .await .unwrap(); } let pending_scrape_out_message = PendingScrapeResponse { pending_worker_out_messages, stats: Default::default(), }; self.wait_for_out_message(Some(pending_scrape_out_message)) .await } } } /// Wait for announce out_message or partial scrape out_messages to arrive, /// return full out_message async fn wait_for_out_message( &self, mut opt_pending_scrape_out_message: Option, ) -> anyhow::Result { loop { let (meta, out_message) = self .out_message_receiver .recv() .await .expect("wait_for_out_message: can't receive out_message, sender is closed"); if meta.naive_peer_addr != self.peer_addr { return Err(anyhow::anyhow!("peer addresses didn't match")); } match out_message { OutMessage::ScrapeResponse(out_message) => { if let Some(mut pending) = opt_pending_scrape_out_message.take() { pending.stats.extend(out_message.files); pending.pending_worker_out_messages -= 1; if pending.pending_worker_out_messages == 0 { let out_message = OutMessage::ScrapeResponse(ScrapeResponse { action: ScrapeAction, files: pending.stats, }); break Ok(out_message); } else { opt_pending_scrape_out_message = Some(pending); } } else { return Err(anyhow::anyhow!( "received channel scrape out_message without pending scrape out_message" )); } } out_message => { break Ok(out_message); } }; } } async fn write_out_message(&mut self, out_message: &OutMessage) -> anyhow::Result<()> { futures::SinkExt::send(&mut self.ws_out, out_message.to_ws_message()).await?; futures::SinkExt::flush(&mut self.ws_out).await?; Ok(()) } } fn calculate_in_message_consumer_index(config: &Config, info_hash: InfoHash) -> usize { (info_hash.0[0] as usize) % config.request_workers }