mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-04-01 18:25:30 +00:00
udp: uring: rewrite SendBuffers to use UnsafeCell
This commit is contained in:
parent
2e67f11caf
commit
339feb3d0a
1 changed files with 118 additions and 110 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{io::Cursor, net::IpAddr, ops::IndexMut, ptr::null_mut};
|
use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut};
|
||||||
|
|
||||||
use aquatic_common::CanonicalSocketAddr;
|
use aquatic_common::CanonicalSocketAddr;
|
||||||
use aquatic_udp_protocol::Response;
|
use aquatic_udp_protocol::Response;
|
||||||
|
|
@ -32,114 +32,110 @@ impl ResponseType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SendBuffers {
|
struct SendBuffer {
|
||||||
likely_next_free_index: usize,
|
name_v4: UnsafeCell<libc::sockaddr_in>,
|
||||||
network_address: IpAddr,
|
name_v6: UnsafeCell<libc::sockaddr_in6>,
|
||||||
names_v4: Vec<libc::sockaddr_in>,
|
bytes: UnsafeCell<[u8; BUF_LEN]>,
|
||||||
names_v6: Vec<libc::sockaddr_in6>,
|
iovec: UnsafeCell<libc::iovec>,
|
||||||
buffers: Vec<[u8; BUF_LEN]>,
|
msghdr: UnsafeCell<libc::msghdr>,
|
||||||
iovecs: Vec<libc::iovec>,
|
free: bool,
|
||||||
msghdrs: Vec<libc::msghdr>,
|
|
||||||
free: Vec<bool>,
|
|
||||||
// Only used for statistics
|
// Only used for statistics
|
||||||
receiver_is_ipv4: Vec<bool>,
|
receiver_is_ipv4: bool,
|
||||||
// Only used for statistics
|
// Only used for statistics
|
||||||
response_types: Vec<ResponseType>,
|
response_type: ResponseType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SendBuffers {
|
impl SendBuffer {
|
||||||
pub fn new(config: &Config, capacity: usize) -> Self {
|
fn new() -> Self {
|
||||||
let mut buffers = ::std::iter::repeat([0u8; BUF_LEN])
|
Self {
|
||||||
.take(capacity)
|
name_v4: UnsafeCell::new(libc::sockaddr_in {
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let mut iovecs = buffers
|
|
||||||
.iter_mut()
|
|
||||||
.map(|buffer| libc::iovec {
|
|
||||||
iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
|
|
||||||
iov_len: buffer.len(),
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let (names_v4, names_v6, msghdrs) = if config.network.address.is_ipv4() {
|
|
||||||
let mut names_v4 = ::std::iter::repeat(libc::sockaddr_in {
|
|
||||||
sin_family: libc::AF_INET as u16,
|
sin_family: libc::AF_INET as u16,
|
||||||
sin_port: 0,
|
sin_port: 0,
|
||||||
sin_addr: libc::in_addr { s_addr: 0 },
|
sin_addr: libc::in_addr { s_addr: 0 },
|
||||||
sin_zero: [0; 8],
|
sin_zero: [0; 8],
|
||||||
})
|
}),
|
||||||
.take(capacity)
|
name_v6: UnsafeCell::new(libc::sockaddr_in6 {
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let msghdrs = names_v4
|
|
||||||
.iter_mut()
|
|
||||||
.zip(iovecs.iter_mut())
|
|
||||||
.map(|(msg_name, msg_iov)| libc::msghdr {
|
|
||||||
msg_name: msg_name as *mut _ as *mut libc::c_void,
|
|
||||||
msg_namelen: core::mem::size_of::<libc::sockaddr_in>() as u32,
|
|
||||||
msg_iov: msg_iov as *mut _,
|
|
||||||
msg_iovlen: 1,
|
|
||||||
msg_control: null_mut(),
|
|
||||||
msg_controllen: 0,
|
|
||||||
msg_flags: 0,
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
(names_v4, Vec::new(), msghdrs)
|
|
||||||
} else {
|
|
||||||
let mut names_v6 = ::std::iter::repeat(libc::sockaddr_in6 {
|
|
||||||
sin6_family: libc::AF_INET6 as u16,
|
sin6_family: libc::AF_INET6 as u16,
|
||||||
sin6_port: 0,
|
sin6_port: 0,
|
||||||
sin6_flowinfo: 0,
|
sin6_flowinfo: 0,
|
||||||
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
|
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
|
||||||
sin6_scope_id: 0,
|
sin6_scope_id: 0,
|
||||||
})
|
}),
|
||||||
|
bytes: UnsafeCell::new([0; BUF_LEN]),
|
||||||
|
iovec: UnsafeCell::new(libc::iovec {
|
||||||
|
iov_base: null_mut(),
|
||||||
|
iov_len: 0,
|
||||||
|
}),
|
||||||
|
msghdr: UnsafeCell::new(libc::msghdr {
|
||||||
|
msg_name: null_mut(),
|
||||||
|
msg_namelen: 0,
|
||||||
|
msg_iov: null_mut(),
|
||||||
|
msg_iovlen: 1,
|
||||||
|
msg_control: null_mut(),
|
||||||
|
msg_controllen: 0,
|
||||||
|
msg_flags: 0,
|
||||||
|
}),
|
||||||
|
free: true,
|
||||||
|
receiver_is_ipv4: true,
|
||||||
|
response_type: ResponseType::Connect,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SendBuffers {
|
||||||
|
likely_next_free_index: usize,
|
||||||
|
socket_is_ipv4: bool,
|
||||||
|
buffers: Box<[SendBuffer]>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SendBuffers {
|
||||||
|
pub fn new(config: &Config, capacity: usize) -> Self {
|
||||||
|
let socket_is_ipv4 = config.network.address.is_ipv4();
|
||||||
|
|
||||||
|
let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new())
|
||||||
.take(capacity)
|
.take(capacity)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>()
|
||||||
|
.into_boxed_slice();
|
||||||
|
|
||||||
let msghdrs = names_v6
|
for buffer in buffers.iter_mut() {
|
||||||
.iter_mut()
|
unsafe {
|
||||||
.zip(iovecs.iter_mut())
|
let iovec = &mut *buffer.iovec.get();
|
||||||
.map(|(msg_name, msg_iov)| libc::msghdr {
|
|
||||||
msg_name: msg_name as *mut _ as *mut libc::c_void,
|
|
||||||
msg_namelen: core::mem::size_of::<libc::sockaddr_in6>() as u32,
|
|
||||||
msg_iov: msg_iov as *mut _,
|
|
||||||
msg_iovlen: 1,
|
|
||||||
msg_control: null_mut(),
|
|
||||||
msg_controllen: 0,
|
|
||||||
msg_flags: 0,
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
(Vec::new(), names_v6, msghdrs)
|
iovec.iov_base = buffer.bytes.get() as *mut libc::c_void;
|
||||||
};
|
iovec.iov_len = (&*buffer.bytes.get()).len();
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
let msghdr = &mut *buffer.msghdr.get();
|
||||||
|
|
||||||
|
msghdr.msg_iov = buffer.iovec.get();
|
||||||
|
|
||||||
|
if socket_is_ipv4 {
|
||||||
|
msghdr.msg_name = buffer.name_v4.get() as *mut libc::c_void;
|
||||||
|
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
|
||||||
|
} else {
|
||||||
|
msghdr.msg_name = buffer.name_v6.get() as *mut libc::c_void;
|
||||||
|
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
likely_next_free_index: 0,
|
likely_next_free_index: 0,
|
||||||
network_address: config.network.address.ip(),
|
socket_is_ipv4,
|
||||||
names_v4,
|
|
||||||
names_v6,
|
|
||||||
buffers,
|
buffers,
|
||||||
iovecs,
|
|
||||||
msghdrs,
|
|
||||||
free: ::std::iter::repeat(true).take(capacity).collect(),
|
|
||||||
receiver_is_ipv4: ::std::iter::repeat(true).take(capacity).collect(),
|
|
||||||
response_types: ::std::iter::repeat(ResponseType::Connect)
|
|
||||||
.take(capacity)
|
|
||||||
.collect(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn receiver_is_ipv4(&mut self, index: usize) -> bool {
|
pub fn receiver_is_ipv4(&mut self, index: usize) -> bool {
|
||||||
self.receiver_is_ipv4[index]
|
self.buffers[index].receiver_is_ipv4
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn response_type(&mut self, index: usize) -> ResponseType {
|
pub fn response_type(&mut self, index: usize) -> ResponseType {
|
||||||
self.response_types[index]
|
self.buffers[index].response_type
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mark_index_as_free(&mut self, index: usize) {
|
pub fn mark_index_as_free(&mut self, index: usize) {
|
||||||
self.free[index] = true;
|
self.buffers[index].free = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call after going through completion queue
|
/// Call after going through completion queue
|
||||||
|
|
@ -154,64 +150,76 @@ impl SendBuffers {
|
||||||
) -> Result<io_uring::squeue::Entry, Error> {
|
) -> Result<io_uring::squeue::Entry, Error> {
|
||||||
let index = self.next_free_index()?;
|
let index = self.next_free_index()?;
|
||||||
|
|
||||||
// Set receiver socket addr
|
let buffer = self.buffers.index_mut(index);
|
||||||
if self.network_address.is_ipv4() {
|
|
||||||
let msg_name = self.names_v4.index_mut(index);
|
|
||||||
let addr = addr.get_ipv4().unwrap();
|
|
||||||
|
|
||||||
msg_name.sin_port = addr.port().to_be();
|
// Set receiver socket addr
|
||||||
msg_name.sin_addr.s_addr = if let IpAddr::V4(addr) = addr.ip() {
|
if self.socket_is_ipv4 {
|
||||||
u32::from(addr).to_be()
|
buffer.receiver_is_ipv4 = true;
|
||||||
|
|
||||||
|
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
|
||||||
|
addr
|
||||||
} else {
|
} else {
|
||||||
panic!("ipv6 address in ipv4 mode");
|
panic!("ipv6 address in ipv4 mode");
|
||||||
};
|
};
|
||||||
|
|
||||||
self.receiver_is_ipv4[index] = true;
|
unsafe {
|
||||||
} else {
|
let name = &mut *buffer.name_v4.get();
|
||||||
let msg_name = self.names_v6.index_mut(index);
|
|
||||||
let addr = addr.get_ipv6_mapped();
|
|
||||||
|
|
||||||
msg_name.sin6_port = addr.port().to_be();
|
name.sin_port = addr.port().to_be();
|
||||||
msg_name.sin6_addr.s6_addr = if let IpAddr::V6(addr) = addr.ip() {
|
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
|
||||||
addr.octets()
|
}
|
||||||
|
} else {
|
||||||
|
buffer.receiver_is_ipv4 = addr.is_ipv4();
|
||||||
|
|
||||||
|
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
|
||||||
|
addr
|
||||||
} else {
|
} else {
|
||||||
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
|
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
|
||||||
};
|
};
|
||||||
|
|
||||||
self.receiver_is_ipv4[index] = addr.is_ipv4();
|
unsafe {
|
||||||
|
let name = &mut *buffer.name_v6.get();
|
||||||
|
|
||||||
|
name.sin6_port = addr.port().to_be();
|
||||||
|
name.sin6_addr.s6_addr = addr.ip().octets();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut cursor = Cursor::new(self.buffers.index_mut(index).as_mut_slice());
|
unsafe {
|
||||||
|
let bytes = (&mut *buffer.bytes.get()).as_mut_slice();
|
||||||
|
|
||||||
match response.write(&mut cursor) {
|
let mut cursor = Cursor::new(bytes);
|
||||||
Ok(()) => {
|
|
||||||
self.iovecs[index].iov_len = cursor.position() as usize;
|
|
||||||
self.response_types[index] = ResponseType::from_response(response);
|
|
||||||
self.free[index] = false;
|
|
||||||
|
|
||||||
self.likely_next_free_index = index + 1;
|
match response.write(&mut cursor) {
|
||||||
|
Ok(()) => {
|
||||||
|
(&mut *buffer.iovec.get()).iov_len = cursor.position() as usize;
|
||||||
|
|
||||||
let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdrs.index_mut(index))
|
buffer.response_type = ResponseType::from_response(response);
|
||||||
.build()
|
buffer.free = false;
|
||||||
.user_data(index as u64);
|
|
||||||
|
|
||||||
Ok(sqe)
|
self.likely_next_free_index = index + 1;
|
||||||
|
|
||||||
|
let sqe = SendMsg::new(SOCKET_IDENTIFIER, buffer.msghdr.get())
|
||||||
|
.build()
|
||||||
|
.user_data(index as u64);
|
||||||
|
|
||||||
|
Ok(sqe)
|
||||||
|
}
|
||||||
|
Err(err) => Err(Error::SerializationFailed(err)),
|
||||||
}
|
}
|
||||||
Err(err) => Err(Error::SerializationFailed(err)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn next_free_index(&self) -> Result<usize, Error> {
|
fn next_free_index(&self) -> Result<usize, Error> {
|
||||||
if self.likely_next_free_index >= self.free.len() {
|
if self.likely_next_free_index >= self.buffers.len() {
|
||||||
return Err(Error::NoBuffers);
|
return Err(Error::NoBuffers);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i, free) in self.free[self.likely_next_free_index..]
|
for (i, buffer) in self.buffers[self.likely_next_free_index..]
|
||||||
.iter()
|
.iter()
|
||||||
.copied()
|
|
||||||
.enumerate()
|
.enumerate()
|
||||||
{
|
{
|
||||||
if free {
|
if buffer.free {
|
||||||
return Ok(self.likely_next_free_index + i);
|
return Ok(self.likely_next_free_index + i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue