diff --git a/aquatic_udp/src/workers/socket/uring/send_buffers.rs b/aquatic_udp/src/workers/socket/uring/send_buffers.rs index 0eae144..cd711f0 100644 --- a/aquatic_udp/src/workers/socket/uring/send_buffers.rs +++ b/aquatic_udp/src/workers/socket/uring/send_buffers.rs @@ -1,4 +1,4 @@ -use std::{io::Cursor, net::IpAddr, ops::IndexMut, ptr::null_mut}; +use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut}; use aquatic_common::CanonicalSocketAddr; use aquatic_udp_protocol::Response; @@ -32,114 +32,110 @@ impl ResponseType { } } -pub struct SendBuffers { - likely_next_free_index: usize, - network_address: IpAddr, - names_v4: Vec, - names_v6: Vec, - buffers: Vec<[u8; BUF_LEN]>, - iovecs: Vec, - msghdrs: Vec, - free: Vec, +struct SendBuffer { + name_v4: UnsafeCell, + name_v6: UnsafeCell, + bytes: UnsafeCell<[u8; BUF_LEN]>, + iovec: UnsafeCell, + msghdr: UnsafeCell, + free: bool, // Only used for statistics - receiver_is_ipv4: Vec, + receiver_is_ipv4: bool, // Only used for statistics - response_types: Vec, + response_type: ResponseType, } -impl SendBuffers { - pub fn new(config: &Config, capacity: usize) -> Self { - let mut buffers = ::std::iter::repeat([0u8; BUF_LEN]) - .take(capacity) - .collect::>(); - - let mut iovecs = buffers - .iter_mut() - .map(|buffer| libc::iovec { - iov_base: buffer.as_mut_ptr() as *mut libc::c_void, - iov_len: buffer.len(), - }) - .collect::>(); - - let (names_v4, names_v6, msghdrs) = if config.network.address.is_ipv4() { - let mut names_v4 = ::std::iter::repeat(libc::sockaddr_in { +impl SendBuffer { + fn new() -> Self { + Self { + name_v4: UnsafeCell::new(libc::sockaddr_in { sin_family: libc::AF_INET as u16, sin_port: 0, sin_addr: libc::in_addr { s_addr: 0 }, sin_zero: [0; 8], - }) - .take(capacity) - .collect::>(); - - let msghdrs = names_v4 - .iter_mut() - .zip(iovecs.iter_mut()) - .map(|(msg_name, msg_iov)| libc::msghdr { - msg_name: msg_name as *mut _ as *mut libc::c_void, - msg_namelen: core::mem::size_of::() as u32, - msg_iov: msg_iov as *mut _, - msg_iovlen: 1, - msg_control: null_mut(), - msg_controllen: 0, - msg_flags: 0, - }) - .collect::>(); - - (names_v4, Vec::new(), msghdrs) - } else { - let mut names_v6 = ::std::iter::repeat(libc::sockaddr_in6 { + }), + name_v6: UnsafeCell::new(libc::sockaddr_in6 { sin6_family: libc::AF_INET6 as u16, sin6_port: 0, sin6_flowinfo: 0, sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, sin6_scope_id: 0, - }) + }), + bytes: UnsafeCell::new([0; BUF_LEN]), + iovec: UnsafeCell::new(libc::iovec { + iov_base: null_mut(), + iov_len: 0, + }), + msghdr: UnsafeCell::new(libc::msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: null_mut(), + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }), + free: true, + receiver_is_ipv4: true, + response_type: ResponseType::Connect, + } + } +} + +pub struct SendBuffers { + likely_next_free_index: usize, + socket_is_ipv4: bool, + buffers: Box<[SendBuffer]>, +} + +impl SendBuffers { + pub fn new(config: &Config, capacity: usize) -> Self { + let socket_is_ipv4 = config.network.address.is_ipv4(); + + let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new()) .take(capacity) - .collect::>(); + .collect::>() + .into_boxed_slice(); - let msghdrs = names_v6 - .iter_mut() - .zip(iovecs.iter_mut()) - .map(|(msg_name, msg_iov)| libc::msghdr { - msg_name: msg_name as *mut _ as *mut libc::c_void, - msg_namelen: core::mem::size_of::() as u32, - msg_iov: msg_iov as *mut _, - msg_iovlen: 1, - msg_control: null_mut(), - msg_controllen: 0, - msg_flags: 0, - }) - .collect::>(); + for buffer in buffers.iter_mut() { + unsafe { + let iovec = &mut *buffer.iovec.get(); - (Vec::new(), names_v6, msghdrs) - }; + iovec.iov_base = buffer.bytes.get() as *mut libc::c_void; + iovec.iov_len = (&*buffer.bytes.get()).len(); + } + unsafe { + let msghdr = &mut *buffer.msghdr.get(); + + msghdr.msg_iov = buffer.iovec.get(); + + if socket_is_ipv4 { + msghdr.msg_name = buffer.name_v4.get() as *mut libc::c_void; + msghdr.msg_namelen = core::mem::size_of::() as u32; + } else { + msghdr.msg_name = buffer.name_v6.get() as *mut libc::c_void; + msghdr.msg_namelen = core::mem::size_of::() as u32; + } + } + } Self { likely_next_free_index: 0, - network_address: config.network.address.ip(), - names_v4, - names_v6, + socket_is_ipv4, buffers, - iovecs, - msghdrs, - free: ::std::iter::repeat(true).take(capacity).collect(), - receiver_is_ipv4: ::std::iter::repeat(true).take(capacity).collect(), - response_types: ::std::iter::repeat(ResponseType::Connect) - .take(capacity) - .collect(), } } pub fn receiver_is_ipv4(&mut self, index: usize) -> bool { - self.receiver_is_ipv4[index] + self.buffers[index].receiver_is_ipv4 } pub fn response_type(&mut self, index: usize) -> ResponseType { - self.response_types[index] + self.buffers[index].response_type } pub fn mark_index_as_free(&mut self, index: usize) { - self.free[index] = true; + self.buffers[index].free = true; } /// Call after going through completion queue @@ -154,64 +150,76 @@ impl SendBuffers { ) -> Result { let index = self.next_free_index()?; - // Set receiver socket addr - if self.network_address.is_ipv4() { - let msg_name = self.names_v4.index_mut(index); - let addr = addr.get_ipv4().unwrap(); + let buffer = self.buffers.index_mut(index); - msg_name.sin_port = addr.port().to_be(); - msg_name.sin_addr.s_addr = if let IpAddr::V4(addr) = addr.ip() { - u32::from(addr).to_be() + // Set receiver socket addr + if self.socket_is_ipv4 { + buffer.receiver_is_ipv4 = true; + + let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() { + addr } else { panic!("ipv6 address in ipv4 mode"); }; - self.receiver_is_ipv4[index] = true; - } else { - let msg_name = self.names_v6.index_mut(index); - let addr = addr.get_ipv6_mapped(); + unsafe { + let name = &mut *buffer.name_v4.get(); - msg_name.sin6_port = addr.port().to_be(); - msg_name.sin6_addr.s6_addr = if let IpAddr::V6(addr) = addr.ip() { - addr.octets() + name.sin_port = addr.port().to_be(); + name.sin_addr.s_addr = u32::from(*addr.ip()).to_be(); + } + } else { + buffer.receiver_is_ipv4 = addr.is_ipv4(); + + let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() { + addr } else { panic!("ipv4 address when ipv6 or ipv6-mapped address expected"); }; - self.receiver_is_ipv4[index] = addr.is_ipv4(); + unsafe { + let name = &mut *buffer.name_v6.get(); + + name.sin6_port = addr.port().to_be(); + name.sin6_addr.s6_addr = addr.ip().octets(); + } } - let mut cursor = Cursor::new(self.buffers.index_mut(index).as_mut_slice()); + unsafe { + let bytes = (&mut *buffer.bytes.get()).as_mut_slice(); - match response.write(&mut cursor) { - Ok(()) => { - self.iovecs[index].iov_len = cursor.position() as usize; - self.response_types[index] = ResponseType::from_response(response); - self.free[index] = false; + let mut cursor = Cursor::new(bytes); - self.likely_next_free_index = index + 1; + match response.write(&mut cursor) { + Ok(()) => { + (&mut *buffer.iovec.get()).iov_len = cursor.position() as usize; - let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdrs.index_mut(index)) - .build() - .user_data(index as u64); + buffer.response_type = ResponseType::from_response(response); + buffer.free = false; - Ok(sqe) + self.likely_next_free_index = index + 1; + + let sqe = SendMsg::new(SOCKET_IDENTIFIER, buffer.msghdr.get()) + .build() + .user_data(index as u64); + + Ok(sqe) + } + Err(err) => Err(Error::SerializationFailed(err)), } - Err(err) => Err(Error::SerializationFailed(err)), } } fn next_free_index(&self) -> Result { - if self.likely_next_free_index >= self.free.len() { + if self.likely_next_free_index >= self.buffers.len() { return Err(Error::NoBuffers); } - for (i, free) in self.free[self.likely_next_free_index..] + for (i, buffer) in self.buffers[self.likely_next_free_index..] .iter() - .copied() .enumerate() { - if free { + if buffer.free { return Ok(self.likely_next_free_index + i); } }