udp: uring: rewrite SendBuffers to use UnsafeCell

This commit is contained in:
Joakim Frostegård 2023-03-08 13:14:35 +01:00
parent 2e67f11caf
commit 339feb3d0a

View file

@ -1,4 +1,4 @@
use std::{io::Cursor, net::IpAddr, ops::IndexMut, ptr::null_mut};
use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut};
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::Response;
@ -32,114 +32,110 @@ impl ResponseType {
}
}
pub struct SendBuffers {
likely_next_free_index: usize,
network_address: IpAddr,
names_v4: Vec<libc::sockaddr_in>,
names_v6: Vec<libc::sockaddr_in6>,
buffers: Vec<[u8; BUF_LEN]>,
iovecs: Vec<libc::iovec>,
msghdrs: Vec<libc::msghdr>,
free: Vec<bool>,
struct SendBuffer {
name_v4: UnsafeCell<libc::sockaddr_in>,
name_v6: UnsafeCell<libc::sockaddr_in6>,
bytes: UnsafeCell<[u8; BUF_LEN]>,
iovec: UnsafeCell<libc::iovec>,
msghdr: UnsafeCell<libc::msghdr>,
free: bool,
// Only used for statistics
receiver_is_ipv4: Vec<bool>,
receiver_is_ipv4: bool,
// Only used for statistics
response_types: Vec<ResponseType>,
response_type: ResponseType,
}
impl SendBuffers {
pub fn new(config: &Config, capacity: usize) -> Self {
let mut buffers = ::std::iter::repeat([0u8; BUF_LEN])
.take(capacity)
.collect::<Vec<_>>();
let mut iovecs = buffers
.iter_mut()
.map(|buffer| libc::iovec {
iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
iov_len: buffer.len(),
})
.collect::<Vec<_>>();
let (names_v4, names_v6, msghdrs) = if config.network.address.is_ipv4() {
let mut names_v4 = ::std::iter::repeat(libc::sockaddr_in {
impl SendBuffer {
fn new() -> Self {
Self {
name_v4: UnsafeCell::new(libc::sockaddr_in {
sin_family: libc::AF_INET as u16,
sin_port: 0,
sin_addr: libc::in_addr { s_addr: 0 },
sin_zero: [0; 8],
})
.take(capacity)
.collect::<Vec<_>>();
let msghdrs = names_v4
.iter_mut()
.zip(iovecs.iter_mut())
.map(|(msg_name, msg_iov)| libc::msghdr {
msg_name: msg_name as *mut _ as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in>() as u32,
msg_iov: msg_iov as *mut _,
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
})
.collect::<Vec<_>>();
(names_v4, Vec::new(), msghdrs)
} else {
let mut names_v6 = ::std::iter::repeat(libc::sockaddr_in6 {
}),
name_v6: UnsafeCell::new(libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as u16,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
})
}),
bytes: UnsafeCell::new([0; BUF_LEN]),
iovec: UnsafeCell::new(libc::iovec {
iov_base: null_mut(),
iov_len: 0,
}),
msghdr: UnsafeCell::new(libc::msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: null_mut(),
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
}),
free: true,
receiver_is_ipv4: true,
response_type: ResponseType::Connect,
}
}
}
pub struct SendBuffers {
likely_next_free_index: usize,
socket_is_ipv4: bool,
buffers: Box<[SendBuffer]>,
}
impl SendBuffers {
pub fn new(config: &Config, capacity: usize) -> Self {
let socket_is_ipv4 = config.network.address.is_ipv4();
let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new())
.take(capacity)
.collect::<Vec<_>>();
.collect::<Vec<_>>()
.into_boxed_slice();
let msghdrs = names_v6
.iter_mut()
.zip(iovecs.iter_mut())
.map(|(msg_name, msg_iov)| libc::msghdr {
msg_name: msg_name as *mut _ as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in6>() as u32,
msg_iov: msg_iov as *mut _,
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
})
.collect::<Vec<_>>();
for buffer in buffers.iter_mut() {
unsafe {
let iovec = &mut *buffer.iovec.get();
(Vec::new(), names_v6, msghdrs)
};
iovec.iov_base = buffer.bytes.get() as *mut libc::c_void;
iovec.iov_len = (&*buffer.bytes.get()).len();
}
unsafe {
let msghdr = &mut *buffer.msghdr.get();
msghdr.msg_iov = buffer.iovec.get();
if socket_is_ipv4 {
msghdr.msg_name = buffer.name_v4.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
} else {
msghdr.msg_name = buffer.name_v6.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
}
}
}
Self {
likely_next_free_index: 0,
network_address: config.network.address.ip(),
names_v4,
names_v6,
socket_is_ipv4,
buffers,
iovecs,
msghdrs,
free: ::std::iter::repeat(true).take(capacity).collect(),
receiver_is_ipv4: ::std::iter::repeat(true).take(capacity).collect(),
response_types: ::std::iter::repeat(ResponseType::Connect)
.take(capacity)
.collect(),
}
}
pub fn receiver_is_ipv4(&mut self, index: usize) -> bool {
self.receiver_is_ipv4[index]
self.buffers[index].receiver_is_ipv4
}
pub fn response_type(&mut self, index: usize) -> ResponseType {
self.response_types[index]
self.buffers[index].response_type
}
pub fn mark_index_as_free(&mut self, index: usize) {
self.free[index] = true;
self.buffers[index].free = true;
}
/// Call after going through completion queue
@ -154,64 +150,76 @@ impl SendBuffers {
) -> Result<io_uring::squeue::Entry, Error> {
let index = self.next_free_index()?;
// Set receiver socket addr
if self.network_address.is_ipv4() {
let msg_name = self.names_v4.index_mut(index);
let addr = addr.get_ipv4().unwrap();
let buffer = self.buffers.index_mut(index);
msg_name.sin_port = addr.port().to_be();
msg_name.sin_addr.s_addr = if let IpAddr::V4(addr) = addr.ip() {
u32::from(addr).to_be()
// Set receiver socket addr
if self.socket_is_ipv4 {
buffer.receiver_is_ipv4 = true;
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
addr
} else {
panic!("ipv6 address in ipv4 mode");
};
self.receiver_is_ipv4[index] = true;
} else {
let msg_name = self.names_v6.index_mut(index);
let addr = addr.get_ipv6_mapped();
unsafe {
let name = &mut *buffer.name_v4.get();
msg_name.sin6_port = addr.port().to_be();
msg_name.sin6_addr.s6_addr = if let IpAddr::V6(addr) = addr.ip() {
addr.octets()
name.sin_port = addr.port().to_be();
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
}
} else {
buffer.receiver_is_ipv4 = addr.is_ipv4();
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
addr
} else {
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
};
self.receiver_is_ipv4[index] = addr.is_ipv4();
unsafe {
let name = &mut *buffer.name_v6.get();
name.sin6_port = addr.port().to_be();
name.sin6_addr.s6_addr = addr.ip().octets();
}
}
let mut cursor = Cursor::new(self.buffers.index_mut(index).as_mut_slice());
unsafe {
let bytes = (&mut *buffer.bytes.get()).as_mut_slice();
match response.write(&mut cursor) {
Ok(()) => {
self.iovecs[index].iov_len = cursor.position() as usize;
self.response_types[index] = ResponseType::from_response(response);
self.free[index] = false;
let mut cursor = Cursor::new(bytes);
self.likely_next_free_index = index + 1;
match response.write(&mut cursor) {
Ok(()) => {
(&mut *buffer.iovec.get()).iov_len = cursor.position() as usize;
let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdrs.index_mut(index))
.build()
.user_data(index as u64);
buffer.response_type = ResponseType::from_response(response);
buffer.free = false;
Ok(sqe)
self.likely_next_free_index = index + 1;
let sqe = SendMsg::new(SOCKET_IDENTIFIER, buffer.msghdr.get())
.build()
.user_data(index as u64);
Ok(sqe)
}
Err(err) => Err(Error::SerializationFailed(err)),
}
Err(err) => Err(Error::SerializationFailed(err)),
}
}
fn next_free_index(&self) -> Result<usize, Error> {
if self.likely_next_free_index >= self.free.len() {
if self.likely_next_free_index >= self.buffers.len() {
return Err(Error::NoBuffers);
}
for (i, free) in self.free[self.likely_next_free_index..]
for (i, buffer) in self.buffers[self.likely_next_free_index..]
.iter()
.copied()
.enumerate()
{
if free {
if buffer.free {
return Ok(self.likely_next_free_index + i);
}
}