diff --git a/crates/udp/src/workers/socket/mio.rs b/crates/udp/src/workers/socket/mio.rs index 9c72039..fb24bef 100644 --- a/crates/udp/src/workers/socket/mio.rs +++ b/crates/udp/src/workers/socket/mio.rs @@ -30,7 +30,6 @@ pub struct SocketWorker { access_list_cache: AccessListCache, validator: ConnectionValidator, socket: UdpSocket, - opt_resend_buffer: Option>, buffer: [u8; BUFFER_SIZE], rng: SmallRng, } @@ -46,7 +45,6 @@ impl SocketWorker { ) -> anyhow::Result<()> { let socket = UdpSocket::from_std(create_socket(&config, priv_dropper)?); let access_list_cache = create_access_list_cache(&shared_state.access_list); - let opt_resend_buffer = (config.network.resend_buffer_max_len > 0).then_some(Vec::new()); let mut worker = Self { config, @@ -56,7 +54,6 @@ impl SocketWorker { validator, access_list_cache, socket, - opt_resend_buffer, buffer: [0; BUFFER_SIZE], rng: SmallRng::from_entropy(), }; @@ -65,6 +62,8 @@ impl SocketWorker { } pub fn run_inner(&mut self) -> anyhow::Result<()> { + let mut opt_resend_buffer = + (self.config.network.resend_buffer_max_len > 0).then_some(Vec::new()); let mut events = Events::with_capacity(1); let mut poll = Poll::new().context("create poll")?; @@ -79,28 +78,23 @@ impl SocketWorker { for event in events.iter() { if event.is_readable() { - self.read_and_handle_requests(); + self.read_and_handle_requests(&mut opt_resend_buffer); } } // If resend buffer is enabled, send any responses in it - if let Some(resend_buffer) = self.opt_resend_buffer.as_mut() { + if let Some(resend_buffer) = opt_resend_buffer.as_mut() { for (addr, response) in resend_buffer.drain(..) { - send_response( - &self.config, - &self.statistics, - &mut self.socket, - &mut self.buffer, - &mut None, - response, - addr, - ); + self.send_response(&mut None, addr, response); } } } } - fn read_and_handle_requests(&mut self) { + fn read_and_handle_requests( + &mut self, + opt_resend_buffer: &mut Option>, + ) { let max_scrape_torrents = self.config.protocol.max_scrape_torrents; loop { @@ -144,7 +138,9 @@ impl SocketWorker { statistics.requests.fetch_add(1, Ordering::Relaxed); } - self.handle_request(request, src); + if let Some(response) = self.handle_request(request, src) { + self.send_response(opt_resend_buffer, src, response); + } } Err(RequestParseError::Sendable { connection_id, @@ -156,15 +152,7 @@ impl SocketWorker { message: err.into(), }; - send_response( - &self.config, - &self.statistics, - &mut self.socket, - &mut self.buffer, - &mut self.opt_resend_buffer, - Response::Error(response), - src, - ); + self.send_response(opt_resend_buffer, src, Response::Error(response)); ::log::debug!("request parse error (sent error response): {:?}", err); } @@ -186,25 +174,15 @@ impl SocketWorker { } } - fn handle_request(&mut self, request: Request, src: CanonicalSocketAddr) { + fn handle_request(&mut self, request: Request, src: CanonicalSocketAddr) -> Option { let access_list_mode = self.config.access_list.mode; match request { Request::Connect(request) => { - let response = ConnectResponse { + return Some(Response::Connect(ConnectResponse { connection_id: self.validator.create_connection_id(src), transaction_id: request.transaction_id, - }; - - send_response( - &self.config, - &self.statistics, - &mut self.socket, - &mut self.buffer, - &mut self.opt_resend_buffer, - Response::Connect(response), - src, - ); + })); } Request::Announce(request) => { if self @@ -230,30 +208,12 @@ impl SocketWorker { peer_valid_until, ); - send_response( - &self.config, - &self.statistics, - &mut self.socket, - &mut self.buffer, - &mut self.opt_resend_buffer, - response, - src, - ); + return Some(response); } else { - let response = ErrorResponse { + return Some(Response::Error(ErrorResponse { transaction_id: request.transaction_id, message: "Info hash not allowed".into(), - }; - - send_response( - &self.config, - &self.statistics, - &mut self.socket, - &mut self.buffer, - &mut self.opt_resend_buffer, - Response::Error(response), - src, - ); + })); } } } @@ -262,93 +222,85 @@ impl SocketWorker { .validator .connection_id_valid(src, request.connection_id) { - let response = self.shared_state.torrent_maps.scrape(request, src); - - send_response( - &self.config, - &self.statistics, - &mut self.socket, - &mut self.buffer, - &mut self.opt_resend_buffer, - Response::Scrape(response), - src, - ); + return Some(Response::Scrape( + self.shared_state.torrent_maps.scrape(request, src), + )); } } } - } -} -fn send_response( - config: &Config, - statistics: &CachePaddedArc>, - socket: &mut UdpSocket, - buffer: &mut [u8], - opt_resend_buffer: &mut Option>, - response: Response, - canonical_addr: CanonicalSocketAddr, -) { - let mut buffer = Cursor::new(&mut buffer[..]); - - if let Err(err) = response.write_bytes(&mut buffer) { - ::log::error!("failed writing response to buffer: {:#}", err); - - return; + None } - let bytes_written = buffer.position() as usize; + fn send_response( + &mut self, + opt_resend_buffer: &mut Option>, + canonical_addr: CanonicalSocketAddr, + response: Response, + ) { + let mut buffer = Cursor::new(&mut self.buffer[..]); - let addr = if config.network.address.is_ipv4() { - canonical_addr - .get_ipv4() - .expect("found peer ipv6 address while running bound to ipv4 address") - } else { - canonical_addr.get_ipv6_mapped() - }; + if let Err(err) = response.write_bytes(&mut buffer) { + ::log::error!("failed writing response to buffer: {:#}", err); - match socket.send_to(&buffer.into_inner()[..bytes_written], addr) { - Ok(amt) if config.statistics.active() => { - let stats = if canonical_addr.is_ipv4() { - let stats = &statistics.ipv4; + return; + } - stats - .bytes_sent - .fetch_add(amt + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); + let bytes_written = buffer.position() as usize; - stats - } else { - let stats = &statistics.ipv6; + let addr = if self.config.network.address.is_ipv4() { + canonical_addr + .get_ipv4() + .expect("found peer ipv6 address while running bound to ipv4 address") + } else { + canonical_addr.get_ipv6_mapped() + }; - stats - .bytes_sent - .fetch_add(amt + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + match self + .socket + .send_to(&buffer.into_inner()[..bytes_written], addr) + { + Ok(bytes_sent) if self.config.statistics.active() => { + let stats = if canonical_addr.is_ipv4() { + let stats = &self.statistics.ipv4; - stats - }; + stats + .bytes_sent + .fetch_add(bytes_sent + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); - match response { - Response::Connect(_) => { - stats.responses_connect.fetch_add(1, Ordering::Relaxed); - } - Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { - stats.responses_announce.fetch_add(1, Ordering::Relaxed); - } - Response::Scrape(_) => { - stats.responses_scrape.fetch_add(1, Ordering::Relaxed); - } - Response::Error(_) => { - stats.responses_error.fetch_add(1, Ordering::Relaxed); + stats + } else { + let stats = &self.statistics.ipv6; + + stats + .bytes_sent + .fetch_add(bytes_sent + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + + stats + }; + + match response { + Response::Connect(_) => { + stats.responses_connect.fetch_add(1, Ordering::Relaxed); + } + Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { + stats.responses_announce.fetch_add(1, Ordering::Relaxed); + } + Response::Scrape(_) => { + stats.responses_scrape.fetch_add(1, Ordering::Relaxed); + } + Response::Error(_) => { + stats.responses_error.fetch_add(1, Ordering::Relaxed); + } } } - } - Ok(_) => (), - Err(err) => { - match opt_resend_buffer.as_mut() { + Ok(_) => (), + Err(err) => match opt_resend_buffer.as_mut() { Some(resend_buffer) if (err.raw_os_error() == Some(libc::ENOBUFS)) || (err.kind() == ErrorKind::WouldBlock) => { - if resend_buffer.len() < config.network.resend_buffer_max_len { + if resend_buffer.len() < self.config.network.resend_buffer_max_len { ::log::debug!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); resend_buffer.push((canonical_addr, response)); @@ -359,9 +311,9 @@ fn send_response( _ => { ::log::warn!("Sending response to {} failed: {:#}", addr, err); } - } + }, } - } - ::log::debug!("send response fn finished"); + ::log::debug!("send response fn finished"); + } }