udp: uring: improve SendBuffers code

This commit is contained in:
Joakim Frostegård 2023-03-08 13:33:14 +01:00
parent 339feb3d0a
commit fa93f38d82

View file

@ -46,7 +46,7 @@ struct SendBuffer {
} }
impl SendBuffer { impl SendBuffer {
fn new() -> Self { fn new_with_null_pointers() -> Self {
Self { Self {
name_v4: UnsafeCell::new(libc::sockaddr_in { name_v4: UnsafeCell::new(libc::sockaddr_in {
sin_family: libc::AF_INET as u16, sin_family: libc::AF_INET as u16,
@ -80,6 +80,86 @@ impl SendBuffer {
response_type: ResponseType::Connect, response_type: ResponseType::Connect,
} }
} }
/// # Safety
///
/// - SendBuffer must be stored at a fixed location in memory
unsafe fn setup_pointers(&mut self, socket_is_ipv4: bool) {
let iovec = &mut *self.iovec.get();
iovec.iov_base = self.bytes.get() as *mut libc::c_void;
iovec.iov_len = (&*self.bytes.get()).len();
let msghdr = &mut *self.msghdr.get();
msghdr.msg_iov = self.iovec.get();
if socket_is_ipv4 {
msghdr.msg_name = self.name_v4.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
} else {
msghdr.msg_name = self.name_v6.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
}
}
/// # Safety
///
/// - SendBuffer must be stored at a fixed location in memory
/// - SendBuffer.setup_pointers must have been called previously
unsafe fn prepare_entry(
&mut self,
response: &Response,
addr: CanonicalSocketAddr,
socket_is_ipv4: bool,
) -> Result<io_uring::squeue::Entry, Error> {
// Set receiver socket addr
if socket_is_ipv4 {
self.receiver_is_ipv4 = true;
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
addr
} else {
panic!("ipv6 address in ipv4 mode");
};
let name = &mut *self.name_v4.get();
name.sin_port = addr.port().to_be();
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
} else {
self.receiver_is_ipv4 = addr.is_ipv4();
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
addr
} else {
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
};
let name = &mut *self.name_v6.get();
name.sin6_port = addr.port().to_be();
name.sin6_addr.s6_addr = addr.ip().octets();
}
let bytes = (&mut *self.bytes.get()).as_mut_slice();
let mut cursor = Cursor::new(bytes);
match response.write(&mut cursor) {
Ok(()) => {
(&mut *self.iovec.get()).iov_len = cursor.position() as usize;
self.response_type = ResponseType::from_response(response);
self.free = false;
let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdr.get()).build();
Ok(sqe)
}
Err(err) => Err(Error::SerializationFailed(err)),
}
}
} }
pub struct SendBuffers { pub struct SendBuffers {
@ -92,30 +172,15 @@ impl SendBuffers {
pub fn new(config: &Config, capacity: usize) -> Self { pub fn new(config: &Config, capacity: usize) -> Self {
let socket_is_ipv4 = config.network.address.is_ipv4(); let socket_is_ipv4 = config.network.address.is_ipv4();
let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new()) let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new_with_null_pointers())
.take(capacity) .take(capacity)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.into_boxed_slice(); .into_boxed_slice();
for buffer in buffers.iter_mut() { for buffer in buffers.iter_mut() {
// Safety: OK because buffers are stored in fixed memory location
unsafe { unsafe {
let iovec = &mut *buffer.iovec.get(); buffer.setup_pointers(socket_is_ipv4);
iovec.iov_base = buffer.bytes.get() as *mut libc::c_void;
iovec.iov_len = (&*buffer.bytes.get()).len();
}
unsafe {
let msghdr = &mut *buffer.msghdr.get();
msghdr.msg_iov = buffer.iovec.get();
if socket_is_ipv4 {
msghdr.msg_name = buffer.name_v4.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
} else {
msghdr.msg_name = buffer.name_v6.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
}
} }
} }
@ -152,60 +217,16 @@ impl SendBuffers {
let buffer = self.buffers.index_mut(index); let buffer = self.buffers.index_mut(index);
// Set receiver socket addr // Safety: OK because buffers are stored in fixed memory location
if self.socket_is_ipv4 { // and buffer pointers were set up in SendBuffers::new()
buffer.receiver_is_ipv4 = true;
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
addr
} else {
panic!("ipv6 address in ipv4 mode");
};
unsafe {
let name = &mut *buffer.name_v4.get();
name.sin_port = addr.port().to_be();
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
}
} else {
buffer.receiver_is_ipv4 = addr.is_ipv4();
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
addr
} else {
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
};
unsafe {
let name = &mut *buffer.name_v6.get();
name.sin6_port = addr.port().to_be();
name.sin6_addr.s6_addr = addr.ip().octets();
}
}
unsafe { unsafe {
let bytes = (&mut *buffer.bytes.get()).as_mut_slice(); match buffer.prepare_entry(response, addr, self.socket_is_ipv4) {
Ok(entry) => {
let mut cursor = Cursor::new(bytes);
match response.write(&mut cursor) {
Ok(()) => {
(&mut *buffer.iovec.get()).iov_len = cursor.position() as usize;
buffer.response_type = ResponseType::from_response(response);
buffer.free = false;
self.likely_next_free_index = index + 1; self.likely_next_free_index = index + 1;
let sqe = SendMsg::new(SOCKET_IDENTIFIER, buffer.msghdr.get()) Ok(entry.user_data(index as u64))
.build()
.user_data(index as u64);
Ok(sqe)
} }
Err(err) => Err(Error::SerializationFailed(err)), Err(err) => Err(err),
} }
} }
} }