diff --git a/aquatic_udp/src/workers/socket/uring/mod.rs b/aquatic_udp/src/workers/socket/uring/mod.rs index b68ddc6..16df167 100644 --- a/aquatic_udp/src/workers/socket/uring/mod.rs +++ b/aquatic_udp/src/workers/socket/uring/mod.rs @@ -361,51 +361,60 @@ impl SocketWorker { let buffer = buffer.as_slice(); - let (res_request, addr) = self.recv_helper.parse(buffer); + let addr = match self.recv_helper.parse(buffer) { + Ok((request, addr)) => { + self.handle_request(pending_scrape_valid_until, request, addr); - match res_request { - Ok(request) => self.handle_request(pending_scrape_valid_until, request, addr), - Err(RequestParseError::Sendable { - connection_id, - transaction_id, - err, - }) => { - ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); - - if self.validator.connection_id_valid(addr, connection_id) { - let response = ErrorResponse { + addr + } + Err(self::recv_helper::Error::RequestParseError(err, addr)) => { + match err { + RequestParseError::Sendable { + connection_id, transaction_id, - message: err.right_or("Parse error").into(), - }; + err, + } => { + ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); - self.local_responses.push_back((response.into(), addr)); + if self.validator.connection_id_valid(addr, connection_id) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + self.local_responses.push_back((response.into(), addr)); + } + } + RequestParseError::Unsendable { err } => { + ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); + } } + + addr } - Err(RequestParseError::Unsendable { err }) => { - ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); + Err(self::recv_helper::Error::InvalidSocketAddress) => { + ::log::debug!("Ignored request claiming to be from port 0"); + + return; } - } + Err(self::recv_helper::Error::RecvMsgParseError) => { + ::log::error!("RecvMsgOut::parse failed"); + + return; + } + }; if self.config.statistics.active() { - if addr.is_ipv4() { - self.shared_state - .statistics_ipv4 - .bytes_received - .fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); - self.shared_state - .statistics_ipv4 - .requests_received - .fetch_add(1, Ordering::Relaxed); + let (statistics, extra_bytes) = if addr.is_ipv4() { + (&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4) } else { - self.shared_state - .statistics_ipv6 - .bytes_received - .fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); - self.shared_state - .statistics_ipv6 - .requests_received - .fetch_add(1, Ordering::Relaxed); - } + (&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6) + }; + + statistics + .bytes_received + .fetch_add(buffer.len() + extra_bytes, Ordering::Relaxed); + statistics.requests_received.fetch_add(1, Ordering::Relaxed); } } diff --git a/aquatic_udp/src/workers/socket/uring/recv_helper.rs b/aquatic_udp/src/workers/socket/uring/recv_helper.rs index 808c2f9..f87e208 100644 --- a/aquatic_udp/src/workers/socket/uring/recv_helper.rs +++ b/aquatic_udp/src/workers/socket/uring/recv_helper.rs @@ -12,6 +12,12 @@ use crate::config::Config; use super::{SOCKET_IDENTIFIER, USER_DATA_RECV}; +pub enum Error { + RecvMsgParseError, + RequestParseError(RequestParseError, CanonicalSocketAddr), + InvalidSocketAddress, +} + pub struct RecvHelper { socket_is_ipv4: bool, max_scrape_torrents: u8, @@ -82,15 +88,12 @@ impl RecvHelper { .user_data(USER_DATA_RECV) } - pub fn parse( - &self, - buffer: &[u8], - ) -> (Result, CanonicalSocketAddr) { + pub fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> { let (msg, addr) = if self.socket_is_ipv4 { let msg = unsafe { let msghdr = &*(self.msghdr_v4.get() as *const _); - RecvMsgOut::parse(buffer, msghdr).unwrap() + RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)? }; let addr = unsafe { @@ -102,12 +105,16 @@ impl RecvHelper { )) }; + if addr.port() == 0 { + return Err(Error::InvalidSocketAddress); + } + (msg, addr) } else { let msg = unsafe { let msghdr = &*(self.msghdr_v6.get() as *const _); - RecvMsgOut::parse(buffer, msghdr).unwrap() + RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)? }; let addr = unsafe { @@ -121,12 +128,18 @@ impl RecvHelper { )) }; + if addr.port() == 0 { + return Err(Error::InvalidSocketAddress); + } + (msg, addr) }; - ( - Request::from_bytes(msg.payload_data(), self.max_scrape_torrents), - CanonicalSocketAddr::new(addr), - ) + let addr = CanonicalSocketAddr::new(addr); + + let request = Request::from_bytes(msg.payload_data(), self.max_scrape_torrents) + .map_err(|err| Error::RequestParseError(err, addr))?; + + Ok((request, addr)) } }