Merge pull request #132 from greatest-ape/work-2023-03-08

udp: uring: use UnsafeCell, code improvements
This commit is contained in:
Joakim Frostegård 2023-03-08 20:04:41 +01:00 committed by GitHub
commit f84d80a7e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 295 additions and 226 deletions

View file

@ -3,6 +3,8 @@
## High priority
* udp uring
* should queues be synced?
* miri
* uneven performance?
* thiserror?
* CI

View file

@ -196,7 +196,7 @@ impl SocketWorker {
break;
}
Err(send_buffers::Error::SerializationFailed(err)) => {
::log::error!("write response to buffer: {:#}", err);
::log::error!("Failed serializing response: {:#}", err);
}
}
} else {
@ -245,7 +245,7 @@ impl SocketWorker {
break;
}
Err(send_buffers::Error::SerializationFailed(err)) => {
::log::error!("write response to buffer: {:#}", err);
::log::error!("Failed serializing response: {:#}", err);
}
}
}
@ -291,30 +291,31 @@ impl SocketWorker {
if result < 0 {
::log::error!(
"send: {:#}",
"Couldn't send response: {:#}",
::std::io::Error::from_raw_os_error(-result)
);
} else if self.config.statistics.active() {
let send_buffer_index = send_buffer_index as usize;
let (statistics, extra_bytes) =
if self.send_buffers.receiver_is_ipv4(send_buffer_index) {
(&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4)
} else {
(&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6)
};
let (response_type, receiver_is_ipv4) =
self.send_buffers.response_type_and_ipv4(send_buffer_index);
let (statistics, extra_bytes) = if receiver_is_ipv4 {
(&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4)
} else {
(&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6)
};
statistics
.bytes_sent
.fetch_add(result as usize + extra_bytes, Ordering::Relaxed);
let response_counter =
match self.send_buffers.response_type(send_buffer_index) {
ResponseType::Connect => &statistics.responses_sent_connect,
ResponseType::Announce => &statistics.responses_sent_announce,
ResponseType::Scrape => &statistics.responses_sent_scrape,
ResponseType::Error => &statistics.responses_sent_error,
};
let response_counter = match response_type {
ResponseType::Connect => &statistics.responses_sent_connect,
ResponseType::Announce => &statistics.responses_sent_announce,
ResponseType::Scrape => &statistics.responses_sent_scrape,
ResponseType::Error => &statistics.responses_sent_error,
};
response_counter.fetch_add(1, Ordering::Relaxed);
}
@ -337,8 +338,14 @@ impl SocketWorker {
let result = cqe.result();
if result < 0 {
// Will produce ENOBUFS if there were no free buffers
::log::warn!("recv: {:#}", ::std::io::Error::from_raw_os_error(-result));
if -result == libc::ENOBUFS {
::log::warn!("recv failed due to lack of buffers, try increasing ring size");
} else {
::log::warn!(
"recv failed: {:#}",
::std::io::Error::from_raw_os_error(-result)
);
}
return;
}
@ -347,12 +354,12 @@ impl SocketWorker {
match self.buf_ring.get_buf(result as u32, cqe.flags()) {
Ok(Some(buffer)) => buffer,
Ok(None) => {
::log::error!("Couldn't get buffer");
::log::error!("Couldn't get recv buffer");
return;
}
Err(err) => {
::log::error!("Couldn't get buffer: {:#}", err);
::log::error!("Couldn't get recv buffer: {:#}", err);
return;
}
@ -361,51 +368,60 @@ impl SocketWorker {
let buffer = buffer.as_slice();
let (res_request, addr) = self.recv_helper.parse(buffer);
let addr = match self.recv_helper.parse(buffer) {
Ok((request, addr)) => {
self.handle_request(pending_scrape_valid_until, request, addr);
match res_request {
Ok(request) => self.handle_request(pending_scrape_valid_until, request, addr),
Err(RequestParseError::Sendable {
connection_id,
transaction_id,
err,
}) => {
::log::debug!("Couldn't parse request from {:?}: {}", addr, err);
if self.validator.connection_id_valid(addr, connection_id) {
let response = ErrorResponse {
addr
}
Err(self::recv_helper::Error::RequestParseError(err, addr)) => {
match err {
RequestParseError::Sendable {
connection_id,
transaction_id,
message: err.right_or("Parse error").into(),
};
err,
} => {
::log::debug!("Couldn't parse request from {:?}: {}", addr, err);
self.local_responses.push_back((response.into(), addr));
if self.validator.connection_id_valid(addr, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
self.local_responses.push_back((response.into(), addr));
}
}
RequestParseError::Unsendable { err } => {
::log::debug!("Couldn't parse request from {:?}: {}", addr, err);
}
}
addr
}
Err(RequestParseError::Unsendable { err }) => {
::log::debug!("Couldn't parse request from {:?}: {}", addr, err);
Err(self::recv_helper::Error::InvalidSocketAddress) => {
::log::debug!("Ignored request claiming to be from port 0");
return;
}
}
Err(self::recv_helper::Error::RecvMsgParseError) => {
::log::error!("RecvMsgOut::parse failed");
return;
}
};
if self.config.statistics.active() {
if addr.is_ipv4() {
self.shared_state
.statistics_ipv4
.bytes_received
.fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed);
self.shared_state
.statistics_ipv4
.requests_received
.fetch_add(1, Ordering::Relaxed);
let (statistics, extra_bytes) = if addr.is_ipv4() {
(&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4)
} else {
self.shared_state
.statistics_ipv6
.bytes_received
.fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed);
self.shared_state
.statistics_ipv6
.requests_received
.fetch_add(1, Ordering::Relaxed);
}
(&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6)
};
statistics
.bytes_received
.fetch_add(buffer.len() + extra_bytes, Ordering::Relaxed);
statistics.requests_received.fetch_add(1, Ordering::Relaxed);
}
}

View file

@ -1,5 +1,6 @@
use std::{
net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
cell::UnsafeCell,
net::{Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
ptr::null_mut,
};
@ -11,56 +12,62 @@ use crate::config::Config;
use super::{SOCKET_IDENTIFIER, USER_DATA_RECV};
pub enum Error {
RecvMsgParseError,
RequestParseError(RequestParseError, CanonicalSocketAddr),
InvalidSocketAddress,
}
pub struct RecvHelper {
network_address: IpAddr,
socket_is_ipv4: bool,
max_scrape_torrents: u8,
#[allow(dead_code)]
name_v4: Box<libc::sockaddr_in>,
msghdr_v4: Box<libc::msghdr>,
name_v4: Box<UnsafeCell<libc::sockaddr_in>>,
msghdr_v4: Box<UnsafeCell<libc::msghdr>>,
#[allow(dead_code)]
name_v6: Box<libc::sockaddr_in6>,
msghdr_v6: Box<libc::msghdr>,
name_v6: Box<UnsafeCell<libc::sockaddr_in6>>,
msghdr_v6: Box<UnsafeCell<libc::msghdr>>,
}
impl RecvHelper {
pub fn new(config: &Config) -> Self {
let mut name_v4 = Box::new(libc::sockaddr_in {
let name_v4 = Box::new(UnsafeCell::new(libc::sockaddr_in {
sin_family: 0,
sin_port: 0,
sin_addr: libc::in_addr { s_addr: 0 },
sin_zero: [0; 8],
});
}));
let msghdr_v4 = Box::new(libc::msghdr {
msg_name: &mut name_v4 as *mut _ as *mut libc::c_void,
let msghdr_v4 = Box::new(UnsafeCell::new(libc::msghdr {
msg_name: name_v4.get() as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in>() as u32,
msg_iov: null_mut(),
msg_iovlen: 0,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
});
}));
let mut name_v6 = Box::new(libc::sockaddr_in6 {
let name_v6 = Box::new(UnsafeCell::new(libc::sockaddr_in6 {
sin6_family: 0,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
});
}));
let msghdr_v6 = Box::new(libc::msghdr {
msg_name: &mut name_v6 as *mut _ as *mut libc::c_void,
let msghdr_v6 = Box::new(UnsafeCell::new(libc::msghdr {
msg_name: name_v6.get() as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in6>() as u32,
msg_iov: null_mut(),
msg_iovlen: 0,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
});
}));
Self {
network_address: config.network.address.ip(),
socket_is_ipv4: config.network.address.is_ipv4(),
max_scrape_torrents: config.protocol.max_scrape_torrents,
name_v4,
msghdr_v4,
@ -70,10 +77,10 @@ impl RecvHelper {
}
pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry {
let msghdr: *const libc::msghdr = if self.network_address.is_ipv4() {
&*self.msghdr_v4
let msghdr: *const libc::msghdr = if self.socket_is_ipv4 {
self.msghdr_v4.get()
} else {
&*self.msghdr_v6
self.msghdr_v6.get()
};
RecvMsgMulti::new(SOCKET_IDENTIFIER, msghdr, buf_group)
@ -81,27 +88,36 @@ impl RecvHelper {
.user_data(USER_DATA_RECV)
}
pub fn parse(
&self,
buffer: &[u8],
) -> (Result<Request, RequestParseError>, CanonicalSocketAddr) {
let msghdr = if self.network_address.is_ipv4() {
&self.msghdr_v4
} else {
&self.msghdr_v6
};
pub fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> {
let (msg, addr) = if self.socket_is_ipv4 {
let msg = unsafe {
let msghdr = &*(self.msghdr_v4.get() as *const _);
let msg = RecvMsgOut::parse(buffer, msghdr).unwrap();
RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)?
};
let addr = unsafe {
if self.network_address.is_ipv4() {
let addr = unsafe {
let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in);
SocketAddr::V4(SocketAddrV4::new(
u32::from_be(name_data.sin_addr.s_addr).into(),
u16::from_be(name_data.sin_port),
))
} else {
};
if addr.port() == 0 {
return Err(Error::InvalidSocketAddress);
}
(msg, addr)
} else {
let msg = unsafe {
let msghdr = &*(self.msghdr_v6.get() as *const _);
RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)?
};
let addr = unsafe {
let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in6);
SocketAddr::V6(SocketAddrV6::new(
@ -110,12 +126,20 @@ impl RecvHelper {
u32::from_be(name_data.sin6_flowinfo),
u32::from_be(name_data.sin6_scope_id),
))
};
if addr.port() == 0 {
return Err(Error::InvalidSocketAddress);
}
(msg, addr)
};
(
Request::from_bytes(msg.payload_data(), self.max_scrape_torrents),
CanonicalSocketAddr::new(addr),
)
let addr = CanonicalSocketAddr::new(addr);
let request = Request::from_bytes(msg.payload_data(), self.max_scrape_torrents)
.map_err(|err| Error::RequestParseError(err, addr))?;
Ok((request, addr))
}
}

View file

@ -1,4 +1,4 @@
use std::{io::Cursor, net::IpAddr, ops::IndexMut, ptr::null_mut};
use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut};
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::Response;
@ -32,114 +32,173 @@ impl ResponseType {
}
}
pub struct SendBuffers {
likely_next_free_index: usize,
network_address: IpAddr,
names_v4: Vec<libc::sockaddr_in>,
names_v6: Vec<libc::sockaddr_in6>,
buffers: Vec<[u8; BUF_LEN]>,
iovecs: Vec<libc::iovec>,
msghdrs: Vec<libc::msghdr>,
free: Vec<bool>,
struct SendBuffer {
name_v4: UnsafeCell<libc::sockaddr_in>,
name_v6: UnsafeCell<libc::sockaddr_in6>,
bytes: UnsafeCell<[u8; BUF_LEN]>,
iovec: UnsafeCell<libc::iovec>,
msghdr: UnsafeCell<libc::msghdr>,
free: bool,
// Only used for statistics
receiver_is_ipv4: Vec<bool>,
receiver_is_ipv4: bool,
// Only used for statistics
response_types: Vec<ResponseType>,
response_type: ResponseType,
}
impl SendBuffers {
pub fn new(config: &Config, capacity: usize) -> Self {
let mut buffers = ::std::iter::repeat([0u8; BUF_LEN])
.take(capacity)
.collect::<Vec<_>>();
let mut iovecs = buffers
.iter_mut()
.map(|buffer| libc::iovec {
iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
iov_len: buffer.len(),
})
.collect::<Vec<_>>();
let (names_v4, names_v6, msghdrs) = if config.network.address.is_ipv4() {
let mut names_v4 = ::std::iter::repeat(libc::sockaddr_in {
impl SendBuffer {
fn new_with_null_pointers() -> Self {
Self {
name_v4: UnsafeCell::new(libc::sockaddr_in {
sin_family: libc::AF_INET as u16,
sin_port: 0,
sin_addr: libc::in_addr { s_addr: 0 },
sin_zero: [0; 8],
})
.take(capacity)
.collect::<Vec<_>>();
let msghdrs = names_v4
.iter_mut()
.zip(iovecs.iter_mut())
.map(|(msg_name, msg_iov)| libc::msghdr {
msg_name: msg_name as *mut _ as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in>() as u32,
msg_iov: msg_iov as *mut _,
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
})
.collect::<Vec<_>>();
(names_v4, Vec::new(), msghdrs)
} else {
let mut names_v6 = ::std::iter::repeat(libc::sockaddr_in6 {
}),
name_v6: UnsafeCell::new(libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as u16,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
})
.take(capacity)
.collect::<Vec<_>>();
let msghdrs = names_v6
.iter_mut()
.zip(iovecs.iter_mut())
.map(|(msg_name, msg_iov)| libc::msghdr {
msg_name: msg_name as *mut _ as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in6>() as u32,
msg_iov: msg_iov as *mut _,
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
})
.collect::<Vec<_>>();
(Vec::new(), names_v6, msghdrs)
};
Self {
likely_next_free_index: 0,
network_address: config.network.address.ip(),
names_v4,
names_v6,
buffers,
iovecs,
msghdrs,
free: ::std::iter::repeat(true).take(capacity).collect(),
receiver_is_ipv4: ::std::iter::repeat(true).take(capacity).collect(),
response_types: ::std::iter::repeat(ResponseType::Connect)
.take(capacity)
.collect(),
}),
bytes: UnsafeCell::new([0; BUF_LEN]),
iovec: UnsafeCell::new(libc::iovec {
iov_base: null_mut(),
iov_len: 0,
}),
msghdr: UnsafeCell::new(libc::msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: null_mut(),
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
}),
free: true,
receiver_is_ipv4: true,
response_type: ResponseType::Connect,
}
}
pub fn receiver_is_ipv4(&mut self, index: usize) -> bool {
self.receiver_is_ipv4[index]
/// # Safety
///
/// - SendBuffer must be stored at a fixed location in memory
unsafe fn setup_pointers(&mut self, socket_is_ipv4: bool) {
let iovec = &mut *self.iovec.get();
iovec.iov_base = self.bytes.get() as *mut libc::c_void;
iovec.iov_len = (&*self.bytes.get()).len();
let msghdr = &mut *self.msghdr.get();
msghdr.msg_iov = self.iovec.get();
if socket_is_ipv4 {
msghdr.msg_name = self.name_v4.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
} else {
msghdr.msg_name = self.name_v6.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
}
}
pub fn response_type(&mut self, index: usize) -> ResponseType {
self.response_types[index]
/// # Safety
///
/// - SendBuffer must be stored at a fixed location in memory
/// - SendBuffer.setup_pointers must have been called previously
unsafe fn prepare_entry(
&mut self,
response: &Response,
addr: CanonicalSocketAddr,
socket_is_ipv4: bool,
) -> Result<io_uring::squeue::Entry, Error> {
// Set receiver socket addr
if socket_is_ipv4 {
self.receiver_is_ipv4 = true;
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
addr
} else {
panic!("ipv6 address in ipv4 mode");
};
let name = &mut *self.name_v4.get();
name.sin_port = addr.port().to_be();
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
} else {
self.receiver_is_ipv4 = addr.is_ipv4();
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
addr
} else {
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
};
let name = &mut *self.name_v6.get();
name.sin6_port = addr.port().to_be();
name.sin6_addr.s6_addr = addr.ip().octets();
}
let bytes = (&mut *self.bytes.get()).as_mut_slice();
let mut cursor = Cursor::new(bytes);
match response.write(&mut cursor) {
Ok(()) => {
(&mut *self.iovec.get()).iov_len = cursor.position() as usize;
self.response_type = ResponseType::from_response(response);
self.free = false;
let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdr.get()).build();
Ok(sqe)
}
Err(err) => Err(Error::SerializationFailed(err)),
}
}
}
pub struct SendBuffers {
likely_next_free_index: usize,
socket_is_ipv4: bool,
buffers: Box<[SendBuffer]>,
}
impl SendBuffers {
pub fn new(config: &Config, capacity: usize) -> Self {
let socket_is_ipv4 = config.network.address.is_ipv4();
let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new_with_null_pointers())
.take(capacity)
.collect::<Vec<_>>()
.into_boxed_slice();
for buffer in buffers.iter_mut() {
// Safety: OK because buffers are stored in fixed memory location
unsafe {
buffer.setup_pointers(socket_is_ipv4);
}
}
Self {
likely_next_free_index: 0,
socket_is_ipv4,
buffers,
}
}
pub fn response_type_and_ipv4(&self, index: usize) -> (ResponseType, bool) {
let buffer = self.buffers.get(index).unwrap();
(buffer.response_type, buffer.receiver_is_ipv4)
}
pub fn mark_index_as_free(&mut self, index: usize) {
self.free[index] = true;
self.buffers[index].free = true;
}
/// Call after going through completion queue
@ -154,64 +213,32 @@ impl SendBuffers {
) -> Result<io_uring::squeue::Entry, Error> {
let index = self.next_free_index()?;
// Set receiver socket addr
if self.network_address.is_ipv4() {
let msg_name = self.names_v4.index_mut(index);
let addr = addr.get_ipv4().unwrap();
let buffer = self.buffers.index_mut(index);
msg_name.sin_port = addr.port().to_be();
msg_name.sin_addr.s_addr = if let IpAddr::V4(addr) = addr.ip() {
u32::from(addr).to_be()
} else {
panic!("ipv6 address in ipv4 mode");
};
// Safety: OK because buffers are stored in fixed memory location
// and buffer pointers were set up in SendBuffers::new()
unsafe {
match buffer.prepare_entry(response, addr, self.socket_is_ipv4) {
Ok(entry) => {
self.likely_next_free_index = index + 1;
self.receiver_is_ipv4[index] = true;
} else {
let msg_name = self.names_v6.index_mut(index);
let addr = addr.get_ipv6_mapped();
msg_name.sin6_port = addr.port().to_be();
msg_name.sin6_addr.s6_addr = if let IpAddr::V6(addr) = addr.ip() {
addr.octets()
} else {
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
};
self.receiver_is_ipv4[index] = addr.is_ipv4();
}
let mut cursor = Cursor::new(self.buffers.index_mut(index).as_mut_slice());
match response.write(&mut cursor) {
Ok(()) => {
self.iovecs[index].iov_len = cursor.position() as usize;
self.response_types[index] = ResponseType::from_response(response);
self.free[index] = false;
self.likely_next_free_index = index + 1;
let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdrs.index_mut(index))
.build()
.user_data(index as u64);
Ok(sqe)
Ok(entry.user_data(index as u64))
}
Err(err) => Err(err),
}
Err(err) => Err(Error::SerializationFailed(err)),
}
}
fn next_free_index(&self) -> Result<usize, Error> {
if self.likely_next_free_index >= self.free.len() {
if self.likely_next_free_index >= self.buffers.len() {
return Err(Error::NoBuffers);
}
for (i, free) in self.free[self.likely_next_free_index..]
for (i, buffer) in self.buffers[self.likely_next_free_index..]
.iter()
.copied()
.enumerate()
{
if free {
if buffer.free {
return Ok(self.likely_next_free_index + i);
}
}