diff --git a/crates/udp/src/common.rs b/crates/udp/src/common.rs index 4dd3496..d26d519 100644 --- a/crates/udp/src/common.rs +++ b/crates/udp/src/common.rs @@ -13,6 +13,7 @@ use crossbeam_utils::CachePadded; use hdrhistogram::Histogram; use crate::config::Config; +use crate::swarm::TorrentMaps; pub const BUFFER_SIZE: usize = 8192; @@ -230,13 +231,15 @@ pub enum StatisticsMessage { #[derive(Clone)] pub struct State { pub access_list: Arc, + pub torrent_maps: TorrentMaps, pub server_start_instant: ServerStartInstant, } -impl Default for State { - fn default() -> Self { +impl State { + pub fn new(config: &Config) -> Self { Self { access_list: Arc::new(AccessListArcSwap::default()), + torrent_maps: TorrentMaps::new(config), server_start_instant: ServerStartInstant::new(), } } diff --git a/crates/udp/src/lib.rs b/crates/udp/src/lib.rs index 49c8aa1..f2a7ec8 100644 --- a/crates/udp/src/lib.rs +++ b/crates/udp/src/lib.rs @@ -23,6 +23,7 @@ use common::{ SwarmWorkerIndex, }; use config::Config; +use swarm::TorrentMaps; use workers::socket::ConnectionValidator; use workers::swarm::SwarmWorker; @@ -32,81 +33,24 @@ pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn run(config: Config) -> ::anyhow::Result<()> { let mut signals = Signals::new([SIGUSR1])?; - let state = State::default(); + let state = State::new(&config); let statistics = Statistics::new(&config); let connection_validator = ConnectionValidator::new(&config)?; let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); + let mut join_handles = Vec::new(); update_access_list(&config.access_list, &state.access_list)?; - let mut request_senders = Vec::new(); - let mut request_receivers = BTreeMap::new(); - - let mut response_senders = Vec::new(); - let mut response_receivers = BTreeMap::new(); - let (statistics_sender, statistics_receiver) = unbounded(); - for i in 0..config.swarm_workers { - let (request_sender, request_receiver) = bounded(config.worker_channel_size); - - request_senders.push(request_sender); - request_receivers.insert(i, request_receiver); - } - - for i in 0..config.socket_workers { - let (response_sender, response_receiver) = bounded(config.worker_channel_size); - - response_senders.push(response_sender); - response_receivers.insert(i, response_receiver); - } - - for i in 0..config.swarm_workers { - let config = config.clone(); - let state = state.clone(); - let request_receiver = request_receivers.remove(&i).unwrap().clone(); - let response_sender = ConnectedResponseSender::new(response_senders.clone()); - let statistics_sender = statistics_sender.clone(); - let statistics = statistics.swarm[i].clone(); - - let handle = Builder::new() - .name(format!("swarm-{:02}", i + 1)) - .spawn(move || { - #[cfg(feature = "cpu-pinning")] - pin_current_if_configured_to( - &config.cpu_pinning, - config.socket_workers, - config.swarm_workers, - WorkerIndex::SwarmWorker(i), - ); - - let mut worker = SwarmWorker { - config, - state, - statistics, - request_receiver, - response_sender, - statistics_sender, - worker_index: SwarmWorkerIndex(i), - }; - - worker.run() - }) - .with_context(|| "spawn swarm worker")?; - - join_handles.push((WorkerType::Swarm(i), handle)); - } - for i in 0..config.socket_workers { let state = state.clone(); let config = config.clone(); let connection_validator = connection_validator.clone(); - let request_sender = - ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone()); - let response_receiver = response_receivers.remove(&i).unwrap(); let priv_dropper = priv_dropper.clone(); let statistics = statistics.socket[i].clone(); + let statistics_sender = statistics_sender.clone(); let handle = Builder::new() .name(format!("socket-{:02}", i + 1)) @@ -123,9 +67,8 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { config, state, statistics, + statistics_sender, connection_validator, - request_sender, - response_receiver, priv_dropper, ) }) diff --git a/crates/udp/src/swarm.rs b/crates/udp/src/swarm.rs index 1c671e1..159e145 100644 --- a/crates/udp/src/swarm.rs +++ b/crates/udp/src/swarm.rs @@ -16,6 +16,7 @@ use aquatic_udp_protocol::*; use arrayvec::ArrayVec; use crossbeam_channel::Sender; use hdrhistogram::Histogram; +use parking_lot::RwLockUpgradableReadGuard; use rand::prelude::SmallRng; use rand::Rng; @@ -37,7 +38,7 @@ pub struct TorrentMaps { impl TorrentMaps { pub fn new(config: &Config) -> Self { - let num_shards = 128usize; + let num_shards = 16usize; Self { ipv4: TorrentMapShards::new(num_shards), @@ -51,10 +52,10 @@ impl TorrentMaps { statistics_sender: &Sender, rng: &mut SmallRng, request: &AnnounceRequest, - ip_address: CanonicalSocketAddr, + src: CanonicalSocketAddr, valid_until: ValidUntil, ) -> Response { - match ip_address.get().ip() { + match src.get().ip() { IpAddr::V4(ip_address) => Response::AnnounceIpv4(self.ipv4.announce( config, statistics_sender, @@ -74,8 +75,8 @@ impl TorrentMaps { } } - pub fn scrape(&self, ip_addr: CanonicalSocketAddr, request: ScrapeRequest) -> ScrapeResponse { - if ip_addr.is_ipv4() { + pub fn scrape(&self, request: ScrapeRequest, src: CanonicalSocketAddr) -> ScrapeResponse { + if src.is_ipv4() { self.ipv4.scrape(request) } else { self.ipv6.scrape(request) @@ -142,20 +143,20 @@ impl TorrentMapShards { ip_address: I, valid_until: ValidUntil, ) -> AnnounceResponse { - let torrent_map_shard = self.get_shard(&request.info_hash); + let torrent_data = { + let torrent_map_shard = self.get_shard(&request.info_hash).upgradable_read(); - // Clone Arc here to avoid keeping lock on whole shard - let torrent_data = - if let Some(torrent_data) = torrent_map_shard.read().get(&request.info_hash) { + // Clone Arc here to avoid keeping lock on whole shard + if let Some(torrent_data) = torrent_map_shard.get(&request.info_hash) { torrent_data.clone() } else { // Don't overwrite entry if created in the meantime - torrent_map_shard - .write() + RwLockUpgradableReadGuard::upgrade(torrent_map_shard) .entry(request.info_hash) .or_default() .clone() - }; + } + }; let mut peer_map = torrent_data.peer_map.write(); diff --git a/crates/udp/src/workers/socket/mio.rs b/crates/udp/src/workers/socket/mio.rs index 26ab5be..9c72039 100644 --- a/crates/udp/src/workers/socket/mio.rs +++ b/crates/udp/src/workers/socket/mio.rs @@ -1,9 +1,10 @@ use std::io::{Cursor, ErrorKind}; use std::sync::atomic::Ordering; -use std::time::{Duration, Instant}; +use std::time::Duration; use anyhow::Context; use aquatic_common::access_list::AccessListCache; +use crossbeam_channel::Sender; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; @@ -12,40 +13,26 @@ use aquatic_common::{ ValidUntil, }; use aquatic_udp_protocol::*; +use rand::rngs::SmallRng; +use rand::SeedableRng; use crate::common::*; use crate::config::Config; -use super::storage::PendingScrapeResponseSlab; use super::validator::ConnectionValidator; use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; -enum HandleRequestError { - RequestChannelFull(Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>), -} - -#[derive(Clone, Copy, Debug)] -enum PollMode { - Regular, - SkipPolling, - SkipReceiving, -} - pub struct SocketWorker { config: Config, shared_state: State, statistics: CachePaddedArc>, - request_sender: ConnectedRequestSender, - response_receiver: ConnectedResponseReceiver, + statistics_sender: Sender, access_list_cache: AccessListCache, validator: ConnectionValidator, - pending_scrape_responses: PendingScrapeResponseSlab, socket: UdpSocket, opt_resend_buffer: Option>, buffer: [u8; BUFFER_SIZE], - polling_mode: PollMode, - /// Storage for requests that couldn't be sent to swarm worker because channel was full - pending_requests: Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>, + rng: SmallRng, } impl SocketWorker { @@ -53,9 +40,8 @@ impl SocketWorker { config: Config, shared_state: State, statistics: CachePaddedArc>, + statistics_sender: Sender, validator: ConnectionValidator, - request_sender: ConnectedRequestSender, - response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, ) -> anyhow::Result<()> { let socket = UdpSocket::from_std(create_socket(&config, priv_dropper)?); @@ -66,16 +52,13 @@ impl SocketWorker { config, shared_state, statistics, + statistics_sender, validator, - request_sender, - response_receiver, access_list_cache, - pending_scrape_responses: Default::default(), socket, opt_resend_buffer, buffer: [0; BUFFER_SIZE], - polling_mode: PollMode::Regular, - pending_requests: Default::default(), + rng: SmallRng::from_entropy(), }; worker.run_inner() @@ -91,39 +74,12 @@ impl SocketWorker { let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms); - let pending_scrape_cleaning_duration = - Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval); - - let mut pending_scrape_valid_until = ValidUntil::new( - self.shared_state.server_start_instant, - self.config.cleaning.max_pending_scrape_age, - ); - let mut last_pending_scrape_cleaning = Instant::now(); - - let mut iter_counter = 0usize; - loop { - match self.polling_mode { - PollMode::Regular => { - poll.poll(&mut events, Some(poll_timeout)).context("poll")?; + poll.poll(&mut events, Some(poll_timeout)).context("poll")?; - for event in events.iter() { - if event.is_readable() { - self.read_and_handle_requests(pending_scrape_valid_until); - } - } - } - PollMode::SkipPolling => { - self.polling_mode = PollMode::Regular; - - // Continue reading from socket without polling, since - // reading was previouly cancelled - self.read_and_handle_requests(pending_scrape_valid_until); - } - PollMode::SkipReceiving => { - ::log::debug!("Postponing receiving requests because swarm worker channel is full. This means that the OS will be relied on to buffer incoming packets. To prevent this, raise config.worker_channel_size."); - - self.polling_mode = PollMode::SkipPolling; + for event in events.iter() { + if event.is_readable() { + self.read_and_handle_requests(); } } @@ -141,44 +97,10 @@ impl SocketWorker { ); } } - - // Check channel for any responses generated by swarm workers - self.handle_swarm_worker_responses(); - - // Try sending pending requests - while let Some((index, request, addr)) = self.pending_requests.pop() { - if let Err(r) = self.request_sender.try_send_to(index, request, addr) { - self.pending_requests.push(r); - - self.polling_mode = PollMode::SkipReceiving; - - break; - } - } - - // Run periodic ValidUntil updates and state cleaning - if iter_counter % 256 == 0 { - let seconds_since_start = self.shared_state.server_start_instant.seconds_elapsed(); - - pending_scrape_valid_until = ValidUntil::new_with_now( - seconds_since_start, - self.config.cleaning.max_pending_scrape_age, - ); - - let now = Instant::now(); - - if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { - self.pending_scrape_responses.clean(seconds_since_start); - - last_pending_scrape_cleaning = now; - } - } - - iter_counter = iter_counter.wrapping_add(1); } } - fn read_and_handle_requests(&mut self, pending_scrape_valid_until: ValidUntil) { + fn read_and_handle_requests(&mut self) { let max_scrape_torrents = self.config.protocol.max_scrape_torrents; loop { @@ -222,14 +144,7 @@ impl SocketWorker { statistics.requests.fetch_add(1, Ordering::Relaxed); } - if let Err(HandleRequestError::RequestChannelFull(failed_requests)) = - self.handle_request(pending_scrape_valid_until, request, src) - { - self.pending_requests.extend(failed_requests); - self.polling_mode = PollMode::SkipReceiving; - - break; - } + self.handle_request(request, src); } Err(RequestParseError::Sendable { connection_id, @@ -271,20 +186,13 @@ impl SocketWorker { } } - fn handle_request( - &mut self, - pending_scrape_valid_until: ValidUntil, - request: Request, - src: CanonicalSocketAddr, - ) -> Result<(), HandleRequestError> { + fn handle_request(&mut self, request: Request, src: CanonicalSocketAddr) { let access_list_mode = self.config.access_list.mode; match request { Request::Connect(request) => { - let connection_id = self.validator.create_connection_id(src); - let response = ConnectResponse { - connection_id, + connection_id: self.validator.create_connection_id(src), transaction_id: request.transaction_id, }; @@ -297,8 +205,6 @@ impl SocketWorker { Response::Connect(response), src, ); - - Ok(()) } Request::Announce(request) => { if self @@ -310,14 +216,29 @@ impl SocketWorker { .load() .allows(access_list_mode, &request.info_hash.0) { - let worker_index = - SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); + let peer_valid_until = ValidUntil::new( + self.shared_state.server_start_instant, + self.config.cleaning.max_peer_age, + ); - self.request_sender - .try_send_to(worker_index, ConnectedRequest::Announce(request), src) - .map_err(|request| { - HandleRequestError::RequestChannelFull(vec![request]) - }) + let response = self.shared_state.torrent_maps.announce( + &self.config, + &self.statistics_sender, + &mut self.rng, + &request, + src, + peer_valid_until, + ); + + send_response( + &self.config, + &self.statistics, + &mut self.socket, + &mut self.buffer, + &mut self.opt_resend_buffer, + response, + src, + ); } else { let response = ErrorResponse { transaction_id: request.transaction_id, @@ -333,11 +254,7 @@ impl SocketWorker { Response::Error(response), src, ); - - Ok(()) } - } else { - Ok(()) } } Request::Scrape(request) => { @@ -345,64 +262,21 @@ impl SocketWorker { .validator .connection_id_valid(src, request.connection_id) { - let split_requests = self.pending_scrape_responses.prepare_split_requests( + let response = self.shared_state.torrent_maps.scrape(request, src); + + send_response( &self.config, - request, - pending_scrape_valid_until, + &self.statistics, + &mut self.socket, + &mut self.buffer, + &mut self.opt_resend_buffer, + Response::Scrape(response), + src, ); - - let mut failed = Vec::new(); - - for (swarm_worker_index, request) in split_requests { - if let Err(request) = self.request_sender.try_send_to( - swarm_worker_index, - ConnectedRequest::Scrape(request), - src, - ) { - failed.push(request); - } - } - - if failed.is_empty() { - Ok(()) - } else { - Err(HandleRequestError::RequestChannelFull(failed)) - } - } else { - Ok(()) } } } } - - fn handle_swarm_worker_responses(&mut self) { - for (addr, response) in self.response_receiver.try_iter() { - let response = match response { - ConnectedResponse::Scrape(response) => { - if let Some(r) = self - .pending_scrape_responses - .add_and_get_finished(&response) - { - Response::Scrape(r) - } else { - continue; - } - } - ConnectedResponse::AnnounceIpv4(r) => Response::AnnounceIpv4(r), - ConnectedResponse::AnnounceIpv6(r) => Response::AnnounceIpv6(r), - }; - - send_response( - &self.config, - &self.statistics, - &mut self.socket, - &mut self.buffer, - &mut self.opt_resend_buffer, - response, - addr, - ); - } - } } fn send_response( @@ -488,4 +362,6 @@ fn send_response( } } } + + ::log::debug!("send response fn finished"); } diff --git a/crates/udp/src/workers/socket/mod.rs b/crates/udp/src/workers/socket/mod.rs index d55e69d..67e6f7f 100644 --- a/crates/udp/src/workers/socket/mod.rs +++ b/crates/udp/src/workers/socket/mod.rs @@ -6,12 +6,13 @@ mod validator; use anyhow::Context; use aquatic_common::privileges::PrivilegeDropper; +use crossbeam_channel::Sender; use socket2::{Domain, Protocol, Socket, Type}; use crate::{ common::{ CachePaddedArc, ConnectedRequestSender, ConnectedResponseReceiver, IpVersionStatistics, - SocketWorkerStatistics, State, + SocketWorkerStatistics, State, StatisticsMessage, }, config::Config, }; @@ -43,11 +44,11 @@ pub fn run_socket_worker( config: Config, shared_state: State, statistics: CachePaddedArc>, + statistics_sender: Sender, validator: ConnectionValidator, - request_sender: ConnectedRequestSender, - response_receiver: ConnectedResponseReceiver, priv_dropper: PrivilegeDropper, ) -> anyhow::Result<()> { + /* #[cfg(all(target_os = "linux", feature = "io-uring"))] if config.network.use_io_uring { self::uring::supported_on_current_kernel().context("check for io_uring compatibility")?; @@ -62,14 +63,14 @@ pub fn run_socket_worker( priv_dropper, ); } + */ self::mio::SocketWorker::run( config, shared_state, statistics, + statistics_sender, validator, - request_sender, - response_receiver, priv_dropper, ) }