diff --git a/aquatic_ws/src/lib/common.rs b/aquatic_ws/src/lib/common.rs index 0cd1a07..899adfb 100644 --- a/aquatic_ws/src/lib/common.rs +++ b/aquatic_ws/src/lib/common.rs @@ -13,6 +13,10 @@ pub use aquatic_common::ValidUntil; use aquatic_ws_protocol::*; +pub const LISTENER_TOKEN: Token = Token(0); +pub const CHANNEL_TOKEN: Token = Token(1); + + #[derive(Clone, Copy, Debug)] pub struct ConnectionMeta { /// Index of socket worker responsible for this connection. Required for diff --git a/aquatic_ws/src/lib/config.rs b/aquatic_ws/src/lib/config.rs index 339bc20..1a70095 100644 --- a/aquatic_ws/src/lib/config.rs +++ b/aquatic_ws/src/lib/config.rs @@ -114,7 +114,7 @@ impl Default for NetworkConfig { tls_pkcs12_path: "".into(), tls_pkcs12_password: "".into(), poll_event_capacity: 4096, - poll_timeout_microseconds: 1000, + poll_timeout_microseconds: 200_000, websocket_max_message_size: 64 * 1024, websocket_max_frame_size: 16 * 1024, } diff --git a/aquatic_ws/src/lib/handler.rs b/aquatic_ws/src/lib/handler.rs index 0738293..8ec0e03 100644 --- a/aquatic_ws/src/lib/handler.rs +++ b/aquatic_ws/src/lib/handler.rs @@ -1,7 +1,9 @@ use std::time::Duration; use std::vec::Drain; +use std::sync::Arc; use hashbrown::HashMap; +use mio::Waker; use parking_lot::MutexGuard; use rand::{Rng, SeedableRng, rngs::SmallRng}; @@ -17,8 +19,11 @@ pub fn run_request_worker( state: State, in_message_receiver: InMessageReceiver, out_message_sender: OutMessageSender, + wakers: Vec>, ){ - let mut out_messages = Vec::new(); + let mut wake_socket_workers: Vec = (0..config.socket_workers) + .map(|_| false) + .collect(); let mut announce_requests = Vec::new(); let mut scrape_requests = Vec::new(); @@ -63,21 +68,27 @@ pub fn run_request_worker( &config, &mut rng, &mut torrent_map_guard, - &mut out_messages, + &out_message_sender, + &mut wake_socket_workers, announce_requests.drain(..) ); handle_scrape_requests( &config, &mut torrent_map_guard, - &mut out_messages, + &out_message_sender, + &mut wake_socket_workers, scrape_requests.drain(..) ); - ::std::mem::drop(torrent_map_guard); + for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate(){ + if *wake { + if let Err(err) = wakers[worker_index].wake(){ + ::log::error!("request handler couldn't wake poll: {:?}", err); + } - for (meta, out_message) in out_messages.drain(..){ - out_message_sender.send(meta, out_message); + *wake = false; + } } } } @@ -87,7 +98,8 @@ pub fn handle_announce_requests( config: &Config, rng: &mut impl Rng, torrent_maps: &mut TorrentMaps, - messages_out: &mut Vec<(ConnectionMeta, OutMessage)>, + out_message_sender: &OutMessageSender, + wake_socket_workers: &mut Vec, requests: Drain<(ConnectionMeta, AnnounceRequest)>, ){ let valid_until = ValidUntil::new(config.cleaning.max_peer_age); @@ -191,10 +203,11 @@ pub fn handle_announce_requests( offer_id: offer.offer_id, }; - messages_out.push(( + out_message_sender.send( offer_receiver.connection_meta, OutMessage::Offer(middleman_offer) - )); + ); + wake_socket_workers[offer_receiver.connection_meta.worker_index] = true; } } @@ -213,10 +226,11 @@ pub fn handle_announce_requests( offer_id, }; - messages_out.push(( + out_message_sender.send( answer_receiver.connection_meta, OutMessage::Answer(middleman_answer) - )); + ); + wake_socket_workers[answer_receiver.connection_meta.worker_index] = true; } } @@ -228,7 +242,8 @@ pub fn handle_announce_requests( announce_interval: config.protocol.peer_announce_interval, }); - messages_out.push((request_sender_meta, response)); + out_message_sender.send(request_sender_meta, response); + wake_socket_workers[request_sender_meta.worker_index] = true; } } @@ -236,7 +251,8 @@ pub fn handle_announce_requests( pub fn handle_scrape_requests( config: &Config, torrent_maps: &mut TorrentMaps, - messages_out: &mut Vec<(ConnectionMeta, OutMessage)>, + out_message_sender: &OutMessageSender, + wake_socket_workers: &mut Vec, requests: Drain<(ConnectionMeta, ScrapeRequest)>, ){ for (meta, request) in requests { @@ -275,6 +291,7 @@ pub fn handle_scrape_requests( } } - messages_out.push((meta, OutMessage::ScrapeResponse(response))); + out_message_sender.send(meta, OutMessage::ScrapeResponse(response)); + wake_socket_workers[meta.worker_index] = true; } } \ No newline at end of file diff --git a/aquatic_ws/src/lib/lib.rs b/aquatic_ws/src/lib/lib.rs index 91ec9da..fa057f0 100644 --- a/aquatic_ws/src/lib/lib.rs +++ b/aquatic_ws/src/lib/lib.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::thread::Builder; use anyhow::Context; +use mio::{Poll, Waker}; use native_tls::{Identity, TlsAcceptor}; use parking_lot::Mutex; use privdrop::PrivDrop; @@ -27,6 +28,7 @@ pub fn run(config: Config) -> anyhow::Result<()> { let (in_message_sender, in_message_receiver) = ::crossbeam_channel::unbounded(); let mut out_message_senders = Vec::new(); + let mut wakers = Vec::new(); let socket_worker_statuses: SocketWorkerStatuses = { let mut statuses = Vec::new(); @@ -43,16 +45,20 @@ pub fn run(config: Config) -> anyhow::Result<()> { let socket_worker_statuses = socket_worker_statuses.clone(); let in_message_sender = in_message_sender.clone(); let opt_tls_acceptor = opt_tls_acceptor.clone(); + let poll = Poll::new()?; + let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?); let (out_message_sender, out_message_receiver) = ::crossbeam_channel::unbounded(); out_message_senders.push(out_message_sender); + wakers.push(waker); Builder::new().name(format!("socket-{:02}", i + 1)).spawn(move || { network::run_socket_worker( config, i, socket_worker_statuses, + poll, in_message_sender, out_message_receiver, opt_tls_acceptor @@ -99,6 +105,7 @@ pub fn run(config: Config) -> anyhow::Result<()> { state, in_message_receiver, out_message_sender, + wakers, ); })?; } diff --git a/aquatic_ws/src/lib/network/mod.rs b/aquatic_ws/src/lib/network/mod.rs index a227b13..6aec229 100644 --- a/aquatic_ws/src/lib/network/mod.rs +++ b/aquatic_ws/src/lib/network/mod.rs @@ -26,6 +26,7 @@ pub fn run_socket_worker( config: Config, socket_worker_index: usize, socket_worker_statuses: SocketWorkerStatuses, + poll: Poll, in_message_sender: InMessageSender, out_message_receiver: OutMessageReceiver, opt_tls_acceptor: Option, @@ -37,6 +38,7 @@ pub fn run_socket_worker( run_poll_loop( config, socket_worker_index, + poll, in_message_sender, out_message_receiver, listener, @@ -55,6 +57,7 @@ pub fn run_socket_worker( pub fn run_poll_loop( config: Config, socket_worker_index: usize, + mut poll: Poll, in_message_sender: InMessageSender, out_message_receiver: OutMessageReceiver, listener: ::std::net::TcpListener, @@ -70,11 +73,10 @@ pub fn run_poll_loop( }; let mut listener = TcpListener::from_std(listener); - let mut poll = Poll::new().expect("create poll"); let mut events = Events::with_capacity(config.network.poll_event_capacity); poll.registry() - .register(&mut listener, Token(0), Interest::READABLE) + .register(&mut listener, LISTENER_TOKEN, Interest::READABLE) .unwrap(); let mut connections: ConnectionMap = HashMap::new(); @@ -91,7 +93,7 @@ pub fn run_poll_loop( for event in events.iter(){ let token = event.token(); - if token.0 == 0 { + if token == LISTENER_TOKEN { accept_new_streams( ws_config, &mut listener, @@ -100,7 +102,7 @@ pub fn run_poll_loop( valid_until, &mut poll_token_counter, ); - } else { + } else if token != CHANNEL_TOKEN { run_handshakes_and_read_messages( socket_worker_index, &in_message_sender, @@ -111,9 +113,7 @@ pub fn run_poll_loop( valid_until, ); } - } - if !out_message_receiver.is_empty(){ send_out_messages( &mut poll, &out_message_receiver, @@ -144,8 +144,8 @@ fn accept_new_streams( Ok((mut stream, _)) => { poll_token_counter.0 = poll_token_counter.0.wrapping_add(1); - if poll_token_counter.0 == 0 { - poll_token_counter.0 = 1; + if poll_token_counter.0 < 2 { + poll_token_counter.0 = 2; } let token = *poll_token_counter; @@ -263,7 +263,9 @@ pub fn send_out_messages( out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>, connections: &mut ConnectionMap, ){ - for (meta, out_message) in out_message_receiver.try_iter(){ + let len = out_message_receiver.len(); + + for (meta, out_message) in out_message_receiver.try_iter().take(len){ let opt_established_ws = connections.get_mut(&meta.poll_token) .and_then(Connection::get_established_ws);