WIP: udp io-uring experiments

This commit is contained in:
Joakim Frostegård 2021-11-12 13:30:50 +01:00
parent f93db6a9f2
commit c949bde532
4 changed files with 548 additions and 2 deletions

20
Cargo.lock generated
View file

@ -179,12 +179,15 @@ dependencies = [
"aquatic_cli_helpers",
"aquatic_common",
"aquatic_udp_protocol",
"bytemuck",
"cfg-if",
"crossbeam-channel",
"futures-lite",
"glommio",
"hex",
"histogram",
"io-uring",
"libc",
"log",
"mimalloc",
"mio",
@ -194,6 +197,7 @@ dependencies = [
"rand",
"serde",
"signal-hook",
"slab",
"socket2 0.4.2",
]
@ -446,6 +450,12 @@ version = "3.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9df67f7bf9ef8498769f994239c45613ef0c5899415fb58e9add412d2c1a538"
[[package]]
name = "bytemuck"
version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b"
[[package]]
name = "byteorder"
version = "1.4.3"
@ -1130,6 +1140,16 @@ dependencies = [
"memoffset 0.5.6",
]
[[package]]
name = "io-uring"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d75829ed9377bab6c90039fe47b9d84caceb4b5063266142e21bcce6550cda8"
dependencies = [
"bitflags",
"libc",
]
[[package]]
name = "itertools"
version = "0.10.1"

View file

@ -18,7 +18,7 @@ name = "aquatic_udp"
default = ["with-mio"]
cpu-pinning = ["aquatic_common/cpu-pinning"]
with-glommio = ["cpu-pinning", "glommio", "futures-lite"]
with-mio = ["crossbeam-channel", "histogram", "mio", "socket2"]
with-mio = ["crossbeam-channel", "histogram", "mio", "socket2", "io-uring", "libc", "bytemuck"]
[dependencies]
anyhow = "1"
@ -32,6 +32,7 @@ mimalloc = { version = "0.1", default-features = false }
parking_lot = "0.11"
rand = { version = "0.8", features = ["small_rng"] }
serde = { version = "1", features = ["derive"] }
slab = "0.4"
signal-hook = { version = "0.3" }
# mio
@ -39,6 +40,9 @@ crossbeam-channel = { version = "0.5", optional = true }
histogram = { version = "0.6", optional = true }
mio = { version = "0.7", features = ["udp", "os-poll", "os-util"], optional = true }
socket2 = { version = "0.4.1", features = ["all"], optional = true }
io-uring = { version = "0.5", optional = true }
libc = { version = "0.2", optional = true }
bytemuck = { version = "1", optional = true }
# glommio
glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5", optional = true }

View file

@ -17,6 +17,7 @@ use crate::config::Config;
pub mod common;
pub mod handlers;
pub mod network;
pub mod network_uring;
pub mod tasks;
use common::State;
@ -98,7 +99,7 @@ pub fn run_inner(config: Config, state: State) -> ::anyhow::Result<()> {
WorkerIndex::SocketWorker(i),
);
network::run_socket_worker(
network_uring::run_socket_worker(
state,
config,
i,

View file

@ -0,0 +1,521 @@
use std::io::Cursor;
use std::mem::size_of_val;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::os::unix::prelude::{AsRawFd};
use std::ptr::{null_mut};
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::{Duration, Instant};
use aquatic_common::access_list::{AccessListCache, create_access_list_cache};
use crossbeam_channel::{Receiver, Sender};
use io_uring::SubmissionQueue;
use io_uring::types::{Fixed, Timespec};
use libc::{c_void, in_addr, iovec, msghdr, sockaddr_in};
use rand::prelude::{Rng, SeedableRng, StdRng};
use slab::Slab;
use socket2::{Domain, Protocol, Socket, Type};
use aquatic_udp_protocol::{IpVersion, Request, Response};
use crate::common::handlers::*;
use crate::common::network::ConnectionMap;
use crate::common::*;
use crate::config::Config;
use super::common::*;
const RING_SIZE: usize = 128;
const MAX_RECV_EVENTS: usize = 1;
const MAX_SEND_EVENTS: usize = RING_SIZE - MAX_RECV_EVENTS - 1;
const NUM_BUFFERS: usize = MAX_RECV_EVENTS + MAX_SEND_EVENTS;
#[derive(Clone, Copy, Debug, PartialEq)]
enum UserData {
RecvMsg {
slab_key: usize,
},
SendMsg {
slab_key: usize,
},
Timeout,
}
impl UserData {
fn get_buffer_index(&self) -> usize {
match self {
Self::RecvMsg { slab_key } => {
*slab_key
}
Self::SendMsg { slab_key } => {
slab_key + MAX_RECV_EVENTS
}
Self::Timeout => {
unreachable!()
}
}
}
}
impl From<u64> for UserData {
fn from(mut n: u64) -> UserData {
let bytes = bytemuck::bytes_of_mut(&mut n);
let t = bytes[7];
bytes[7] = 0;
match t {
0 => Self::RecvMsg {
slab_key: n as usize,
},
1 => Self::SendMsg {
slab_key: n as usize,
},
2 => Self::Timeout,
_ => unreachable!(),
}
}
}
impl Into<u64> for UserData {
fn into(self) -> u64 {
match self {
Self::RecvMsg { slab_key } => {
let mut out = slab_key as u64;
bytemuck::bytes_of_mut(&mut out)[7] = 0;
out
}
Self::SendMsg { slab_key } => {
let mut out = slab_key as u64;
bytemuck::bytes_of_mut(&mut out)[7] = 1;
out
}
Self::Timeout => {
let mut out = 0u64;
bytemuck::bytes_of_mut(&mut out)[7] = 2;
out
}
}
}
}
pub fn run_socket_worker(
state: State,
config: Config,
token_num: usize,
request_sender: Sender<(ConnectedRequest, SocketAddr)>,
response_receiver: Receiver<(ConnectedResponse, SocketAddr)>,
num_bound_sockets: Arc<AtomicUsize>,
) {
let mut rng = StdRng::from_entropy();
let socket = create_socket(&config);
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
let mut connections = ConnectionMap::default();
let mut access_list_cache = create_access_list_cache(&state.access_list);
let mut local_responses: Vec<(Response, SocketAddr)> = Vec::new();
let cleaning_duration = Duration::from_secs(config.cleaning.connection_cleaning_interval);
let mut iter_counter = 0usize;
let mut last_cleaning = Instant::now();
let mut buffers: Vec<[u8; MAX_PACKET_SIZE]> = (0..NUM_BUFFERS).map(|_| [0; MAX_PACKET_SIZE]).collect();
let mut sockaddrs_ipv4 = [
sockaddr_in {
sin_addr: in_addr {
s_addr: 0,
},
sin_port: 0,
sin_family: 0,
sin_zero: Default::default(),
}
; NUM_BUFFERS
];
let mut iovs: Vec<iovec> = (0..NUM_BUFFERS).map(|i| {
let iov_base = buffers[i].as_mut_ptr() as *mut c_void;
let iov_len = MAX_PACKET_SIZE;
iovec {
iov_base,
iov_len,
}
}).collect();
let mut msghdrs: Vec<msghdr> = (0..NUM_BUFFERS).map(|i| {
let msg_iov: *mut iovec = &mut iovs[i];
let msg_name: *mut sockaddr_in = &mut sockaddrs_ipv4[i];
msghdr {
msg_name: msg_name as *mut c_void,
msg_namelen: size_of_val(&sockaddrs_ipv4[i]) as u32,
msg_iov,
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
}
}).collect();
let timeout = Timespec::new().nsec(500_000_000);
let mut timeout_set = false;
let mut recv_entries = Slab::with_capacity(MAX_RECV_EVENTS);
let mut send_entries = Slab::with_capacity(MAX_SEND_EVENTS);
let mut ring = io_uring::IoUring::new(RING_SIZE as u32).unwrap();
let (submitter, mut sq, mut cq) = ring.split();
submitter.register_files(&[socket.as_raw_fd()]).unwrap();
let fd = Fixed(0);
loop {
while let Some(entry) = cq.next() {
let user_data: UserData = entry.user_data().into();
match user_data {
UserData::RecvMsg { slab_key } => {
recv_entries.remove(slab_key);
let result = entry.result();
if result < 0 {
::log::info!("recvmsg error {}: {:#}", result, ::std::io::Error::from_raw_os_error(-result));
} else if result == 0 {
::log::info!("recvmsg error: 0 bytes read");
} else {
let buffer_index = user_data.get_buffer_index();
let buffer_len = result as usize;
let src = SocketAddrV4::new(
Ipv4Addr::from(u32::from_be(sockaddrs_ipv4[buffer_index].sin_addr.s_addr)),
u16::from_be(sockaddrs_ipv4[buffer_index].sin_port),
);
let res_request =
Request::from_bytes(&buffers[buffer_index][..buffer_len], config.protocol.max_scrape_torrents);
handle_request(
&config,
&state,
&mut connections,
&mut access_list_cache,
&mut rng,
&request_sender,
&mut local_responses,
res_request,
SocketAddr::V4(src),
);
}
}
UserData::SendMsg { slab_key } => {
send_entries.remove(slab_key);
if entry.result() < 0 {
::log::info!("recvmsg error: {:#}", ::std::io::Error::from_raw_os_error(-entry.result()));
}
}
UserData::Timeout => {
timeout_set = false;
}
}
}
for _ in 0..(MAX_RECV_EVENTS - recv_entries.len()) {
let slab_key = recv_entries.insert(());
let user_data = UserData::RecvMsg { slab_key };
let buffer_index = user_data.get_buffer_index();
let buf_ptr: *mut msghdr = &mut msghdrs[buffer_index];
let entry = io_uring::opcode::RecvMsg::new(fd, buf_ptr).build().user_data(user_data.into());
unsafe {
sq.push(&entry).unwrap();
}
}
if !timeout_set {
let user_data = UserData::Timeout;
let timespec_ptr: *const Timespec = &timeout;
let entry = io_uring::opcode::Timeout::new(timespec_ptr).build().user_data(user_data.into());
unsafe {
sq.push(&entry).unwrap();
}
timeout_set = true;
}
let num_local_to_queue = (MAX_SEND_EVENTS - send_entries.len()).min(local_responses.len());
for (response, addr) in local_responses.drain(local_responses.len() - num_local_to_queue..) {
queue_response(&mut sq, fd, &mut send_entries, &mut buffers, &mut iovs, &mut sockaddrs_ipv4, &mut msghdrs, response, addr);
}
for (response, addr) in response_receiver.try_iter().take(MAX_SEND_EVENTS - send_entries.len()) {
queue_response(&mut sq, fd, &mut send_entries, &mut buffers, &mut iovs, &mut sockaddrs_ipv4, &mut msghdrs, response.into(), addr);
}
if iter_counter % 32 == 0 {
let now = Instant::now();
if now > last_cleaning + cleaning_duration {
connections.clean();
last_cleaning = now;
}
}
let all_responses_sent = local_responses.is_empty() & response_receiver.is_empty();
let wait_for_num = if all_responses_sent {
send_entries.len() + recv_entries.len()
} else {
send_entries.len()
};
sq.sync();
submitter.submit_and_wait(wait_for_num).unwrap();
sq.sync();
cq.sync();
iter_counter = iter_counter.wrapping_add(1);
}
}
fn queue_response(
sq: &mut SubmissionQueue,
fd: Fixed,
send_events: &mut Slab<()>,
buffers: &mut [[u8; MAX_PACKET_SIZE]],
iovs: &mut [iovec],
sockaddrs: &mut [sockaddr_in],
msghdrs: &mut [msghdr],
response: Response,
src: SocketAddr,
) {
let slab_key = send_events.insert(());
let user_data = UserData::SendMsg { slab_key };
let buffer_index = user_data.get_buffer_index();
let mut cursor = Cursor::new(&mut buffers[buffer_index][..]);
match response.write(&mut cursor, ip_version_from_ip(src.ip())) {
Ok(()) => {
iovs[buffer_index].iov_len = cursor.position() as usize;
let src = if let SocketAddr::V4(src) = src {
src
} else {
return; // FIXME
};
sockaddrs[buffer_index].sin_addr.s_addr = u32::to_be((*src.ip()).into());
sockaddrs[buffer_index].sin_port = u16::to_be(src.port());
}
Err(err) => {
::log::error!("Response::write error: {:?}", err);
}
}
let buf_ptr: *mut msghdr = &mut msghdrs[buffer_index];
let entry = io_uring::opcode::SendMsg::new(fd, buf_ptr).build().user_data(user_data.into());
unsafe {
sq.push(&entry).unwrap();
}
}
fn create_socket(config: &Config) -> ::std::net::UdpSocket {
let socket = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
} else {
Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))
}
.expect("create socket");
socket.set_reuse_port(true).expect("socket: set reuse port");
socket
.set_nonblocking(true)
.expect("socket: set nonblocking");
socket
.bind(&config.network.address.into())
.unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err));
let recv_buffer_size = config.network.socket_recv_buffer_size;
if recv_buffer_size != 0 {
if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) {
::log::error!(
"socket: failed setting recv buffer to {}: {:?}",
recv_buffer_size,
err
);
}
}
socket.into()
}
#[inline]
fn handle_request(
config: &Config,
state: &State,
connections: &mut ConnectionMap,
access_list_cache: &mut AccessListCache,
rng: &mut StdRng,
request_sender: &Sender<(ConnectedRequest, SocketAddr)>,
local_responses: &mut Vec<(Response, SocketAddr)>,
res_request: Result<Request, RequestParseError>,
src: SocketAddr,
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode;
match res_request {
Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen());
connections.insert(connection_id, src, valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) {
if access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
if let Err(err) = request_sender
.try_send((ConnectedRequest::Announce(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) {
let request = ConnectedRequest::Scrape {
request,
original_indices: Vec::new(),
};
if let Err(err) = request_sender.try_send((request, src)) {
::log::warn!("request_sender.try_send failed: {:?}", err)
}
}
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if connections.contains(connection_id, src) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_responses.push((response.into(), src));
}
}
}
}
}
fn ip_version_from_ip(ip: IpAddr) -> IpVersion {
match ip {
IpAddr::V4(_) => IpVersion::IPv4,
IpAddr::V6(ip) => {
if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() {
IpVersion::IPv4
} else {
IpVersion::IPv6
}
}
}
}
#[cfg(test)]
mod tests {
use quickcheck::Arbitrary;
use quickcheck_macros::quickcheck;
use super::*;
impl quickcheck::Arbitrary for UserData {
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
match (bool::arbitrary(g), bool::arbitrary(g)) {
(false, b) => {
let slab_key: u32 = Arbitrary::arbitrary(g);
let slab_key = slab_key as usize;
if b {
UserData::RecvMsg {
slab_key
}
} else {
UserData::SendMsg {
slab_key
}
}
}
_ => {
UserData::Timeout
}
}
}
}
#[quickcheck]
fn test_user_data_identity(a: UserData) -> bool {
let n: u64 = a.into();
let b = UserData::from(n);
a == b
}
}