aquatic/aquatic_udp/src/workers/socket/uring/send_buffers.rs
2023-03-09 21:57:37 +01:00

251 lines
7.6 KiB
Rust

use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut};
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::Response;
use io_uring::opcode::SendMsg;
use crate::config::Config;
use super::{RESPONSE_BUF_LEN, SOCKET_IDENTIFIER};
pub enum Error {
NoBuffers,
SerializationFailed(std::io::Error),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ResponseType {
Connect,
Announce,
Scrape,
Error,
}
impl ResponseType {
fn from_response(response: &Response) -> Self {
match response {
Response::Connect(_) => Self::Connect,
Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => Self::Announce,
Response::Scrape(_) => Self::Scrape,
Response::Error(_) => Self::Error,
}
}
}
struct SendBuffer {
name_v4: UnsafeCell<libc::sockaddr_in>,
name_v6: UnsafeCell<libc::sockaddr_in6>,
bytes: UnsafeCell<[u8; RESPONSE_BUF_LEN]>,
iovec: UnsafeCell<libc::iovec>,
msghdr: UnsafeCell<libc::msghdr>,
free: bool,
/// Only used for statistics
receiver_is_ipv4: bool,
/// Only used for statistics
response_type: ResponseType,
}
impl SendBuffer {
fn new_with_null_pointers() -> Self {
Self {
name_v4: UnsafeCell::new(libc::sockaddr_in {
sin_family: libc::AF_INET as u16,
sin_port: 0,
sin_addr: libc::in_addr { s_addr: 0 },
sin_zero: [0; 8],
}),
name_v6: UnsafeCell::new(libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as u16,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
}),
bytes: UnsafeCell::new([0; RESPONSE_BUF_LEN]),
iovec: UnsafeCell::new(libc::iovec {
iov_base: null_mut(),
iov_len: 0,
}),
msghdr: UnsafeCell::new(libc::msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: null_mut(),
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
}),
free: true,
receiver_is_ipv4: true,
response_type: ResponseType::Connect,
}
}
fn setup_pointers(&mut self, socket_is_ipv4: bool) {
unsafe {
let iovec = &mut *self.iovec.get();
iovec.iov_base = self.bytes.get() as *mut libc::c_void;
iovec.iov_len = (&*self.bytes.get()).len();
let msghdr = &mut *self.msghdr.get();
msghdr.msg_iov = self.iovec.get();
if socket_is_ipv4 {
msghdr.msg_name = self.name_v4.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
} else {
msghdr.msg_name = self.name_v6.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
}
}
}
/// # Safety
///
/// - SendBuffer must be stored at a fixed location in memory
/// - SendBuffer.setup_pointers must have been called while stored at that
/// fixed location
/// - Contents of struct fields wrapped in UnsafeCell can NOT be accessed
/// simultaneously to this function call
unsafe fn prepare_entry(
&mut self,
response: &Response,
addr: CanonicalSocketAddr,
socket_is_ipv4: bool,
) -> Result<io_uring::squeue::Entry, Error> {
// Set receiver socket addr
if socket_is_ipv4 {
self.receiver_is_ipv4 = true;
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
addr
} else {
panic!("ipv6 address in ipv4 mode");
};
let name = &mut *self.name_v4.get();
name.sin_port = addr.port().to_be();
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
} else {
// Set receiver protocol type before calling addr.get_ipv6_mapped()
self.receiver_is_ipv4 = addr.is_ipv4();
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
addr
} else {
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
};
let name = &mut *self.name_v6.get();
name.sin6_port = addr.port().to_be();
name.sin6_addr.s6_addr = addr.ip().octets();
}
let bytes = (&mut *self.bytes.get()).as_mut_slice();
let mut cursor = Cursor::new(bytes);
match response.write(&mut cursor) {
Ok(()) => {
(&mut *self.iovec.get()).iov_len = cursor.position() as usize;
self.response_type = ResponseType::from_response(response);
self.free = false;
Ok(SendMsg::new(SOCKET_IDENTIFIER, self.msghdr.get()).build())
}
Err(err) => Err(Error::SerializationFailed(err)),
}
}
}
pub struct SendBuffers {
likely_next_free_index: usize,
socket_is_ipv4: bool,
buffers: Box<[SendBuffer]>,
}
impl SendBuffers {
pub fn new(config: &Config, capacity: usize) -> Self {
let socket_is_ipv4 = config.network.address.is_ipv4();
let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new_with_null_pointers())
.take(capacity)
.collect::<Vec<_>>()
.into_boxed_slice();
for buffer in buffers.iter_mut() {
buffer.setup_pointers(socket_is_ipv4);
}
Self {
likely_next_free_index: 0,
socket_is_ipv4,
buffers,
}
}
pub fn response_type_and_ipv4(&self, index: usize) -> (ResponseType, bool) {
let buffer = self.buffers.get(index).unwrap();
(buffer.response_type, buffer.receiver_is_ipv4)
}
/// # Safety
///
/// Only safe to call once buffer is no longer referenced by in-flight
/// io_uring queue entries
pub unsafe fn mark_index_as_free(&mut self, index: usize) {
self.buffers[index].free = true;
}
/// Call after going through completion queue
pub fn reset_index(&mut self) {
self.likely_next_free_index = 0;
}
pub fn prepare_entry(
&mut self,
response: &Response,
addr: CanonicalSocketAddr,
) -> Result<io_uring::squeue::Entry, Error> {
let index = self.next_free_index()?;
let buffer = self.buffers.index_mut(index);
// Safety: OK because buffers are stored in fixed memory location,
// buffer pointers were set up in SendBuffers::new() and pointers to
// SendBuffer UnsafeCell contents are not accessed elsewhere
unsafe {
match buffer.prepare_entry(response, addr, self.socket_is_ipv4) {
Ok(entry) => {
self.likely_next_free_index = index + 1;
Ok(entry.user_data(index as u64))
}
Err(err) => Err(err),
}
}
}
fn next_free_index(&self) -> Result<usize, Error> {
if self.likely_next_free_index >= self.buffers.len() {
return Err(Error::NoBuffers);
}
for (i, buffer) in self.buffers[self.likely_next_free_index..]
.iter()
.enumerate()
{
if buffer.free {
return Ok(self.likely_next_free_index + i);
}
}
Err(Error::NoBuffers)
}
}