diff --git a/aquatic_udp/src/workers/socket/uring/recv_helper.rs b/aquatic_udp/src/workers/socket/uring/recv_helper.rs index 96fd551..808c2f9 100644 --- a/aquatic_udp/src/workers/socket/uring/recv_helper.rs +++ b/aquatic_udp/src/workers/socket/uring/recv_helper.rs @@ -1,6 +1,6 @@ use std::{ cell::UnsafeCell, - net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + net::{Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, ptr::null_mut, }; @@ -13,7 +13,7 @@ use crate::config::Config; use super::{SOCKET_IDENTIFIER, USER_DATA_RECV}; pub struct RecvHelper { - network_address: IpAddr, + socket_is_ipv4: bool, max_scrape_torrents: u8, #[allow(dead_code)] name_v4: Box>, @@ -61,7 +61,7 @@ impl RecvHelper { })); Self { - network_address: config.network.address.ip(), + socket_is_ipv4: config.network.address.is_ipv4(), max_scrape_torrents: config.protocol.max_scrape_torrents, name_v4, msghdr_v4, @@ -71,7 +71,7 @@ impl RecvHelper { } pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry { - let msghdr: *const libc::msghdr = if self.network_address.is_ipv4() { + let msghdr: *const libc::msghdr = if self.socket_is_ipv4 { self.msghdr_v4.get() } else { self.msghdr_v6.get() @@ -86,25 +86,31 @@ impl RecvHelper { &self, buffer: &[u8], ) -> (Result, CanonicalSocketAddr) { - let msghdr = unsafe { - if self.network_address.is_ipv4() { - &*(self.msghdr_v4.get() as *const _) - } else { - &*(self.msghdr_v6.get() as *const _) - } - }; + let (msg, addr) = if self.socket_is_ipv4 { + let msg = unsafe { + let msghdr = &*(self.msghdr_v4.get() as *const _); - let msg = RecvMsgOut::parse(buffer, msghdr).unwrap(); + RecvMsgOut::parse(buffer, msghdr).unwrap() + }; - let addr = unsafe { - if self.network_address.is_ipv4() { + let addr = unsafe { let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in); SocketAddr::V4(SocketAddrV4::new( u32::from_be(name_data.sin_addr.s_addr).into(), u16::from_be(name_data.sin_port), )) - } else { + }; + + (msg, addr) + } else { + let msg = unsafe { + let msghdr = &*(self.msghdr_v6.get() as *const _); + + RecvMsgOut::parse(buffer, msghdr).unwrap() + }; + + let addr = unsafe { let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in6); SocketAddr::V6(SocketAddrV6::new( @@ -113,7 +119,9 @@ impl RecvHelper { u32::from_be(name_data.sin6_flowinfo), u32::from_be(name_data.sin6_scope_id), )) - } + }; + + (msg, addr) }; (