diff --git a/TODO.md b/TODO.md index 2e3608b..e75a256 100644 --- a/TODO.md +++ b/TODO.md @@ -3,6 +3,8 @@ ## High priority * udp uring + * should queues be synced? + * miri * uneven performance? * thiserror? * CI diff --git a/aquatic_udp/src/workers/socket/uring/mod.rs b/aquatic_udp/src/workers/socket/uring/mod.rs index b68ddc6..7adf062 100644 --- a/aquatic_udp/src/workers/socket/uring/mod.rs +++ b/aquatic_udp/src/workers/socket/uring/mod.rs @@ -196,7 +196,7 @@ impl SocketWorker { break; } Err(send_buffers::Error::SerializationFailed(err)) => { - ::log::error!("write response to buffer: {:#}", err); + ::log::error!("Failed serializing response: {:#}", err); } } } else { @@ -245,7 +245,7 @@ impl SocketWorker { break; } Err(send_buffers::Error::SerializationFailed(err)) => { - ::log::error!("write response to buffer: {:#}", err); + ::log::error!("Failed serializing response: {:#}", err); } } } @@ -291,30 +291,31 @@ impl SocketWorker { if result < 0 { ::log::error!( - "send: {:#}", + "Couldn't send response: {:#}", ::std::io::Error::from_raw_os_error(-result) ); } else if self.config.statistics.active() { let send_buffer_index = send_buffer_index as usize; - let (statistics, extra_bytes) = - if self.send_buffers.receiver_is_ipv4(send_buffer_index) { - (&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4) - } else { - (&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6) - }; + let (response_type, receiver_is_ipv4) = + self.send_buffers.response_type_and_ipv4(send_buffer_index); + + let (statistics, extra_bytes) = if receiver_is_ipv4 { + (&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4) + } else { + (&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6) + }; statistics .bytes_sent .fetch_add(result as usize + extra_bytes, Ordering::Relaxed); - let response_counter = - match self.send_buffers.response_type(send_buffer_index) { - ResponseType::Connect => &statistics.responses_sent_connect, - ResponseType::Announce => &statistics.responses_sent_announce, - ResponseType::Scrape => &statistics.responses_sent_scrape, - ResponseType::Error => &statistics.responses_sent_error, - }; + let response_counter = match response_type { + ResponseType::Connect => &statistics.responses_sent_connect, + ResponseType::Announce => &statistics.responses_sent_announce, + ResponseType::Scrape => &statistics.responses_sent_scrape, + ResponseType::Error => &statistics.responses_sent_error, + }; response_counter.fetch_add(1, Ordering::Relaxed); } @@ -337,8 +338,14 @@ impl SocketWorker { let result = cqe.result(); if result < 0 { - // Will produce ENOBUFS if there were no free buffers - ::log::warn!("recv: {:#}", ::std::io::Error::from_raw_os_error(-result)); + if -result == libc::ENOBUFS { + ::log::warn!("recv failed due to lack of buffers, try increasing ring size"); + } else { + ::log::warn!( + "recv failed: {:#}", + ::std::io::Error::from_raw_os_error(-result) + ); + } return; } @@ -347,12 +354,12 @@ impl SocketWorker { match self.buf_ring.get_buf(result as u32, cqe.flags()) { Ok(Some(buffer)) => buffer, Ok(None) => { - ::log::error!("Couldn't get buffer"); + ::log::error!("Couldn't get recv buffer"); return; } Err(err) => { - ::log::error!("Couldn't get buffer: {:#}", err); + ::log::error!("Couldn't get recv buffer: {:#}", err); return; } @@ -361,51 +368,60 @@ impl SocketWorker { let buffer = buffer.as_slice(); - let (res_request, addr) = self.recv_helper.parse(buffer); + let addr = match self.recv_helper.parse(buffer) { + Ok((request, addr)) => { + self.handle_request(pending_scrape_valid_until, request, addr); - match res_request { - Ok(request) => self.handle_request(pending_scrape_valid_until, request, addr), - Err(RequestParseError::Sendable { - connection_id, - transaction_id, - err, - }) => { - ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); - - if self.validator.connection_id_valid(addr, connection_id) { - let response = ErrorResponse { + addr + } + Err(self::recv_helper::Error::RequestParseError(err, addr)) => { + match err { + RequestParseError::Sendable { + connection_id, transaction_id, - message: err.right_or("Parse error").into(), - }; + err, + } => { + ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); - self.local_responses.push_back((response.into(), addr)); + if self.validator.connection_id_valid(addr, connection_id) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + self.local_responses.push_back((response.into(), addr)); + } + } + RequestParseError::Unsendable { err } => { + ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); + } } + + addr } - Err(RequestParseError::Unsendable { err }) => { - ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); + Err(self::recv_helper::Error::InvalidSocketAddress) => { + ::log::debug!("Ignored request claiming to be from port 0"); + + return; } - } + Err(self::recv_helper::Error::RecvMsgParseError) => { + ::log::error!("RecvMsgOut::parse failed"); + + return; + } + }; if self.config.statistics.active() { - if addr.is_ipv4() { - self.shared_state - .statistics_ipv4 - .bytes_received - .fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); - self.shared_state - .statistics_ipv4 - .requests_received - .fetch_add(1, Ordering::Relaxed); + let (statistics, extra_bytes) = if addr.is_ipv4() { + (&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4) } else { - self.shared_state - .statistics_ipv6 - .bytes_received - .fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); - self.shared_state - .statistics_ipv6 - .requests_received - .fetch_add(1, Ordering::Relaxed); - } + (&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6) + }; + + statistics + .bytes_received + .fetch_add(buffer.len() + extra_bytes, Ordering::Relaxed); + statistics.requests_received.fetch_add(1, Ordering::Relaxed); } } diff --git a/aquatic_udp/src/workers/socket/uring/recv_helper.rs b/aquatic_udp/src/workers/socket/uring/recv_helper.rs index 920cb2f..f87e208 100644 --- a/aquatic_udp/src/workers/socket/uring/recv_helper.rs +++ b/aquatic_udp/src/workers/socket/uring/recv_helper.rs @@ -1,5 +1,6 @@ use std::{ - net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + cell::UnsafeCell, + net::{Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, ptr::null_mut, }; @@ -11,56 +12,62 @@ use crate::config::Config; use super::{SOCKET_IDENTIFIER, USER_DATA_RECV}; +pub enum Error { + RecvMsgParseError, + RequestParseError(RequestParseError, CanonicalSocketAddr), + InvalidSocketAddress, +} + pub struct RecvHelper { - network_address: IpAddr, + socket_is_ipv4: bool, max_scrape_torrents: u8, #[allow(dead_code)] - name_v4: Box, - msghdr_v4: Box, + name_v4: Box>, + msghdr_v4: Box>, #[allow(dead_code)] - name_v6: Box, - msghdr_v6: Box, + name_v6: Box>, + msghdr_v6: Box>, } impl RecvHelper { pub fn new(config: &Config) -> Self { - let mut name_v4 = Box::new(libc::sockaddr_in { + let name_v4 = Box::new(UnsafeCell::new(libc::sockaddr_in { sin_family: 0, sin_port: 0, sin_addr: libc::in_addr { s_addr: 0 }, sin_zero: [0; 8], - }); + })); - let msghdr_v4 = Box::new(libc::msghdr { - msg_name: &mut name_v4 as *mut _ as *mut libc::c_void, + let msghdr_v4 = Box::new(UnsafeCell::new(libc::msghdr { + msg_name: name_v4.get() as *mut libc::c_void, msg_namelen: core::mem::size_of::() as u32, msg_iov: null_mut(), msg_iovlen: 0, msg_control: null_mut(), msg_controllen: 0, msg_flags: 0, - }); + })); - let mut name_v6 = Box::new(libc::sockaddr_in6 { + let name_v6 = Box::new(UnsafeCell::new(libc::sockaddr_in6 { sin6_family: 0, sin6_port: 0, sin6_flowinfo: 0, sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, sin6_scope_id: 0, - }); + })); - let msghdr_v6 = Box::new(libc::msghdr { - msg_name: &mut name_v6 as *mut _ as *mut libc::c_void, + let msghdr_v6 = Box::new(UnsafeCell::new(libc::msghdr { + msg_name: name_v6.get() as *mut libc::c_void, msg_namelen: core::mem::size_of::() as u32, msg_iov: null_mut(), msg_iovlen: 0, msg_control: null_mut(), msg_controllen: 0, msg_flags: 0, - }); + })); Self { - network_address: config.network.address.ip(), + socket_is_ipv4: config.network.address.is_ipv4(), max_scrape_torrents: config.protocol.max_scrape_torrents, name_v4, msghdr_v4, @@ -70,10 +77,10 @@ impl RecvHelper { } pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry { - let msghdr: *const libc::msghdr = if self.network_address.is_ipv4() { - &*self.msghdr_v4 + let msghdr: *const libc::msghdr = if self.socket_is_ipv4 { + self.msghdr_v4.get() } else { - &*self.msghdr_v6 + self.msghdr_v6.get() }; RecvMsgMulti::new(SOCKET_IDENTIFIER, msghdr, buf_group) @@ -81,27 +88,36 @@ impl RecvHelper { .user_data(USER_DATA_RECV) } - pub fn parse( - &self, - buffer: &[u8], - ) -> (Result, CanonicalSocketAddr) { - let msghdr = if self.network_address.is_ipv4() { - &self.msghdr_v4 - } else { - &self.msghdr_v6 - }; + pub fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> { + let (msg, addr) = if self.socket_is_ipv4 { + let msg = unsafe { + let msghdr = &*(self.msghdr_v4.get() as *const _); - let msg = RecvMsgOut::parse(buffer, msghdr).unwrap(); + RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)? + }; - let addr = unsafe { - if self.network_address.is_ipv4() { + let addr = unsafe { let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in); SocketAddr::V4(SocketAddrV4::new( u32::from_be(name_data.sin_addr.s_addr).into(), u16::from_be(name_data.sin_port), )) - } else { + }; + + if addr.port() == 0 { + return Err(Error::InvalidSocketAddress); + } + + (msg, addr) + } else { + let msg = unsafe { + let msghdr = &*(self.msghdr_v6.get() as *const _); + + RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)? + }; + + let addr = unsafe { let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in6); SocketAddr::V6(SocketAddrV6::new( @@ -110,12 +126,20 @@ impl RecvHelper { u32::from_be(name_data.sin6_flowinfo), u32::from_be(name_data.sin6_scope_id), )) + }; + + if addr.port() == 0 { + return Err(Error::InvalidSocketAddress); } + + (msg, addr) }; - ( - Request::from_bytes(msg.payload_data(), self.max_scrape_torrents), - CanonicalSocketAddr::new(addr), - ) + let addr = CanonicalSocketAddr::new(addr); + + let request = Request::from_bytes(msg.payload_data(), self.max_scrape_torrents) + .map_err(|err| Error::RequestParseError(err, addr))?; + + Ok((request, addr)) } } diff --git a/aquatic_udp/src/workers/socket/uring/send_buffers.rs b/aquatic_udp/src/workers/socket/uring/send_buffers.rs index 0eae144..38fe6c5 100644 --- a/aquatic_udp/src/workers/socket/uring/send_buffers.rs +++ b/aquatic_udp/src/workers/socket/uring/send_buffers.rs @@ -1,4 +1,4 @@ -use std::{io::Cursor, net::IpAddr, ops::IndexMut, ptr::null_mut}; +use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut}; use aquatic_common::CanonicalSocketAddr; use aquatic_udp_protocol::Response; @@ -32,114 +32,173 @@ impl ResponseType { } } -pub struct SendBuffers { - likely_next_free_index: usize, - network_address: IpAddr, - names_v4: Vec, - names_v6: Vec, - buffers: Vec<[u8; BUF_LEN]>, - iovecs: Vec, - msghdrs: Vec, - free: Vec, +struct SendBuffer { + name_v4: UnsafeCell, + name_v6: UnsafeCell, + bytes: UnsafeCell<[u8; BUF_LEN]>, + iovec: UnsafeCell, + msghdr: UnsafeCell, + free: bool, // Only used for statistics - receiver_is_ipv4: Vec, + receiver_is_ipv4: bool, // Only used for statistics - response_types: Vec, + response_type: ResponseType, } -impl SendBuffers { - pub fn new(config: &Config, capacity: usize) -> Self { - let mut buffers = ::std::iter::repeat([0u8; BUF_LEN]) - .take(capacity) - .collect::>(); - - let mut iovecs = buffers - .iter_mut() - .map(|buffer| libc::iovec { - iov_base: buffer.as_mut_ptr() as *mut libc::c_void, - iov_len: buffer.len(), - }) - .collect::>(); - - let (names_v4, names_v6, msghdrs) = if config.network.address.is_ipv4() { - let mut names_v4 = ::std::iter::repeat(libc::sockaddr_in { +impl SendBuffer { + fn new_with_null_pointers() -> Self { + Self { + name_v4: UnsafeCell::new(libc::sockaddr_in { sin_family: libc::AF_INET as u16, sin_port: 0, sin_addr: libc::in_addr { s_addr: 0 }, sin_zero: [0; 8], - }) - .take(capacity) - .collect::>(); - - let msghdrs = names_v4 - .iter_mut() - .zip(iovecs.iter_mut()) - .map(|(msg_name, msg_iov)| libc::msghdr { - msg_name: msg_name as *mut _ as *mut libc::c_void, - msg_namelen: core::mem::size_of::() as u32, - msg_iov: msg_iov as *mut _, - msg_iovlen: 1, - msg_control: null_mut(), - msg_controllen: 0, - msg_flags: 0, - }) - .collect::>(); - - (names_v4, Vec::new(), msghdrs) - } else { - let mut names_v6 = ::std::iter::repeat(libc::sockaddr_in6 { + }), + name_v6: UnsafeCell::new(libc::sockaddr_in6 { sin6_family: libc::AF_INET6 as u16, sin6_port: 0, sin6_flowinfo: 0, sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, sin6_scope_id: 0, - }) - .take(capacity) - .collect::>(); - - let msghdrs = names_v6 - .iter_mut() - .zip(iovecs.iter_mut()) - .map(|(msg_name, msg_iov)| libc::msghdr { - msg_name: msg_name as *mut _ as *mut libc::c_void, - msg_namelen: core::mem::size_of::() as u32, - msg_iov: msg_iov as *mut _, - msg_iovlen: 1, - msg_control: null_mut(), - msg_controllen: 0, - msg_flags: 0, - }) - .collect::>(); - - (Vec::new(), names_v6, msghdrs) - }; - - Self { - likely_next_free_index: 0, - network_address: config.network.address.ip(), - names_v4, - names_v6, - buffers, - iovecs, - msghdrs, - free: ::std::iter::repeat(true).take(capacity).collect(), - receiver_is_ipv4: ::std::iter::repeat(true).take(capacity).collect(), - response_types: ::std::iter::repeat(ResponseType::Connect) - .take(capacity) - .collect(), + }), + bytes: UnsafeCell::new([0; BUF_LEN]), + iovec: UnsafeCell::new(libc::iovec { + iov_base: null_mut(), + iov_len: 0, + }), + msghdr: UnsafeCell::new(libc::msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: null_mut(), + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }), + free: true, + receiver_is_ipv4: true, + response_type: ResponseType::Connect, } } - pub fn receiver_is_ipv4(&mut self, index: usize) -> bool { - self.receiver_is_ipv4[index] + /// # Safety + /// + /// - SendBuffer must be stored at a fixed location in memory + unsafe fn setup_pointers(&mut self, socket_is_ipv4: bool) { + let iovec = &mut *self.iovec.get(); + + iovec.iov_base = self.bytes.get() as *mut libc::c_void; + iovec.iov_len = (&*self.bytes.get()).len(); + + let msghdr = &mut *self.msghdr.get(); + + msghdr.msg_iov = self.iovec.get(); + + if socket_is_ipv4 { + msghdr.msg_name = self.name_v4.get() as *mut libc::c_void; + msghdr.msg_namelen = core::mem::size_of::() as u32; + } else { + msghdr.msg_name = self.name_v6.get() as *mut libc::c_void; + msghdr.msg_namelen = core::mem::size_of::() as u32; + } } - pub fn response_type(&mut self, index: usize) -> ResponseType { - self.response_types[index] + /// # Safety + /// + /// - SendBuffer must be stored at a fixed location in memory + /// - SendBuffer.setup_pointers must have been called previously + unsafe fn prepare_entry( + &mut self, + response: &Response, + addr: CanonicalSocketAddr, + socket_is_ipv4: bool, + ) -> Result { + // Set receiver socket addr + if socket_is_ipv4 { + self.receiver_is_ipv4 = true; + + let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() { + addr + } else { + panic!("ipv6 address in ipv4 mode"); + }; + + let name = &mut *self.name_v4.get(); + + name.sin_port = addr.port().to_be(); + name.sin_addr.s_addr = u32::from(*addr.ip()).to_be(); + } else { + self.receiver_is_ipv4 = addr.is_ipv4(); + + let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() { + addr + } else { + panic!("ipv4 address when ipv6 or ipv6-mapped address expected"); + }; + + let name = &mut *self.name_v6.get(); + + name.sin6_port = addr.port().to_be(); + name.sin6_addr.s6_addr = addr.ip().octets(); + } + + let bytes = (&mut *self.bytes.get()).as_mut_slice(); + + let mut cursor = Cursor::new(bytes); + + match response.write(&mut cursor) { + Ok(()) => { + (&mut *self.iovec.get()).iov_len = cursor.position() as usize; + + self.response_type = ResponseType::from_response(response); + self.free = false; + + let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdr.get()).build(); + + Ok(sqe) + } + Err(err) => Err(Error::SerializationFailed(err)), + } + } +} + +pub struct SendBuffers { + likely_next_free_index: usize, + socket_is_ipv4: bool, + buffers: Box<[SendBuffer]>, +} + +impl SendBuffers { + pub fn new(config: &Config, capacity: usize) -> Self { + let socket_is_ipv4 = config.network.address.is_ipv4(); + + let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new_with_null_pointers()) + .take(capacity) + .collect::>() + .into_boxed_slice(); + + for buffer in buffers.iter_mut() { + // Safety: OK because buffers are stored in fixed memory location + unsafe { + buffer.setup_pointers(socket_is_ipv4); + } + } + + Self { + likely_next_free_index: 0, + socket_is_ipv4, + buffers, + } + } + + pub fn response_type_and_ipv4(&self, index: usize) -> (ResponseType, bool) { + let buffer = self.buffers.get(index).unwrap(); + + (buffer.response_type, buffer.receiver_is_ipv4) } pub fn mark_index_as_free(&mut self, index: usize) { - self.free[index] = true; + self.buffers[index].free = true; } /// Call after going through completion queue @@ -154,64 +213,32 @@ impl SendBuffers { ) -> Result { let index = self.next_free_index()?; - // Set receiver socket addr - if self.network_address.is_ipv4() { - let msg_name = self.names_v4.index_mut(index); - let addr = addr.get_ipv4().unwrap(); + let buffer = self.buffers.index_mut(index); - msg_name.sin_port = addr.port().to_be(); - msg_name.sin_addr.s_addr = if let IpAddr::V4(addr) = addr.ip() { - u32::from(addr).to_be() - } else { - panic!("ipv6 address in ipv4 mode"); - }; + // Safety: OK because buffers are stored in fixed memory location + // and buffer pointers were set up in SendBuffers::new() + unsafe { + match buffer.prepare_entry(response, addr, self.socket_is_ipv4) { + Ok(entry) => { + self.likely_next_free_index = index + 1; - self.receiver_is_ipv4[index] = true; - } else { - let msg_name = self.names_v6.index_mut(index); - let addr = addr.get_ipv6_mapped(); - - msg_name.sin6_port = addr.port().to_be(); - msg_name.sin6_addr.s6_addr = if let IpAddr::V6(addr) = addr.ip() { - addr.octets() - } else { - panic!("ipv4 address when ipv6 or ipv6-mapped address expected"); - }; - - self.receiver_is_ipv4[index] = addr.is_ipv4(); - } - - let mut cursor = Cursor::new(self.buffers.index_mut(index).as_mut_slice()); - - match response.write(&mut cursor) { - Ok(()) => { - self.iovecs[index].iov_len = cursor.position() as usize; - self.response_types[index] = ResponseType::from_response(response); - self.free[index] = false; - - self.likely_next_free_index = index + 1; - - let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdrs.index_mut(index)) - .build() - .user_data(index as u64); - - Ok(sqe) + Ok(entry.user_data(index as u64)) + } + Err(err) => Err(err), } - Err(err) => Err(Error::SerializationFailed(err)), } } fn next_free_index(&self) -> Result { - if self.likely_next_free_index >= self.free.len() { + if self.likely_next_free_index >= self.buffers.len() { return Err(Error::NoBuffers); } - for (i, free) in self.free[self.likely_next_free_index..] + for (i, buffer) in self.buffers[self.likely_next_free_index..] .iter() - .copied() .enumerate() { - if free { + if buffer.free { return Ok(self.likely_next_free_index + i); } }