diff --git a/aquatic_udp/src/lib.rs b/aquatic_udp/src/lib.rs index 2a735b1..a6b22f8 100644 --- a/aquatic_udp/src/lib.rs +++ b/aquatic_udp/src/lib.rs @@ -21,6 +21,7 @@ use common::{ }; use config::Config; use workers::socket::validator::ConnectionValidator; +use workers::socket::SocketWorker; pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker"; pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -121,11 +122,10 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { WorkerIndex::SocketWorker(i), ); - workers::socket::run_socket_worker( + SocketWorker::run( sentinel, state, config, - i, connection_validator, server_start_instant, request_sender, diff --git a/aquatic_udp/src/workers/socket/mod.rs b/aquatic_udp/src/workers/socket/mod.rs index 6daeeb2..2c196e8 100644 --- a/aquatic_udp/src/workers/socket/mod.rs +++ b/aquatic_udp/src/workers/socket/mod.rs @@ -1,11 +1,12 @@ -mod requests; -mod responses; mod storage; pub mod validator; +use std::io::{Cursor, ErrorKind}; +use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; use anyhow::Context; +use aquatic_common::access_list::AccessListCache; use aquatic_common::ServerStartInstant; use crossbeam_channel::Receiver; use mio::net::UdpSocket; @@ -21,105 +22,397 @@ use aquatic_udp_protocol::*; use crate::common::*; use crate::config::Config; -use requests::read_requests; -use responses::send_responses; use storage::PendingScrapeResponseSlab; use validator::ConnectionValidator; -pub fn run_socket_worker( - _sentinel: PanicSentinel, - state: State, +pub struct SocketWorker { config: Config, - token_num: usize, - mut connection_validator: ConnectionValidator, - server_start_instant: ServerStartInstant, + shared_state: State, request_sender: ConnectedRequestSender, response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, - priv_dropper: PrivilegeDropper, -) { - let mut buffer = [0u8; BUFFER_SIZE]; + access_list_cache: AccessListCache, + validator: ConnectionValidator, + server_start_instant: ServerStartInstant, + pending_scrape_responses: PendingScrapeResponseSlab, + socket: UdpSocket, + buffer: [u8; BUFFER_SIZE], +} - let mut socket = - UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket")); - let mut poll = Poll::new().expect("create poll"); +impl SocketWorker { + pub fn run( + _sentinel: PanicSentinel, + shared_state: State, + config: Config, + validator: ConnectionValidator, + server_start_instant: ServerStartInstant, + request_sender: ConnectedRequestSender, + response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + priv_dropper: PrivilegeDropper, + ) { + let socket = + UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket")); + let access_list_cache = create_access_list_cache(&shared_state.access_list); - let interests = Interest::READABLE; + let mut worker = Self { + config, + shared_state, + validator, + server_start_instant, + request_sender, + response_receiver, + access_list_cache, + pending_scrape_responses: Default::default(), + socket, + buffer: [0; BUFFER_SIZE], + }; - poll.registry() - .register(&mut socket, Token(token_num), interests) - .unwrap(); + worker.run_inner(); + } - let mut events = Events::with_capacity(config.network.poll_event_capacity); - let mut pending_scrape_responses = PendingScrapeResponseSlab::default(); - let mut access_list_cache = create_access_list_cache(&state.access_list); + pub fn run_inner(&mut self) { + let mut local_responses = Vec::new(); + let mut opt_resend_buffer = + (self.config.network.resend_buffer_max_len > 0).then_some(Vec::new()); - let mut local_responses: Vec<(Response, CanonicalSocketAddr)> = Vec::new(); - let mut opt_resend_buffer = (config.network.resend_buffer_max_len > 0).then_some(Vec::new()); + let mut events = Events::with_capacity(self.config.network.poll_event_capacity); + let mut poll = Poll::new().expect("create poll"); - let poll_timeout = Duration::from_millis(config.network.poll_timeout_ms); + poll.registry() + .register(&mut self.socket, Token(0), Interest::READABLE) + .expect("register poll"); - let pending_scrape_cleaning_duration = - Duration::from_secs(config.cleaning.pending_scrape_cleaning_interval); + let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms); - let mut pending_scrape_valid_until = - ValidUntil::new(server_start_instant, config.cleaning.max_pending_scrape_age); - let mut last_pending_scrape_cleaning = Instant::now(); + let pending_scrape_cleaning_duration = + Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval); - let mut iter_counter = 0usize; + let mut pending_scrape_valid_until = ValidUntil::new( + self.server_start_instant, + self.config.cleaning.max_pending_scrape_age, + ); + let mut last_pending_scrape_cleaning = Instant::now(); - loop { - poll.poll(&mut events, Some(poll_timeout)) - .expect("failed polling"); + let mut iter_counter = 0usize; - for event in events.iter() { - let token = event.token(); + loop { + poll.poll(&mut events, Some(poll_timeout)) + .expect("failed polling"); - if (token.0 == token_num) & event.is_readable() { - read_requests( - &config, - &state, - &mut connection_validator, - &mut pending_scrape_responses, - &mut access_list_cache, - &mut socket, - &mut buffer, - &request_sender, - &mut local_responses, - pending_scrape_valid_until, + for event in events.iter() { + if event.is_readable() { + self.read_requests(&mut local_responses, pending_scrape_valid_until); + } + } + + if let Some(resend_buffer) = opt_resend_buffer.as_mut() { + for (response, addr) in resend_buffer.drain(..) { + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut None, + response, + addr, + ); + } + } + + for (response, addr) in local_responses.drain(..) { + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut opt_resend_buffer, + response, + addr, ); } + + for (response, addr) in self.response_receiver.try_iter() { + let opt_response = match response { + ConnectedResponse::Scrape(r) => self + .pending_scrape_responses + .add_and_get_finished(r) + .map(Response::Scrape), + ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)), + ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)), + }; + + if let Some(response) = opt_response { + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut opt_resend_buffer, + response, + addr, + ); + } + } + + // Run periodic ValidUntil updates and state cleaning + if iter_counter % 256 == 0 { + let seconds_since_start = self.server_start_instant.seconds_elapsed(); + + pending_scrape_valid_until = ValidUntil::new_with_now( + seconds_since_start, + self.config.cleaning.max_pending_scrape_age, + ); + + let now = Instant::now(); + + if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { + self.pending_scrape_responses.clean(seconds_since_start); + + last_pending_scrape_cleaning = now; + } + } + + iter_counter = iter_counter.wrapping_add(1); } + } - send_responses( - &state, - &config, - &mut socket, - &mut buffer, - &response_receiver, - &mut pending_scrape_responses, - local_responses.drain(..), - &mut opt_resend_buffer, - ); + fn read_requests( + &mut self, + local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, + pending_scrape_valid_until: ValidUntil, + ) { + let mut requests_received_ipv4: usize = 0; + let mut requests_received_ipv6: usize = 0; + let mut bytes_received_ipv4: usize = 0; + let mut bytes_received_ipv6 = 0; - // Run periodic ValidUntil updates and state cleaning - if iter_counter % 256 == 0 { - let seconds_since_start = server_start_instant.seconds_elapsed(); + loop { + match self.socket.recv_from(&mut self.buffer[..]) { + Ok((bytes_read, src)) => { + if src.port() == 0 { + ::log::info!("Ignored request from {} because source port is zero", src); - pending_scrape_valid_until = ValidUntil::new_with_now( - seconds_since_start, - config.cleaning.max_pending_scrape_age, - ); + continue; + } - let now = Instant::now(); + let res_request = Request::from_bytes( + &self.buffer[..bytes_read], + self.config.protocol.max_scrape_torrents, + ); - if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { - pending_scrape_responses.clean(seconds_since_start); + let src = CanonicalSocketAddr::new(src); - last_pending_scrape_cleaning = now; + // Update statistics for converted address + if src.is_ipv4() { + if res_request.is_ok() { + requests_received_ipv4 += 1; + } + bytes_received_ipv4 += bytes_read; + } else { + if res_request.is_ok() { + requests_received_ipv6 += 1; + } + bytes_received_ipv6 += bytes_read; + } + + self.handle_request( + local_responses, + pending_scrape_valid_until, + res_request, + src, + ); + } + Err(err) if err.kind() == ErrorKind::WouldBlock => { + break; + } + Err(err) => { + ::log::warn!("recv_from error: {:#}", err); + } } } - iter_counter = iter_counter.wrapping_add(1); + if self.config.statistics.active() { + self.shared_state + .statistics_ipv4 + .requests_received + .fetch_add(requests_received_ipv4, Ordering::Release); + self.shared_state + .statistics_ipv6 + .requests_received + .fetch_add(requests_received_ipv6, Ordering::Release); + self.shared_state + .statistics_ipv4 + .bytes_received + .fetch_add(bytes_received_ipv4, Ordering::Release); + self.shared_state + .statistics_ipv6 + .bytes_received + .fetch_add(bytes_received_ipv6, Ordering::Release); + } + } + + fn handle_request( + &mut self, + local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, + pending_scrape_valid_until: ValidUntil, + res_request: Result, + src: CanonicalSocketAddr, + ) { + let access_list_mode = self.config.access_list.mode; + + match res_request { + Ok(Request::Connect(request)) => { + let connection_id = self.validator.create_connection_id(src); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, + }); + + local_responses.push((response, src)) + } + Ok(Request::Announce(request)) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + if self + .access_list_cache + .load() + .allows(access_list_mode, &request.info_hash.0) + { + let worker_index = + SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); + + self.request_sender.try_send_to( + worker_index, + ConnectedRequest::Announce(request), + src, + ); + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + local_responses.push((response, src)) + } + } + } + Ok(Request::Scrape(request)) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + let split_requests = self.pending_scrape_responses.prepare_split_requests( + &self.config, + request, + pending_scrape_valid_until, + ); + + for (swarm_worker_index, request) in split_requests { + self.request_sender.try_send_to( + swarm_worker_index, + ConnectedRequest::Scrape(request), + src, + ); + } + } + } + Err(err) => { + ::log::debug!("Request::from_bytes error: {:?}", err); + + if let RequestParseError::Sendable { + connection_id, + transaction_id, + err, + } = err + { + if self.validator.connection_id_valid(src, connection_id) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + local_responses.push((response.into(), src)); + } + } + } + } + } + + fn send_response( + config: &Config, + shared_state: &State, + socket: &mut UdpSocket, + buffer: &mut [u8], + opt_resend_buffer: &mut Option>, + response: Response, + canonical_addr: CanonicalSocketAddr, + ) { + let mut cursor = Cursor::new(buffer); + + if let Err(err) = response.write(&mut cursor) { + ::log::error!("Converting response to bytes failed: {:#}", err); + + return; + } + + let bytes_written = cursor.position() as usize; + + let addr = if config.network.address.is_ipv4() { + canonical_addr + .get_ipv4() + .expect("found peer ipv6 address while running bound to ipv4 address") + } else { + canonical_addr.get_ipv6_mapped() + }; + + match socket.send_to(&cursor.get_ref()[..bytes_written], addr) { + Ok(amt) if config.statistics.active() => { + let stats = if canonical_addr.is_ipv4() { + &shared_state.statistics_ipv4 + } else { + &shared_state.statistics_ipv6 + }; + + stats.bytes_sent.fetch_add(amt, Ordering::Relaxed); + + match response { + Response::Connect(_) => { + stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed); + } + Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { + stats + .responses_sent_announce + .fetch_add(1, Ordering::Relaxed); + } + Response::Scrape(_) => { + stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed); + } + Response::Error(_) => { + stats.responses_sent_error.fetch_add(1, Ordering::Relaxed); + } + } + } + Ok(_) => (), + Err(err) => match opt_resend_buffer.as_mut() { + Some(resend_buffer) + if (err.raw_os_error() == Some(libc::ENOBUFS)) + || (err.kind() == ErrorKind::WouldBlock) => + { + if resend_buffer.len() < config.network.resend_buffer_max_len { + ::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); + + resend_buffer.push((response, canonical_addr)); + } else { + ::log::warn!("Response resend buffer full, dropping response"); + } + } + _ => { + ::log::warn!("Sending response to {} failed: {:#}", addr, err); + } + }, + } } } diff --git a/aquatic_udp/src/workers/socket/requests.rs b/aquatic_udp/src/workers/socket/requests.rs deleted file mode 100644 index eb995bd..0000000 --- a/aquatic_udp/src/workers/socket/requests.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::io::ErrorKind; -use std::sync::atomic::Ordering; - -use mio::net::UdpSocket; - -use aquatic_common::{access_list::AccessListCache, CanonicalSocketAddr, ValidUntil}; -use aquatic_udp_protocol::*; - -use crate::common::*; -use crate::config::Config; - -use super::storage::PendingScrapeResponseSlab; -use super::validator::ConnectionValidator; - -pub fn read_requests( - config: &Config, - state: &State, - connection_validator: &mut ConnectionValidator, - pending_scrape_responses: &mut PendingScrapeResponseSlab, - access_list_cache: &mut AccessListCache, - socket: &mut UdpSocket, - buffer: &mut [u8], - request_sender: &ConnectedRequestSender, - local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, - pending_scrape_valid_until: ValidUntil, -) { - let mut requests_received_ipv4: usize = 0; - let mut requests_received_ipv6: usize = 0; - let mut bytes_received_ipv4: usize = 0; - let mut bytes_received_ipv6 = 0; - - loop { - match socket.recv_from(&mut buffer[..]) { - Ok((bytes_read, src)) => { - if src.port() == 0 { - ::log::info!("Ignored request from {} because source port is zero", src); - - continue; - } - - let res_request = - Request::from_bytes(&buffer[..bytes_read], config.protocol.max_scrape_torrents); - - let src = CanonicalSocketAddr::new(src); - - // Update statistics for converted address - if src.is_ipv4() { - if res_request.is_ok() { - requests_received_ipv4 += 1; - } - bytes_received_ipv4 += bytes_read; - } else { - if res_request.is_ok() { - requests_received_ipv6 += 1; - } - bytes_received_ipv6 += bytes_read; - } - - handle_request( - config, - connection_validator, - pending_scrape_responses, - access_list_cache, - request_sender, - local_responses, - pending_scrape_valid_until, - res_request, - src, - ); - } - Err(err) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(err) => { - ::log::warn!("recv_from error: {:#}", err); - } - } - } - - if config.statistics.active() { - state - .statistics_ipv4 - .requests_received - .fetch_add(requests_received_ipv4, Ordering::Release); - state - .statistics_ipv6 - .requests_received - .fetch_add(requests_received_ipv6, Ordering::Release); - state - .statistics_ipv4 - .bytes_received - .fetch_add(bytes_received_ipv4, Ordering::Release); - state - .statistics_ipv6 - .bytes_received - .fetch_add(bytes_received_ipv6, Ordering::Release); - } -} - -fn handle_request( - config: &Config, - connection_validator: &mut ConnectionValidator, - pending_scrape_responses: &mut PendingScrapeResponseSlab, - access_list_cache: &mut AccessListCache, - request_sender: &ConnectedRequestSender, - local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, - pending_scrape_valid_until: ValidUntil, - res_request: Result, - src: CanonicalSocketAddr, -) { - let access_list_mode = config.access_list.mode; - - match res_request { - Ok(Request::Connect(request)) => { - let connection_id = connection_validator.create_connection_id(src); - - let response = Response::Connect(ConnectResponse { - connection_id, - transaction_id: request.transaction_id, - }); - - local_responses.push((response, src)) - } - Ok(Request::Announce(request)) => { - if connection_validator.connection_id_valid(src, request.connection_id) { - if access_list_cache - .load() - .allows(access_list_mode, &request.info_hash.0) - { - let worker_index = SwarmWorkerIndex::from_info_hash(config, request.info_hash); - - request_sender.try_send_to( - worker_index, - ConnectedRequest::Announce(request), - src, - ); - } else { - let response = Response::Error(ErrorResponse { - transaction_id: request.transaction_id, - message: "Info hash not allowed".into(), - }); - - local_responses.push((response, src)) - } - } - } - Ok(Request::Scrape(request)) => { - if connection_validator.connection_id_valid(src, request.connection_id) { - let split_requests = pending_scrape_responses.prepare_split_requests( - config, - request, - pending_scrape_valid_until, - ); - - for (swarm_worker_index, request) in split_requests { - request_sender.try_send_to( - swarm_worker_index, - ConnectedRequest::Scrape(request), - src, - ); - } - } - } - Err(err) => { - ::log::debug!("Request::from_bytes error: {:?}", err); - - if let RequestParseError::Sendable { - connection_id, - transaction_id, - err, - } = err - { - if connection_validator.connection_id_valid(src, connection_id) { - let response = ErrorResponse { - transaction_id, - message: err.right_or("Parse error").into(), - }; - - local_responses.push((response.into(), src)); - } - } - } - } -} diff --git a/aquatic_udp/src/workers/socket/responses.rs b/aquatic_udp/src/workers/socket/responses.rs deleted file mode 100644 index 6ed78d1..0000000 --- a/aquatic_udp/src/workers/socket/responses.rs +++ /dev/null @@ -1,143 +0,0 @@ -use std::io::{Cursor, ErrorKind}; -use std::sync::atomic::Ordering; -use std::vec::Drain; - -use crossbeam_channel::Receiver; -use libc::ENOBUFS; -use mio::net::UdpSocket; - -use aquatic_common::CanonicalSocketAddr; -use aquatic_udp_protocol::*; - -use crate::common::*; -use crate::config::Config; - -use super::storage::PendingScrapeResponseSlab; - -pub fn send_responses( - state: &State, - config: &Config, - socket: &mut UdpSocket, - buffer: &mut [u8], - response_receiver: &Receiver<(ConnectedResponse, CanonicalSocketAddr)>, - pending_scrape_responses: &mut PendingScrapeResponseSlab, - local_responses: Drain<(Response, CanonicalSocketAddr)>, - opt_resend_buffer: &mut Option>, -) { - if let Some(resend_buffer) = opt_resend_buffer { - for (response, addr) in resend_buffer.drain(..) { - send_response(state, config, socket, buffer, response, addr, &mut None); - } - } - - for (response, addr) in local_responses { - send_response( - state, - config, - socket, - buffer, - response, - addr, - opt_resend_buffer, - ); - } - - for (response, addr) in response_receiver.try_iter() { - let opt_response = match response { - ConnectedResponse::Scrape(r) => pending_scrape_responses - .add_and_get_finished(r) - .map(Response::Scrape), - ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)), - ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)), - }; - - if let Some(response) = opt_response { - send_response( - state, - config, - socket, - buffer, - response, - addr, - opt_resend_buffer, - ); - } - } -} - -fn send_response( - state: &State, - config: &Config, - socket: &mut UdpSocket, - buffer: &mut [u8], - response: Response, - canonical_addr: CanonicalSocketAddr, - resend_buffer: &mut Option>, -) { - let mut cursor = Cursor::new(buffer); - - if let Err(err) = response.write(&mut cursor) { - ::log::error!("Converting response to bytes failed: {:#}", err); - - return; - } - - let bytes_written = cursor.position() as usize; - - let addr = if config.network.address.is_ipv4() { - canonical_addr - .get_ipv4() - .expect("found peer ipv6 address while running bound to ipv4 address") - } else { - canonical_addr.get_ipv6_mapped() - }; - - match socket.send_to(&cursor.get_ref()[..bytes_written], addr) { - Ok(amt) if config.statistics.active() => { - let stats = if canonical_addr.is_ipv4() { - &state.statistics_ipv4 - } else { - &state.statistics_ipv6 - }; - - stats.bytes_sent.fetch_add(amt, Ordering::Relaxed); - - match response { - Response::Connect(_) => { - stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed); - } - Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { - stats - .responses_sent_announce - .fetch_add(1, Ordering::Relaxed); - } - Response::Scrape(_) => { - stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed); - } - Response::Error(_) => { - stats.responses_sent_error.fetch_add(1, Ordering::Relaxed); - } - } - } - Ok(_) => (), - Err(err) => { - match resend_buffer { - Some(resend_buffer) - if (err.raw_os_error() == Some(ENOBUFS)) - || (err.kind() == ErrorKind::WouldBlock) => - { - if resend_buffer.len() < config.network.resend_buffer_max_len { - ::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); - - resend_buffer.push((response, canonical_addr)); - } else { - ::log::warn!("Response resend buffer full, dropping response"); - } - } - _ => { - ::log::warn!("Sending response to {} failed: {:#}", addr, err); - } - } - } - } -}