mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-03-31 17:55:36 +00:00
udp: fix io_uring soundness issues
This commit is contained in:
parent
3f2a87b10f
commit
af16a9e682
9 changed files with 300 additions and 282 deletions
|
|
@ -19,9 +19,12 @@ name = "aquatic_udp"
|
|||
|
||||
[features]
|
||||
default = ["prometheus"]
|
||||
cpu-pinning = ["aquatic_common/hwloc"]
|
||||
# Export prometheus metrics
|
||||
prometheus = ["metrics", "metrics-util", "metrics-exporter-prometheus"]
|
||||
# Experimental io_uring support (Linux 6.0 or later required)
|
||||
io-uring = ["dep:io-uring"]
|
||||
# Experimental CPU pinning support
|
||||
cpu-pinning = ["aquatic_common/hwloc"]
|
||||
|
||||
[dependencies]
|
||||
aquatic_common.workspace = true
|
||||
|
|
@ -38,12 +41,8 @@ getrandom = "0.2"
|
|||
hashbrown = { version = "0.14", default-features = false }
|
||||
hdrhistogram = "7"
|
||||
hex = "0.4"
|
||||
io-uring = { version = "0.6", optional = true }
|
||||
libc = "0.2"
|
||||
log = "0.4"
|
||||
metrics = { version = "0.21", optional = true }
|
||||
metrics-util = { version = "0.15", optional = true }
|
||||
metrics-exporter-prometheus = { version = "0.12", optional = true, default-features = false, features = ["http-listener"] }
|
||||
mimalloc = { version = "0.1", default-features = false }
|
||||
mio = { version = "0.8", features = ["net", "os-poll"] }
|
||||
num-format = "0.4"
|
||||
|
|
@ -55,6 +54,14 @@ socket2 = { version = "0.5", features = ["all"] }
|
|||
time = { version = "0.3", features = ["formatting"] }
|
||||
tinytemplate = "1"
|
||||
|
||||
# prometheus feature
|
||||
metrics = { version = "0.21", optional = true }
|
||||
metrics-util = { version = "0.15", optional = true }
|
||||
metrics-exporter-prometheus = { version = "0.12", optional = true, default-features = false, features = ["http-listener"] }
|
||||
|
||||
# io-uring feature
|
||||
io-uring = { version = "0.6", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
hex = "0.4"
|
||||
tempfile = "3"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
mod mio;
|
||||
mod storage;
|
||||
#[cfg(feature = "io-uring")]
|
||||
#[cfg(all(target_os = "linux", feature = "io-uring"))]
|
||||
mod uring;
|
||||
mod validator;
|
||||
|
||||
|
|
@ -18,6 +18,9 @@ use crate::{
|
|||
|
||||
pub use self::validator::ConnectionValidator;
|
||||
|
||||
#[cfg(all(not(target_os = "linux"), feature = "io-uring"))]
|
||||
compile_error!("io_uring feature is only supported on Linux");
|
||||
|
||||
/// Bytes of data transmitted when sending an IPv4 UDP packet, in addition to payload size
|
||||
///
|
||||
/// Consists of:
|
||||
|
|
@ -46,7 +49,7 @@ pub fn run_socket_worker(
|
|||
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
|
||||
priv_dropper: PrivilegeDropper,
|
||||
) {
|
||||
#[cfg(feature = "io-uring")]
|
||||
#[cfg(all(target_os = "linux", feature = "io-uring"))]
|
||||
match self::uring::supported_on_current_kernel() {
|
||||
Ok(()) => {
|
||||
self::uring::SocketWorker::run(
|
||||
|
|
|
|||
|
|
@ -36,8 +36,10 @@ use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6};
|
|||
|
||||
/// Size of each request buffer
|
||||
///
|
||||
/// Enough for scrape request with 20 info hashes
|
||||
const REQUEST_BUF_LEN: usize = 256;
|
||||
/// Needs to fit recvmsg metadata in addition to the payload.
|
||||
///
|
||||
/// The payload of a scrape request with 20 info hashes fits in 256 bytes.
|
||||
const REQUEST_BUF_LEN: usize = 512;
|
||||
|
||||
/// Size of each response buffer
|
||||
///
|
||||
|
|
@ -111,6 +113,7 @@ impl SocketWorker {
|
|||
|
||||
let socket = create_socket(&config, priv_dropper).expect("create socket");
|
||||
let access_list_cache = create_access_list_cache(&shared_state.access_list);
|
||||
|
||||
let send_buffers = SendBuffers::new(&config, send_buffer_entries as usize);
|
||||
let recv_helper = RecvHelper::new(&config);
|
||||
|
||||
|
|
@ -372,9 +375,7 @@ impl SocketWorker {
|
|||
}
|
||||
};
|
||||
|
||||
let buffer = buffer.as_slice();
|
||||
|
||||
let addr = match self.recv_helper.parse(buffer) {
|
||||
let addr = match self.recv_helper.parse(buffer.as_slice()) {
|
||||
Ok((request, addr)) => {
|
||||
self.handle_request(request, addr);
|
||||
|
||||
|
|
@ -413,6 +414,11 @@ impl SocketWorker {
|
|||
Err(self::recv_helper::Error::RecvMsgParseError) => {
|
||||
::log::error!("RecvMsgOut::parse failed");
|
||||
|
||||
return;
|
||||
}
|
||||
Err(self::recv_helper::Error::RecvMsgTruncated) => {
|
||||
::log::warn!("RecvMsgOut::parse failed: sockaddr or payload truncated");
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use std::{
|
||||
cell::UnsafeCell,
|
||||
net::{Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
|
||||
ptr::null_mut,
|
||||
};
|
||||
|
|
@ -14,6 +13,7 @@ use super::{SOCKET_IDENTIFIER, USER_DATA_RECV};
|
|||
|
||||
pub enum Error {
|
||||
RecvMsgParseError,
|
||||
RecvMsgTruncated,
|
||||
RequestParseError(RequestParseError, CanonicalSocketAddr),
|
||||
InvalidSocketAddress,
|
||||
}
|
||||
|
|
@ -22,24 +22,24 @@ pub struct RecvHelper {
|
|||
socket_is_ipv4: bool,
|
||||
max_scrape_torrents: u8,
|
||||
#[allow(dead_code)]
|
||||
name_v4: Box<UnsafeCell<libc::sockaddr_in>>,
|
||||
msghdr_v4: Box<UnsafeCell<libc::msghdr>>,
|
||||
name_v4: *const libc::sockaddr_in,
|
||||
msghdr_v4: *const libc::msghdr,
|
||||
#[allow(dead_code)]
|
||||
name_v6: Box<UnsafeCell<libc::sockaddr_in6>>,
|
||||
msghdr_v6: Box<UnsafeCell<libc::msghdr>>,
|
||||
name_v6: *const libc::sockaddr_in6,
|
||||
msghdr_v6: *const libc::msghdr,
|
||||
}
|
||||
|
||||
impl RecvHelper {
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let name_v4 = Box::new(UnsafeCell::new(libc::sockaddr_in {
|
||||
let name_v4 = Box::into_raw(Box::new(libc::sockaddr_in {
|
||||
sin_family: 0,
|
||||
sin_port: 0,
|
||||
sin_addr: libc::in_addr { s_addr: 0 },
|
||||
sin_zero: [0; 8],
|
||||
}));
|
||||
|
||||
let msghdr_v4 = Box::new(UnsafeCell::new(libc::msghdr {
|
||||
msg_name: name_v4.get() as *mut libc::c_void,
|
||||
let msghdr_v4 = Box::into_raw(Box::new(libc::msghdr {
|
||||
msg_name: name_v4 as *mut libc::c_void,
|
||||
msg_namelen: core::mem::size_of::<libc::sockaddr_in>() as u32,
|
||||
msg_iov: null_mut(),
|
||||
msg_iovlen: 0,
|
||||
|
|
@ -48,7 +48,7 @@ impl RecvHelper {
|
|||
msg_flags: 0,
|
||||
}));
|
||||
|
||||
let name_v6 = Box::new(UnsafeCell::new(libc::sockaddr_in6 {
|
||||
let name_v6 = Box::into_raw(Box::new(libc::sockaddr_in6 {
|
||||
sin6_family: 0,
|
||||
sin6_port: 0,
|
||||
sin6_flowinfo: 0,
|
||||
|
|
@ -56,8 +56,8 @@ impl RecvHelper {
|
|||
sin6_scope_id: 0,
|
||||
}));
|
||||
|
||||
let msghdr_v6 = Box::new(UnsafeCell::new(libc::msghdr {
|
||||
msg_name: name_v6.get() as *mut libc::c_void,
|
||||
let msghdr_v6 = Box::into_raw(Box::new(libc::msghdr {
|
||||
msg_name: name_v6 as *mut libc::c_void,
|
||||
msg_namelen: core::mem::size_of::<libc::sockaddr_in6>() as u32,
|
||||
msg_iov: null_mut(),
|
||||
msg_iovlen: 0,
|
||||
|
|
@ -77,10 +77,10 @@ impl RecvHelper {
|
|||
}
|
||||
|
||||
pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry {
|
||||
let msghdr: *const libc::msghdr = if self.socket_is_ipv4 {
|
||||
self.msghdr_v4.get()
|
||||
let msghdr = if self.socket_is_ipv4 {
|
||||
self.msghdr_v4
|
||||
} else {
|
||||
self.msghdr_v6.get()
|
||||
self.msghdr_v6
|
||||
};
|
||||
|
||||
RecvMsgMulti::new(SOCKET_IDENTIFIER, msghdr, buf_group)
|
||||
|
|
@ -90,51 +90,51 @@ impl RecvHelper {
|
|||
|
||||
pub fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> {
|
||||
let (msg, addr) = if self.socket_is_ipv4 {
|
||||
let msg = unsafe {
|
||||
let msghdr = &*(self.msghdr_v4.get() as *const _);
|
||||
// Safe as long as kernel only reads from the pointer and doesn't
|
||||
// write to it. I think this is the case.
|
||||
let msghdr = unsafe { self.msghdr_v4.read() };
|
||||
|
||||
RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)?
|
||||
};
|
||||
let msg = RecvMsgOut::parse(buffer, &msghdr).map_err(|_| Error::RecvMsgParseError)?;
|
||||
|
||||
let addr = unsafe {
|
||||
let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in);
|
||||
|
||||
SocketAddr::V4(SocketAddrV4::new(
|
||||
u32::from_be(name_data.sin_addr.s_addr).into(),
|
||||
u16::from_be(name_data.sin_port),
|
||||
))
|
||||
};
|
||||
|
||||
if addr.port() == 0 {
|
||||
return Err(Error::InvalidSocketAddress);
|
||||
if msg.is_name_data_truncated() | msg.is_payload_truncated() {
|
||||
return Err(Error::RecvMsgTruncated);
|
||||
}
|
||||
|
||||
let name_data = unsafe { *(msg.name_data().as_ptr() as *const libc::sockaddr_in) };
|
||||
|
||||
let addr = SocketAddr::V4(SocketAddrV4::new(
|
||||
u32::from_be(name_data.sin_addr.s_addr).into(),
|
||||
u16::from_be(name_data.sin_port),
|
||||
));
|
||||
|
||||
(msg, addr)
|
||||
} else {
|
||||
let msg = unsafe {
|
||||
let msghdr = &*(self.msghdr_v6.get() as *const _);
|
||||
// Safe as long as kernel only reads from the pointer and doesn't
|
||||
// write to it. I think this is the case.
|
||||
let msghdr = unsafe { self.msghdr_v6.read() };
|
||||
|
||||
RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)?
|
||||
};
|
||||
let msg = RecvMsgOut::parse(buffer, &msghdr).map_err(|_| Error::RecvMsgParseError)?;
|
||||
|
||||
let addr = unsafe {
|
||||
let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in6);
|
||||
|
||||
SocketAddr::V6(SocketAddrV6::new(
|
||||
Ipv6Addr::from(name_data.sin6_addr.s6_addr),
|
||||
u16::from_be(name_data.sin6_port),
|
||||
u32::from_be(name_data.sin6_flowinfo),
|
||||
u32::from_be(name_data.sin6_scope_id),
|
||||
))
|
||||
};
|
||||
|
||||
if addr.port() == 0 {
|
||||
return Err(Error::InvalidSocketAddress);
|
||||
if msg.is_name_data_truncated() | msg.is_payload_truncated() {
|
||||
return Err(Error::RecvMsgTruncated);
|
||||
}
|
||||
|
||||
let name_data = unsafe { *(msg.name_data().as_ptr() as *const libc::sockaddr_in6) };
|
||||
|
||||
let addr = SocketAddr::V6(SocketAddrV6::new(
|
||||
Ipv6Addr::from(name_data.sin6_addr.s6_addr),
|
||||
u16::from_be(name_data.sin6_port),
|
||||
u32::from_be(name_data.sin6_flowinfo),
|
||||
u32::from_be(name_data.sin6_scope_id),
|
||||
));
|
||||
|
||||
(msg, addr)
|
||||
};
|
||||
|
||||
if addr.port() == 0 {
|
||||
return Err(Error::InvalidSocketAddress);
|
||||
}
|
||||
|
||||
let addr = CanonicalSocketAddr::new(addr);
|
||||
|
||||
let request = Request::from_bytes(msg.payload_data(), self.max_scrape_torrents)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut};
|
||||
use std::{
|
||||
io::Cursor,
|
||||
iter::repeat_with,
|
||||
net::SocketAddr,
|
||||
ptr::{addr_of_mut, null_mut},
|
||||
};
|
||||
|
||||
use aquatic_common::CanonicalSocketAddr;
|
||||
use aquatic_udp_protocol::Response;
|
||||
|
|
@ -13,8 +18,215 @@ pub enum Error {
|
|||
SerializationFailed(std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct SendBuffers {
|
||||
likely_next_free_index: usize,
|
||||
socket_is_ipv4: bool,
|
||||
buffers: Vec<(SendBufferMetadata, *mut SendBuffer)>,
|
||||
}
|
||||
|
||||
impl SendBuffers {
|
||||
pub fn new(config: &Config, capacity: usize) -> Self {
|
||||
let socket_is_ipv4 = config.network.address.is_ipv4();
|
||||
|
||||
let buffers = repeat_with(|| (Default::default(), SendBuffer::new(socket_is_ipv4)))
|
||||
.take(capacity)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Self {
|
||||
likely_next_free_index: 0,
|
||||
socket_is_ipv4,
|
||||
buffers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn response_type_and_ipv4(&self, index: usize) -> (ResponseType, bool) {
|
||||
let meta = &self.buffers.get(index).unwrap().0;
|
||||
|
||||
(meta.response_type, meta.receiver_is_ipv4)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Only safe to call once buffer is no longer referenced by in-flight
|
||||
/// io_uring queue entries
|
||||
pub unsafe fn mark_buffer_as_free(&mut self, index: usize) {
|
||||
self.buffers[index].0.free = true;
|
||||
}
|
||||
|
||||
/// Call after going through completion queue
|
||||
pub fn reset_likely_next_free_index(&mut self) {
|
||||
self.likely_next_free_index = 0;
|
||||
}
|
||||
|
||||
pub fn prepare_entry(
|
||||
&mut self,
|
||||
response: &Response,
|
||||
addr: CanonicalSocketAddr,
|
||||
) -> Result<io_uring::squeue::Entry, Error> {
|
||||
let index = self.next_free_index()?;
|
||||
|
||||
let (buffer_metadata, buffer) = self.buffers.get_mut(index).unwrap();
|
||||
|
||||
// Safe as long as `mark_buffer_as_free` was used correctly
|
||||
let buffer = unsafe { &mut *(*buffer) };
|
||||
|
||||
match buffer.prepare_entry(response, addr, self.socket_is_ipv4, buffer_metadata) {
|
||||
Ok(entry) => {
|
||||
buffer_metadata.free = false;
|
||||
|
||||
self.likely_next_free_index = index + 1;
|
||||
|
||||
Ok(entry.user_data(index as u64))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_free_index(&self) -> Result<usize, Error> {
|
||||
if self.likely_next_free_index >= self.buffers.len() {
|
||||
return Err(Error::NoBuffers);
|
||||
}
|
||||
|
||||
for (i, (meta, _)) in self.buffers[self.likely_next_free_index..]
|
||||
.iter()
|
||||
.enumerate()
|
||||
{
|
||||
if meta.free {
|
||||
return Ok(self.likely_next_free_index + i);
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::NoBuffers)
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure not to hold any reference to this struct while kernel can
|
||||
/// write to its contents
|
||||
struct SendBuffer {
|
||||
name_v4: libc::sockaddr_in,
|
||||
name_v6: libc::sockaddr_in6,
|
||||
bytes: [u8; RESPONSE_BUF_LEN],
|
||||
iovec: libc::iovec,
|
||||
msghdr: libc::msghdr,
|
||||
}
|
||||
|
||||
impl SendBuffer {
|
||||
fn new(socket_is_ipv4: bool) -> *mut Self {
|
||||
let mut instance = Box::new(Self {
|
||||
name_v4: libc::sockaddr_in {
|
||||
sin_family: libc::AF_INET as u16,
|
||||
sin_port: 0,
|
||||
sin_addr: libc::in_addr { s_addr: 0 },
|
||||
sin_zero: [0; 8],
|
||||
},
|
||||
name_v6: libc::sockaddr_in6 {
|
||||
sin6_family: libc::AF_INET6 as u16,
|
||||
sin6_port: 0,
|
||||
sin6_flowinfo: 0,
|
||||
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
|
||||
sin6_scope_id: 0,
|
||||
},
|
||||
bytes: [0; RESPONSE_BUF_LEN],
|
||||
iovec: libc::iovec {
|
||||
iov_base: null_mut(),
|
||||
iov_len: 0,
|
||||
},
|
||||
msghdr: libc::msghdr {
|
||||
msg_name: null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: null_mut(),
|
||||
msg_iovlen: 1,
|
||||
msg_control: null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
},
|
||||
});
|
||||
|
||||
instance.iovec.iov_base = addr_of_mut!(instance.bytes) as *mut libc::c_void;
|
||||
instance.iovec.iov_len = instance.bytes.len();
|
||||
|
||||
instance.msghdr.msg_iov = addr_of_mut!(instance.iovec);
|
||||
|
||||
if socket_is_ipv4 {
|
||||
instance.msghdr.msg_name = addr_of_mut!(instance.name_v4) as *mut libc::c_void;
|
||||
instance.msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
|
||||
} else {
|
||||
instance.msghdr.msg_name = addr_of_mut!(instance.name_v6) as *mut libc::c_void;
|
||||
instance.msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
|
||||
}
|
||||
|
||||
Box::into_raw(instance)
|
||||
}
|
||||
|
||||
fn prepare_entry(
|
||||
&mut self,
|
||||
response: &Response,
|
||||
addr: CanonicalSocketAddr,
|
||||
socket_is_ipv4: bool,
|
||||
metadata: &mut SendBufferMetadata,
|
||||
) -> Result<io_uring::squeue::Entry, Error> {
|
||||
if socket_is_ipv4 {
|
||||
metadata.receiver_is_ipv4 = true;
|
||||
|
||||
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
|
||||
addr
|
||||
} else {
|
||||
panic!("ipv6 address in ipv4 mode");
|
||||
};
|
||||
|
||||
self.name_v4.sin_port = addr.port().to_be();
|
||||
self.name_v4.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
|
||||
} else {
|
||||
// Set receiver protocol type before calling addr.get_ipv6_mapped()
|
||||
metadata.receiver_is_ipv4 = addr.is_ipv4();
|
||||
|
||||
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
|
||||
addr
|
||||
} else {
|
||||
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
|
||||
};
|
||||
|
||||
self.name_v6.sin6_port = addr.port().to_be();
|
||||
self.name_v6.sin6_addr.s6_addr = addr.ip().octets();
|
||||
}
|
||||
|
||||
let mut cursor = Cursor::new(&mut self.bytes[..]);
|
||||
|
||||
match response.write(&mut cursor) {
|
||||
Ok(()) => {
|
||||
self.iovec.iov_len = cursor.position() as usize;
|
||||
|
||||
metadata.response_type = ResponseType::from_response(response);
|
||||
|
||||
Ok(SendMsg::new(SOCKET_IDENTIFIER, addr_of_mut!(self.msghdr)).build())
|
||||
}
|
||||
Err(err) => Err(Error::SerializationFailed(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SendBufferMetadata {
|
||||
free: bool,
|
||||
/// Only used for statistics
|
||||
receiver_is_ipv4: bool,
|
||||
/// Only used for statistics
|
||||
response_type: ResponseType,
|
||||
}
|
||||
|
||||
impl Default for SendBufferMetadata {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
free: true,
|
||||
receiver_is_ipv4: true,
|
||||
response_type: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub enum ResponseType {
|
||||
#[default]
|
||||
Connect,
|
||||
Announce,
|
||||
Scrape,
|
||||
|
|
@ -31,221 +243,3 @@ impl ResponseType {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SendBuffer {
|
||||
name_v4: UnsafeCell<libc::sockaddr_in>,
|
||||
name_v6: UnsafeCell<libc::sockaddr_in6>,
|
||||
bytes: UnsafeCell<[u8; RESPONSE_BUF_LEN]>,
|
||||
iovec: UnsafeCell<libc::iovec>,
|
||||
msghdr: UnsafeCell<libc::msghdr>,
|
||||
free: bool,
|
||||
/// Only used for statistics
|
||||
receiver_is_ipv4: bool,
|
||||
/// Only used for statistics
|
||||
response_type: ResponseType,
|
||||
}
|
||||
|
||||
impl SendBuffer {
|
||||
fn new_with_null_pointers() -> Self {
|
||||
Self {
|
||||
name_v4: UnsafeCell::new(libc::sockaddr_in {
|
||||
sin_family: libc::AF_INET as u16,
|
||||
sin_port: 0,
|
||||
sin_addr: libc::in_addr { s_addr: 0 },
|
||||
sin_zero: [0; 8],
|
||||
}),
|
||||
name_v6: UnsafeCell::new(libc::sockaddr_in6 {
|
||||
sin6_family: libc::AF_INET6 as u16,
|
||||
sin6_port: 0,
|
||||
sin6_flowinfo: 0,
|
||||
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
|
||||
sin6_scope_id: 0,
|
||||
}),
|
||||
bytes: UnsafeCell::new([0; RESPONSE_BUF_LEN]),
|
||||
iovec: UnsafeCell::new(libc::iovec {
|
||||
iov_base: null_mut(),
|
||||
iov_len: 0,
|
||||
}),
|
||||
msghdr: UnsafeCell::new(libc::msghdr {
|
||||
msg_name: null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: null_mut(),
|
||||
msg_iovlen: 1,
|
||||
msg_control: null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
}),
|
||||
free: true,
|
||||
receiver_is_ipv4: true,
|
||||
response_type: ResponseType::Connect,
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_pointers(&mut self, socket_is_ipv4: bool) {
|
||||
unsafe {
|
||||
let iovec = &mut *self.iovec.get();
|
||||
|
||||
iovec.iov_base = self.bytes.get() as *mut libc::c_void;
|
||||
iovec.iov_len = (&*self.bytes.get()).len();
|
||||
|
||||
let msghdr = &mut *self.msghdr.get();
|
||||
|
||||
msghdr.msg_iov = self.iovec.get();
|
||||
|
||||
if socket_is_ipv4 {
|
||||
msghdr.msg_name = self.name_v4.get() as *mut libc::c_void;
|
||||
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
|
||||
} else {
|
||||
msghdr.msg_name = self.name_v6.get() as *mut libc::c_void;
|
||||
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - SendBuffer must be stored at a fixed location in memory
|
||||
/// - SendBuffer.setup_pointers must have been called while stored at that
|
||||
/// fixed location
|
||||
/// - Contents of struct fields wrapped in UnsafeCell can NOT be accessed
|
||||
/// simultaneously to this function call
|
||||
unsafe fn prepare_entry(
|
||||
&mut self,
|
||||
response: &Response,
|
||||
addr: CanonicalSocketAddr,
|
||||
socket_is_ipv4: bool,
|
||||
) -> Result<io_uring::squeue::Entry, Error> {
|
||||
// Set receiver socket addr
|
||||
if socket_is_ipv4 {
|
||||
self.receiver_is_ipv4 = true;
|
||||
|
||||
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
|
||||
addr
|
||||
} else {
|
||||
panic!("ipv6 address in ipv4 mode");
|
||||
};
|
||||
|
||||
let name = &mut *self.name_v4.get();
|
||||
|
||||
name.sin_port = addr.port().to_be();
|
||||
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
|
||||
} else {
|
||||
// Set receiver protocol type before calling addr.get_ipv6_mapped()
|
||||
self.receiver_is_ipv4 = addr.is_ipv4();
|
||||
|
||||
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
|
||||
addr
|
||||
} else {
|
||||
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
|
||||
};
|
||||
|
||||
let name = &mut *self.name_v6.get();
|
||||
|
||||
name.sin6_port = addr.port().to_be();
|
||||
name.sin6_addr.s6_addr = addr.ip().octets();
|
||||
}
|
||||
|
||||
let bytes = (&mut *self.bytes.get()).as_mut_slice();
|
||||
|
||||
let mut cursor = Cursor::new(bytes);
|
||||
|
||||
match response.write(&mut cursor) {
|
||||
Ok(()) => {
|
||||
(&mut *self.iovec.get()).iov_len = cursor.position() as usize;
|
||||
|
||||
self.response_type = ResponseType::from_response(response);
|
||||
self.free = false;
|
||||
|
||||
Ok(SendMsg::new(SOCKET_IDENTIFIER, self.msghdr.get()).build())
|
||||
}
|
||||
Err(err) => Err(Error::SerializationFailed(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SendBuffers {
|
||||
likely_next_free_index: usize,
|
||||
socket_is_ipv4: bool,
|
||||
buffers: Box<[SendBuffer]>,
|
||||
}
|
||||
|
||||
impl SendBuffers {
|
||||
pub fn new(config: &Config, capacity: usize) -> Self {
|
||||
let socket_is_ipv4 = config.network.address.is_ipv4();
|
||||
|
||||
let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new_with_null_pointers())
|
||||
.take(capacity)
|
||||
.collect::<Vec<_>>()
|
||||
.into_boxed_slice();
|
||||
|
||||
for buffer in buffers.iter_mut() {
|
||||
buffer.setup_pointers(socket_is_ipv4);
|
||||
}
|
||||
|
||||
Self {
|
||||
likely_next_free_index: 0,
|
||||
socket_is_ipv4,
|
||||
buffers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn response_type_and_ipv4(&self, index: usize) -> (ResponseType, bool) {
|
||||
let buffer = self.buffers.get(index).unwrap();
|
||||
|
||||
(buffer.response_type, buffer.receiver_is_ipv4)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Only safe to call once buffer is no longer referenced by in-flight
|
||||
/// io_uring queue entries
|
||||
pub unsafe fn mark_buffer_as_free(&mut self, index: usize) {
|
||||
self.buffers[index].free = true;
|
||||
}
|
||||
|
||||
/// Call after going through completion queue
|
||||
pub fn reset_likely_next_free_index(&mut self) {
|
||||
self.likely_next_free_index = 0;
|
||||
}
|
||||
|
||||
pub fn prepare_entry(
|
||||
&mut self,
|
||||
response: &Response,
|
||||
addr: CanonicalSocketAddr,
|
||||
) -> Result<io_uring::squeue::Entry, Error> {
|
||||
let index = self.next_free_index()?;
|
||||
|
||||
let buffer = self.buffers.index_mut(index);
|
||||
|
||||
// Safety: OK because buffers are stored in fixed memory location,
|
||||
// buffer pointers were set up in SendBuffers::new() and pointers to
|
||||
// SendBuffer UnsafeCell contents are not accessed elsewhere
|
||||
unsafe {
|
||||
match buffer.prepare_entry(response, addr, self.socket_is_ipv4) {
|
||||
Ok(entry) => {
|
||||
self.likely_next_free_index = index + 1;
|
||||
|
||||
Ok(entry.user_data(index as u64))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn next_free_index(&self) -> Result<usize, Error> {
|
||||
if self.likely_next_free_index >= self.buffers.len() {
|
||||
return Err(Error::NoBuffers);
|
||||
}
|
||||
|
||||
for (i, buffer) in self.buffers[self.likely_next_free_index..]
|
||||
.iter()
|
||||
.enumerate()
|
||||
{
|
||||
if buffer.free {
|
||||
return Ok(self.likely_next_free_index + i);
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::NoBuffers)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue