diff --git a/aquatic_udp/src/workers/socket/uring/send_buffers.rs b/aquatic_udp/src/workers/socket/uring/send_buffers.rs index cd711f0..6a8d7dd 100644 --- a/aquatic_udp/src/workers/socket/uring/send_buffers.rs +++ b/aquatic_udp/src/workers/socket/uring/send_buffers.rs @@ -46,7 +46,7 @@ struct SendBuffer { } impl SendBuffer { - fn new() -> Self { + fn new_with_null_pointers() -> Self { Self { name_v4: UnsafeCell::new(libc::sockaddr_in { sin_family: libc::AF_INET as u16, @@ -80,6 +80,86 @@ impl SendBuffer { response_type: ResponseType::Connect, } } + + /// # Safety + /// + /// - SendBuffer must be stored at a fixed location in memory + unsafe fn setup_pointers(&mut self, socket_is_ipv4: bool) { + let iovec = &mut *self.iovec.get(); + + iovec.iov_base = self.bytes.get() as *mut libc::c_void; + iovec.iov_len = (&*self.bytes.get()).len(); + + let msghdr = &mut *self.msghdr.get(); + + msghdr.msg_iov = self.iovec.get(); + + if socket_is_ipv4 { + msghdr.msg_name = self.name_v4.get() as *mut libc::c_void; + msghdr.msg_namelen = core::mem::size_of::() as u32; + } else { + msghdr.msg_name = self.name_v6.get() as *mut libc::c_void; + msghdr.msg_namelen = core::mem::size_of::() as u32; + } + } + + /// # Safety + /// + /// - SendBuffer must be stored at a fixed location in memory + /// - SendBuffer.setup_pointers must have been called previously + unsafe fn prepare_entry( + &mut self, + response: &Response, + addr: CanonicalSocketAddr, + socket_is_ipv4: bool, + ) -> Result { + // Set receiver socket addr + if socket_is_ipv4 { + self.receiver_is_ipv4 = true; + + let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() { + addr + } else { + panic!("ipv6 address in ipv4 mode"); + }; + + let name = &mut *self.name_v4.get(); + + name.sin_port = addr.port().to_be(); + name.sin_addr.s_addr = u32::from(*addr.ip()).to_be(); + } else { + self.receiver_is_ipv4 = addr.is_ipv4(); + + let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() { + addr + } else { + panic!("ipv4 address when ipv6 or ipv6-mapped address expected"); + }; + + let name = &mut *self.name_v6.get(); + + name.sin6_port = addr.port().to_be(); + name.sin6_addr.s6_addr = addr.ip().octets(); + } + + let bytes = (&mut *self.bytes.get()).as_mut_slice(); + + let mut cursor = Cursor::new(bytes); + + match response.write(&mut cursor) { + Ok(()) => { + (&mut *self.iovec.get()).iov_len = cursor.position() as usize; + + self.response_type = ResponseType::from_response(response); + self.free = false; + + let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdr.get()).build(); + + Ok(sqe) + } + Err(err) => Err(Error::SerializationFailed(err)), + } + } } pub struct SendBuffers { @@ -92,30 +172,15 @@ impl SendBuffers { pub fn new(config: &Config, capacity: usize) -> Self { let socket_is_ipv4 = config.network.address.is_ipv4(); - let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new()) + let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new_with_null_pointers()) .take(capacity) .collect::>() .into_boxed_slice(); for buffer in buffers.iter_mut() { + // Safety: OK because buffers are stored in fixed memory location unsafe { - let iovec = &mut *buffer.iovec.get(); - - iovec.iov_base = buffer.bytes.get() as *mut libc::c_void; - iovec.iov_len = (&*buffer.bytes.get()).len(); - } - unsafe { - let msghdr = &mut *buffer.msghdr.get(); - - msghdr.msg_iov = buffer.iovec.get(); - - if socket_is_ipv4 { - msghdr.msg_name = buffer.name_v4.get() as *mut libc::c_void; - msghdr.msg_namelen = core::mem::size_of::() as u32; - } else { - msghdr.msg_name = buffer.name_v6.get() as *mut libc::c_void; - msghdr.msg_namelen = core::mem::size_of::() as u32; - } + buffer.setup_pointers(socket_is_ipv4); } } @@ -152,60 +217,16 @@ impl SendBuffers { let buffer = self.buffers.index_mut(index); - // Set receiver socket addr - if self.socket_is_ipv4 { - buffer.receiver_is_ipv4 = true; - - let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() { - addr - } else { - panic!("ipv6 address in ipv4 mode"); - }; - - unsafe { - let name = &mut *buffer.name_v4.get(); - - name.sin_port = addr.port().to_be(); - name.sin_addr.s_addr = u32::from(*addr.ip()).to_be(); - } - } else { - buffer.receiver_is_ipv4 = addr.is_ipv4(); - - let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() { - addr - } else { - panic!("ipv4 address when ipv6 or ipv6-mapped address expected"); - }; - - unsafe { - let name = &mut *buffer.name_v6.get(); - - name.sin6_port = addr.port().to_be(); - name.sin6_addr.s6_addr = addr.ip().octets(); - } - } - + // Safety: OK because buffers are stored in fixed memory location + // and buffer pointers were set up in SendBuffers::new() unsafe { - let bytes = (&mut *buffer.bytes.get()).as_mut_slice(); - - let mut cursor = Cursor::new(bytes); - - match response.write(&mut cursor) { - Ok(()) => { - (&mut *buffer.iovec.get()).iov_len = cursor.position() as usize; - - buffer.response_type = ResponseType::from_response(response); - buffer.free = false; - + match buffer.prepare_entry(response, addr, self.socket_is_ipv4) { + Ok(entry) => { self.likely_next_free_index = index + 1; - let sqe = SendMsg::new(SOCKET_IDENTIFIER, buffer.msghdr.get()) - .build() - .user_data(index as u64); - - Ok(sqe) + Ok(entry.user_data(index as u64)) } - Err(err) => Err(Error::SerializationFailed(err)), + Err(err) => Err(err), } } }