From 2e67f11caf1978ad1c462b30e4fad96fbe9f082a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Tue, 7 Mar 2023 19:01:37 +0100 Subject: [PATCH] udp: add experimental io_uring implementation (#131) * WIP: add udp uring support * WIP: fix udp uring address parsing * WIP: udp uring: resubmit recv when needed * WIP: udp uring: add OutMessageStorage, send swarm responses * WIP: udp uring: increase ring entries to 1024 * WIP: udp uring: add constants * WIP: udp uring: use sqpoll, avoid kernel calls * WIP: udp uring: disable sqpoll * WIP: udp uring: use VecDeque for local responses * udp uring: enable setup_coop_taskrun * udp uring: add RecvMsgStorage * udp: improve split of uring and mio implementations * udp uring: clean up * udp uring: initial ipv6 support * udp uring: improve helper structs * udp uring: clean up, use constants for important data * udp: share create_socket fn between implementations * udp uring: improve send buffer free index finding * udp uring: work on SendBuffers.try_add * udp uring: split into modules * udp uring: Rename RecvMsgMultiHelper to RecvHelper * udp uring: improve SendBuffers * udp uring: fix copyright attribution in buf_ring module * udp uring: stop always consuming 100% cpu * udp uring: clean up * udp uring: add handle_recv_cqe * udp uring: move local_responses into SocketWorker * udp uring: move timeout_timespec into SocketWorker * Update TODO * udp: make io-uring optional * Update TODO * udp uring: enqueue timeout before sends * udp uring: move likely empty buffer tracking logic into SendBuffers * udp uring: improve error handling and logging * udp uring: keep one timeout submitted at a time * udp uring: update pending_scrape_valid_until * udp uring: add second timeout for cleaning * Update TODO * udp uring: store resubmittable squeue entries in a Vec * udp uring: add comment, remove a log statement * Update TODO * Update TODO * udp: io_uring: fall back to mio if io_uring support not recent enough * udp: uring: add bytes_received statistics * udp: uring: add bytes_sent statistics * udp: uring: add more statistics * Update TODO * udp: uring: improve SendBuffers code * udp: uring: remove unneeded squeue sync calls * udp: uring: replace buf_ring impl with one from tokio-uring * udp: uring: store ring in TLS so it can be used in Drop impls * udp: uring: store BufRing in SocketWorker * udp: uring: silence buf_ring dead code warnings, improve comment * Update TODO * udp: uring: improve CurrentRing docs, use anonymous struct field * udp: uring: improve ring setup * udp: uring: get ipv6 working * udp: uring: make ring entry count configurable, use more send entries * udp: uring: log number of pending responses (info level) * udp: uring: improve comment on send_buffer_entries calculation * udp: improve config comments * udp: uring: add to responses stats when they are confirmed as sent * Update TODO * udp: uring: enable IoUring setup_submit_all * Update README --- Cargo.lock | 11 + README.md | 13 +- TODO.md | 7 +- aquatic_udp/Cargo.toml | 2 + aquatic_udp/src/config.rs | 9 + aquatic_udp/src/lib.rs | 5 +- aquatic_udp/src/workers/socket/mio.rs | 432 ++++++++ aquatic_udp/src/workers/socket/mod.rs | 465 +-------- .../src/workers/socket/uring/buf_ring.rs | 947 ++++++++++++++++++ aquatic_udp/src/workers/socket/uring/mod.rs | 508 ++++++++++ .../src/workers/socket/uring/recv_helper.rs | 121 +++ .../src/workers/socket/uring/send_buffers.rs | 221 ++++ scripts/run-aquatic-udp.sh | 2 +- 13 files changed, 2315 insertions(+), 428 deletions(-) create mode 100644 aquatic_udp/src/workers/socket/mio.rs create mode 100644 aquatic_udp/src/workers/socket/uring/buf_ring.rs create mode 100644 aquatic_udp/src/workers/socket/uring/mod.rs create mode 100644 aquatic_udp/src/workers/socket/uring/recv_helper.rs create mode 100644 aquatic_udp/src/workers/socket/uring/send_buffers.rs diff --git a/Cargo.lock b/Cargo.lock index 6764f9f..b61227a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -229,6 +229,7 @@ dependencies = [ "hashbrown 0.13.2", "hdrhistogram", "hex", + "io-uring", "libc", "log", "metrics", @@ -1489,6 +1490,16 @@ dependencies = [ "memoffset 0.8.0", ] +[[package]] +name = "io-uring" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd1e1a01cfb924fd8c5c43b6827965db394f5a3a16c599ce03452266e1cf984c" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + [[package]] name = "ipnet" version = "2.7.1" diff --git a/README.md b/README.md index c0a938d..ee8c52c 100644 --- a/README.md +++ b/README.md @@ -196,12 +196,15 @@ This is the most mature of the implementations. I consider it ready for producti More details are available [here](./documents/aquatic-udp-load-test-2023-01-11.pdf). -#### Optimisation attempts that didn't work out +#### io_uring -* Using glommio -* Using io-uring -* Using zerocopy + vectored sends for responses -* Using sendmmsg +An experimental io_uring backend can be compiled in by passing the `io-uring` +feature. Currently, Linux 6.0 or later is required. The application will +attempt to fall back to the mio backend if your kernel is not supported. + +```sh +cargo build --release -p aquatic_udp --features "io-uring" +``` ### aquatic_http: HTTP BitTorrent tracker diff --git a/TODO.md b/TODO.md index e86bada..2e3608b 100644 --- a/TODO.md +++ b/TODO.md @@ -2,6 +2,12 @@ ## High priority +* udp uring + * uneven performance? + * thiserror? + * CI + * uring load test? + * ws: wait for crates release of glommio with membarrier fix (PR #558) * Release new version * More non-CI integration tests? @@ -9,7 +15,6 @@ ## Medium priority -* Consider replacing unmaintained indexmap-amortized with plain indexmap * Run cargo-fuzz on protocol crates * udp: support link to arbitrary homepage as well as embedded tracker URL in statistics page diff --git a/aquatic_udp/Cargo.toml b/aquatic_udp/Cargo.toml index 505595a..099caa4 100644 --- a/aquatic_udp/Cargo.toml +++ b/aquatic_udp/Cargo.toml @@ -20,6 +20,7 @@ name = "aquatic_udp" default = ["prometheus"] cpu-pinning = ["aquatic_common/hwloc"] prometheus = ["metrics", "metrics-exporter-prometheus"] +io-uring = ["dep:io-uring"] [dependencies] aquatic_common.workspace = true @@ -35,6 +36,7 @@ getrandom = "0.2" hashbrown = { version = "0.13", default-features = false } hdrhistogram = "7" hex = "0.4" +io-uring = { version = "0.5", optional = true } libc = "0.2" log = "0.4" metrics = { version = "0.20", optional = true } diff --git a/aquatic_udp/src/config.rs b/aquatic_udp/src/config.rs index cf310d9..c5da489 100644 --- a/aquatic_udp/src/config.rs +++ b/aquatic_udp/src/config.rs @@ -85,8 +85,15 @@ pub struct NetworkConfig { /// $ sudo sysctl -w net.core.rmem_max=104857600 /// $ sudo sysctl -w net.core.rmem_default=104857600 pub socket_recv_buffer_size: usize, + /// Poll event capacity (mio backend) pub poll_event_capacity: usize, + /// Poll timeout in milliseconds (mio backend) pub poll_timeout_ms: u64, + /// Number of ring entries (io_uring backend) + /// + /// Will be rounded to next power of two if not already one + #[cfg(feature = "io-uring")] + pub ring_entries: u16, /// Store this many responses at most for retrying (once) on send failure /// /// Useful on operating systems that do not provide an udp send buffer, @@ -112,6 +119,8 @@ impl Default for NetworkConfig { socket_recv_buffer_size: 4096 * 128, poll_event_capacity: 4096, poll_timeout_ms: 50, + #[cfg(feature = "io-uring")] + ring_entries: 1024, resend_buffer_max_len: 0, } } diff --git a/aquatic_udp/src/lib.rs b/aquatic_udp/src/lib.rs index 10b65ef..a869c4f 100644 --- a/aquatic_udp/src/lib.rs +++ b/aquatic_udp/src/lib.rs @@ -20,8 +20,7 @@ use common::{ ConnectedRequestSender, ConnectedResponseSender, SocketWorkerIndex, State, SwarmWorkerIndex, }; use config::Config; -use workers::socket::validator::ConnectionValidator; -use workers::socket::SocketWorker; +use workers::socket::ConnectionValidator; pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker"; pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -122,7 +121,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { WorkerIndex::SocketWorker(i), ); - SocketWorker::run( + workers::socket::run_socket_worker( sentinel, state, config, diff --git a/aquatic_udp/src/workers/socket/mio.rs b/aquatic_udp/src/workers/socket/mio.rs new file mode 100644 index 0000000..c248b3c --- /dev/null +++ b/aquatic_udp/src/workers/socket/mio.rs @@ -0,0 +1,432 @@ +use std::io::{Cursor, ErrorKind}; +use std::sync::atomic::Ordering; +use std::time::{Duration, Instant}; + +use aquatic_common::access_list::AccessListCache; +use aquatic_common::ServerStartInstant; +use crossbeam_channel::Receiver; +use mio::net::UdpSocket; +use mio::{Events, Interest, Poll, Token}; + +use aquatic_common::{ + access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr, + PanicSentinel, ValidUntil, +}; +use aquatic_udp_protocol::*; + +use crate::common::*; +use crate::config::Config; + +use super::storage::PendingScrapeResponseSlab; +use super::validator::ConnectionValidator; +use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; + +pub struct SocketWorker { + config: Config, + shared_state: State, + request_sender: ConnectedRequestSender, + response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + access_list_cache: AccessListCache, + validator: ConnectionValidator, + server_start_instant: ServerStartInstant, + pending_scrape_responses: PendingScrapeResponseSlab, + socket: UdpSocket, + buffer: [u8; BUFFER_SIZE], +} + +impl SocketWorker { + pub fn run( + _sentinel: PanicSentinel, + shared_state: State, + config: Config, + validator: ConnectionValidator, + server_start_instant: ServerStartInstant, + request_sender: ConnectedRequestSender, + response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + priv_dropper: PrivilegeDropper, + ) { + let socket = + UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket")); + let access_list_cache = create_access_list_cache(&shared_state.access_list); + + let mut worker = Self { + config, + shared_state, + validator, + server_start_instant, + request_sender, + response_receiver, + access_list_cache, + pending_scrape_responses: Default::default(), + socket, + buffer: [0; BUFFER_SIZE], + }; + + worker.run_inner(); + } + + pub fn run_inner(&mut self) { + let mut local_responses = Vec::new(); + let mut opt_resend_buffer = + (self.config.network.resend_buffer_max_len > 0).then_some(Vec::new()); + + let mut events = Events::with_capacity(self.config.network.poll_event_capacity); + let mut poll = Poll::new().expect("create poll"); + + poll.registry() + .register(&mut self.socket, Token(0), Interest::READABLE) + .expect("register poll"); + + let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms); + + let pending_scrape_cleaning_duration = + Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval); + + let mut pending_scrape_valid_until = ValidUntil::new( + self.server_start_instant, + self.config.cleaning.max_pending_scrape_age, + ); + let mut last_pending_scrape_cleaning = Instant::now(); + + let mut iter_counter = 0usize; + + loop { + poll.poll(&mut events, Some(poll_timeout)) + .expect("failed polling"); + + for event in events.iter() { + if event.is_readable() { + self.read_and_handle_requests(&mut local_responses, pending_scrape_valid_until); + } + } + + // If resend buffer is enabled, send any responses in it + if let Some(resend_buffer) = opt_resend_buffer.as_mut() { + for (response, addr) in resend_buffer.drain(..) { + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut None, + response, + addr, + ); + } + } + + // Send any connect and error responses generated by this socket worker + for (response, addr) in local_responses.drain(..) { + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut opt_resend_buffer, + response, + addr, + ); + } + + // Check channel for any responses generated by swarm workers + for (response, addr) in self.response_receiver.try_iter() { + let opt_response = match response { + ConnectedResponse::Scrape(r) => self + .pending_scrape_responses + .add_and_get_finished(r) + .map(Response::Scrape), + ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)), + ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)), + }; + + if let Some(response) = opt_response { + Self::send_response( + &self.config, + &self.shared_state, + &mut self.socket, + &mut self.buffer, + &mut opt_resend_buffer, + response, + addr, + ); + } + } + + // Run periodic ValidUntil updates and state cleaning + if iter_counter % 256 == 0 { + let seconds_since_start = self.server_start_instant.seconds_elapsed(); + + pending_scrape_valid_until = ValidUntil::new_with_now( + seconds_since_start, + self.config.cleaning.max_pending_scrape_age, + ); + + let now = Instant::now(); + + if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { + self.pending_scrape_responses.clean(seconds_since_start); + + last_pending_scrape_cleaning = now; + } + } + + iter_counter = iter_counter.wrapping_add(1); + } + } + + fn read_and_handle_requests( + &mut self, + local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, + pending_scrape_valid_until: ValidUntil, + ) { + let mut requests_received_ipv4: usize = 0; + let mut requests_received_ipv6: usize = 0; + let mut bytes_received_ipv4: usize = 0; + let mut bytes_received_ipv6 = 0; + + loop { + match self.socket.recv_from(&mut self.buffer[..]) { + Ok((bytes_read, src)) => { + if src.port() == 0 { + ::log::info!("Ignored request from {} because source port is zero", src); + + continue; + } + + let src = CanonicalSocketAddr::new(src); + + let request_parsable = match Request::from_bytes( + &self.buffer[..bytes_read], + self.config.protocol.max_scrape_torrents, + ) { + Ok(request) => { + self.handle_request( + local_responses, + pending_scrape_valid_until, + request, + src, + ); + + true + } + Err(err) => { + ::log::debug!("Request::from_bytes error: {:?}", err); + + if let RequestParseError::Sendable { + connection_id, + transaction_id, + err, + } = err + { + if self.validator.connection_id_valid(src, connection_id) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + local_responses.push((response.into(), src)); + } + } + + false + } + }; + + // Update statistics for converted address + if src.is_ipv4() { + if request_parsable { + requests_received_ipv4 += 1; + } + bytes_received_ipv4 += bytes_read + EXTRA_PACKET_SIZE_IPV4; + } else { + if request_parsable { + requests_received_ipv6 += 1; + } + bytes_received_ipv6 += bytes_read + EXTRA_PACKET_SIZE_IPV6; + } + } + Err(err) if err.kind() == ErrorKind::WouldBlock => { + break; + } + Err(err) => { + ::log::warn!("recv_from error: {:#}", err); + } + } + } + + if self.config.statistics.active() { + self.shared_state + .statistics_ipv4 + .requests_received + .fetch_add(requests_received_ipv4, Ordering::Relaxed); + self.shared_state + .statistics_ipv6 + .requests_received + .fetch_add(requests_received_ipv6, Ordering::Relaxed); + self.shared_state + .statistics_ipv4 + .bytes_received + .fetch_add(bytes_received_ipv4, Ordering::Relaxed); + self.shared_state + .statistics_ipv6 + .bytes_received + .fetch_add(bytes_received_ipv6, Ordering::Relaxed); + } + } + + fn handle_request( + &mut self, + local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, + pending_scrape_valid_until: ValidUntil, + request: Request, + src: CanonicalSocketAddr, + ) { + let access_list_mode = self.config.access_list.mode; + + match request { + Request::Connect(request) => { + let connection_id = self.validator.create_connection_id(src); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, + }); + + local_responses.push((response, src)) + } + Request::Announce(request) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + if self + .access_list_cache + .load() + .allows(access_list_mode, &request.info_hash.0) + { + let worker_index = + SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); + + self.request_sender.try_send_to( + worker_index, + ConnectedRequest::Announce(request), + src, + ); + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + local_responses.push((response, src)) + } + } + } + Request::Scrape(request) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + let split_requests = self.pending_scrape_responses.prepare_split_requests( + &self.config, + request, + pending_scrape_valid_until, + ); + + for (swarm_worker_index, request) in split_requests { + self.request_sender.try_send_to( + swarm_worker_index, + ConnectedRequest::Scrape(request), + src, + ); + } + } + } + } + } + + fn send_response( + config: &Config, + shared_state: &State, + socket: &mut UdpSocket, + buffer: &mut [u8], + opt_resend_buffer: &mut Option>, + response: Response, + canonical_addr: CanonicalSocketAddr, + ) { + let mut cursor = Cursor::new(buffer); + + if let Err(err) = response.write(&mut cursor) { + ::log::error!("Converting response to bytes failed: {:#}", err); + + return; + } + + let bytes_written = cursor.position() as usize; + + let addr = if config.network.address.is_ipv4() { + canonical_addr + .get_ipv4() + .expect("found peer ipv6 address while running bound to ipv4 address") + } else { + canonical_addr.get_ipv6_mapped() + }; + + match socket.send_to(&cursor.get_ref()[..bytes_written], addr) { + Ok(amt) if config.statistics.active() => { + let stats = if canonical_addr.is_ipv4() { + let stats = &shared_state.statistics_ipv4; + + stats + .bytes_sent + .fetch_add(amt + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); + + stats + } else { + let stats = &shared_state.statistics_ipv6; + + stats + .bytes_sent + .fetch_add(amt + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + + stats + }; + + match response { + Response::Connect(_) => { + stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed); + } + Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { + stats + .responses_sent_announce + .fetch_add(1, Ordering::Relaxed); + } + Response::Scrape(_) => { + stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed); + } + Response::Error(_) => { + stats.responses_sent_error.fetch_add(1, Ordering::Relaxed); + } + } + } + Ok(_) => (), + Err(err) => match opt_resend_buffer.as_mut() { + Some(resend_buffer) + if (err.raw_os_error() == Some(libc::ENOBUFS)) + || (err.kind() == ErrorKind::WouldBlock) => + { + if resend_buffer.len() < config.network.resend_buffer_max_len { + ::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); + + resend_buffer.push((response, canonical_addr)); + } else { + ::log::warn!("Response resend buffer full, dropping response"); + } + } + _ => { + ::log::warn!("Sending response to {} failed: {:#}", addr, err); + } + }, + } + } +} diff --git a/aquatic_udp/src/workers/socket/mod.rs b/aquatic_udp/src/workers/socket/mod.rs index c1da08d..45bd030 100644 --- a/aquatic_udp/src/workers/socket/mod.rs +++ b/aquatic_udp/src/workers/socket/mod.rs @@ -1,29 +1,22 @@ +mod mio; mod storage; -pub mod validator; - -use std::io::{Cursor, ErrorKind}; -use std::sync::atomic::Ordering; -use std::time::{Duration, Instant}; +#[cfg(feature = "io-uring")] +mod uring; +mod validator; use anyhow::Context; -use aquatic_common::access_list::AccessListCache; -use aquatic_common::ServerStartInstant; +use aquatic_common::{ + privileges::PrivilegeDropper, CanonicalSocketAddr, PanicSentinel, ServerStartInstant, +}; use crossbeam_channel::Receiver; -use mio::net::UdpSocket; -use mio::{Events, Interest, Poll, Token}; use socket2::{Domain, Protocol, Socket, Type}; -use aquatic_common::{ - access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr, - PanicSentinel, ValidUntil, +use crate::{ + common::{ConnectedRequestSender, ConnectedResponse, State}, + config::Config, }; -use aquatic_udp_protocol::*; -use crate::common::*; -use crate::config::Config; - -use storage::PendingScrapeResponseSlab; -use validator::ConnectionValidator; +pub use self::validator::ConnectionValidator; /// Bytes of data transmitted when sending an IPv4 UDP packet, in addition to payload size /// @@ -43,414 +36,50 @@ const EXTRA_PACKET_SIZE_IPV4: usize = 8 + 18 + 20 + 8; /// - 8 bit udp header const EXTRA_PACKET_SIZE_IPV6: usize = 8 + 18 + 40 + 8; -pub struct SocketWorker { - config: Config, +pub fn run_socket_worker( + sentinel: PanicSentinel, shared_state: State, - request_sender: ConnectedRequestSender, - response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, - access_list_cache: AccessListCache, + config: Config, validator: ConnectionValidator, server_start_instant: ServerStartInstant, - pending_scrape_responses: PendingScrapeResponseSlab, - socket: UdpSocket, - buffer: [u8; BUFFER_SIZE], -} - -impl SocketWorker { - pub fn run( - _sentinel: PanicSentinel, - shared_state: State, - config: Config, - validator: ConnectionValidator, - server_start_instant: ServerStartInstant, - request_sender: ConnectedRequestSender, - response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, - priv_dropper: PrivilegeDropper, - ) { - let socket = - UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket")); - let access_list_cache = create_access_list_cache(&shared_state.access_list); - - let mut worker = Self { - config, - shared_state, - validator, - server_start_instant, - request_sender, - response_receiver, - access_list_cache, - pending_scrape_responses: Default::default(), - socket, - buffer: [0; BUFFER_SIZE], - }; - - worker.run_inner(); - } - - pub fn run_inner(&mut self) { - let mut local_responses = Vec::new(); - let mut opt_resend_buffer = - (self.config.network.resend_buffer_max_len > 0).then_some(Vec::new()); - - let mut events = Events::with_capacity(self.config.network.poll_event_capacity); - let mut poll = Poll::new().expect("create poll"); - - poll.registry() - .register(&mut self.socket, Token(0), Interest::READABLE) - .expect("register poll"); - - let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms); - - let pending_scrape_cleaning_duration = - Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval); - - let mut pending_scrape_valid_until = ValidUntil::new( - self.server_start_instant, - self.config.cleaning.max_pending_scrape_age, - ); - let mut last_pending_scrape_cleaning = Instant::now(); - - let mut iter_counter = 0usize; - - loop { - poll.poll(&mut events, Some(poll_timeout)) - .expect("failed polling"); - - for event in events.iter() { - if event.is_readable() { - self.read_and_handle_requests(&mut local_responses, pending_scrape_valid_until); - } - } - - // If resend buffer is enabled, send any responses in it - if let Some(resend_buffer) = opt_resend_buffer.as_mut() { - for (response, addr) in resend_buffer.drain(..) { - Self::send_response( - &self.config, - &self.shared_state, - &mut self.socket, - &mut self.buffer, - &mut None, - response, - addr, - ); - } - } - - // Send any connect and error responses generated by this socket worker - for (response, addr) in local_responses.drain(..) { - Self::send_response( - &self.config, - &self.shared_state, - &mut self.socket, - &mut self.buffer, - &mut opt_resend_buffer, - response, - addr, - ); - } - - // Check channel for any responses generated by swarm workers - for (response, addr) in self.response_receiver.try_iter() { - let opt_response = match response { - ConnectedResponse::Scrape(r) => self - .pending_scrape_responses - .add_and_get_finished(r) - .map(Response::Scrape), - ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)), - ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)), - }; - - if let Some(response) = opt_response { - Self::send_response( - &self.config, - &self.shared_state, - &mut self.socket, - &mut self.buffer, - &mut opt_resend_buffer, - response, - addr, - ); - } - } - - // Run periodic ValidUntil updates and state cleaning - if iter_counter % 256 == 0 { - let seconds_since_start = self.server_start_instant.seconds_elapsed(); - - pending_scrape_valid_until = ValidUntil::new_with_now( - seconds_since_start, - self.config.cleaning.max_pending_scrape_age, - ); - - let now = Instant::now(); - - if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { - self.pending_scrape_responses.clean(seconds_since_start); - - last_pending_scrape_cleaning = now; - } - } - - iter_counter = iter_counter.wrapping_add(1); - } - } - - fn read_and_handle_requests( - &mut self, - local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, - pending_scrape_valid_until: ValidUntil, - ) { - let mut requests_received_ipv4: usize = 0; - let mut requests_received_ipv6: usize = 0; - let mut bytes_received_ipv4: usize = 0; - let mut bytes_received_ipv6 = 0; - - loop { - match self.socket.recv_from(&mut self.buffer[..]) { - Ok((bytes_read, src)) => { - if src.port() == 0 { - ::log::info!("Ignored request from {} because source port is zero", src); - - continue; - } - - let src = CanonicalSocketAddr::new(src); - - let request_parsable = match Request::from_bytes( - &self.buffer[..bytes_read], - self.config.protocol.max_scrape_torrents, - ) { - Ok(request) => { - self.handle_request( - local_responses, - pending_scrape_valid_until, - request, - src, - ); - - true - } - Err(err) => { - ::log::debug!("Request::from_bytes error: {:?}", err); - - if let RequestParseError::Sendable { - connection_id, - transaction_id, - err, - } = err - { - if self.validator.connection_id_valid(src, connection_id) { - let response = ErrorResponse { - transaction_id, - message: err.right_or("Parse error").into(), - }; - - local_responses.push((response.into(), src)); - } - } - - false - } - }; - - // Update statistics for converted address - if src.is_ipv4() { - if request_parsable { - requests_received_ipv4 += 1; - } - bytes_received_ipv4 += bytes_read + EXTRA_PACKET_SIZE_IPV4; - } else { - if request_parsable { - requests_received_ipv6 += 1; - } - bytes_received_ipv6 += bytes_read + EXTRA_PACKET_SIZE_IPV6; - } - } - Err(err) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(err) => { - ::log::warn!("recv_from error: {:#}", err); - } - } - } - - if self.config.statistics.active() { - self.shared_state - .statistics_ipv4 - .requests_received - .fetch_add(requests_received_ipv4, Ordering::Relaxed); - self.shared_state - .statistics_ipv6 - .requests_received - .fetch_add(requests_received_ipv6, Ordering::Relaxed); - self.shared_state - .statistics_ipv4 - .bytes_received - .fetch_add(bytes_received_ipv4, Ordering::Relaxed); - self.shared_state - .statistics_ipv6 - .bytes_received - .fetch_add(bytes_received_ipv6, Ordering::Relaxed); - } - } - - fn handle_request( - &mut self, - local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, - pending_scrape_valid_until: ValidUntil, - request: Request, - src: CanonicalSocketAddr, - ) { - let access_list_mode = self.config.access_list.mode; - - match request { - Request::Connect(request) => { - let connection_id = self.validator.create_connection_id(src); - - let response = Response::Connect(ConnectResponse { - connection_id, - transaction_id: request.transaction_id, - }); - - local_responses.push((response, src)) - } - Request::Announce(request) => { - if self - .validator - .connection_id_valid(src, request.connection_id) - { - if self - .access_list_cache - .load() - .allows(access_list_mode, &request.info_hash.0) - { - let worker_index = - SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); - - self.request_sender.try_send_to( - worker_index, - ConnectedRequest::Announce(request), - src, - ); - } else { - let response = Response::Error(ErrorResponse { - transaction_id: request.transaction_id, - message: "Info hash not allowed".into(), - }); - - local_responses.push((response, src)) - } - } - } - Request::Scrape(request) => { - if self - .validator - .connection_id_valid(src, request.connection_id) - { - let split_requests = self.pending_scrape_responses.prepare_split_requests( - &self.config, - request, - pending_scrape_valid_until, - ); - - for (swarm_worker_index, request) in split_requests { - self.request_sender.try_send_to( - swarm_worker_index, - ConnectedRequest::Scrape(request), - src, - ); - } - } - } - } - } - - fn send_response( - config: &Config, - shared_state: &State, - socket: &mut UdpSocket, - buffer: &mut [u8], - opt_resend_buffer: &mut Option>, - response: Response, - canonical_addr: CanonicalSocketAddr, - ) { - let mut cursor = Cursor::new(buffer); - - if let Err(err) = response.write(&mut cursor) { - ::log::error!("Converting response to bytes failed: {:#}", err); + request_sender: ConnectedRequestSender, + response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + priv_dropper: PrivilegeDropper, +) { + #[cfg(feature = "io-uring")] + match self::uring::supported_on_current_kernel() { + Ok(()) => { + self::uring::SocketWorker::run( + sentinel, + shared_state, + config, + validator, + server_start_instant, + request_sender, + response_receiver, + priv_dropper, + ); return; } - - let bytes_written = cursor.position() as usize; - - let addr = if config.network.address.is_ipv4() { - canonical_addr - .get_ipv4() - .expect("found peer ipv6 address while running bound to ipv4 address") - } else { - canonical_addr.get_ipv6_mapped() - }; - - match socket.send_to(&cursor.get_ref()[..bytes_written], addr) { - Ok(amt) if config.statistics.active() => { - let stats = if canonical_addr.is_ipv4() { - let stats = &shared_state.statistics_ipv4; - - stats - .bytes_sent - .fetch_add(amt + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); - - stats - } else { - let stats = &shared_state.statistics_ipv6; - - stats - .bytes_sent - .fetch_add(amt + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); - - stats - }; - - match response { - Response::Connect(_) => { - stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed); - } - Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { - stats - .responses_sent_announce - .fetch_add(1, Ordering::Relaxed); - } - Response::Scrape(_) => { - stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed); - } - Response::Error(_) => { - stats.responses_sent_error.fetch_add(1, Ordering::Relaxed); - } - } - } - Ok(_) => (), - Err(err) => match opt_resend_buffer.as_mut() { - Some(resend_buffer) - if (err.raw_os_error() == Some(libc::ENOBUFS)) - || (err.kind() == ErrorKind::WouldBlock) => - { - if resend_buffer.len() < config.network.resend_buffer_max_len { - ::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); - - resend_buffer.push((response, canonical_addr)); - } else { - ::log::warn!("Response resend buffer full, dropping response"); - } - } - _ => { - ::log::warn!("Sending response to {} failed: {:#}", addr, err); - } - }, + Err(err) => { + ::log::warn!( + "Falling back to mio because of lacking kernel io_uring support: {:#}", + err + ); } } + + self::mio::SocketWorker::run( + sentinel, + shared_state, + config, + validator, + server_start_instant, + request_sender, + response_receiver, + priv_dropper, + ); } fn create_socket( diff --git a/aquatic_udp/src/workers/socket/uring/buf_ring.rs b/aquatic_udp/src/workers/socket/uring/buf_ring.rs new file mode 100644 index 0000000..7862753 --- /dev/null +++ b/aquatic_udp/src/workers/socket/uring/buf_ring.rs @@ -0,0 +1,947 @@ +// Copyright (c) 2021 Carl Lerche +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +// Copied (with slight modifications) from +// - https://github.com/FrankReh/tokio-uring/tree/9387c92c98138451f7d760432a04b0b95a406f22/src/buf/bufring +// - https://github.com/FrankReh/tokio-uring/blob/9387c92c98138451f7d760432a04b0b95a406f22/src/buf/bufgroup/mod.rs + +//! Module for the io_uring device's buf_ring feature. + +// Developer's note about io_uring return codes when a buf_ring is used: +// +// While a buf_ring pool is exhaused, new calls to read that are, or are not, ready to read will +// fail with the 105 error, "no buffers", while existing calls that were waiting to become ready to +// read will not fail. Only when the data becomes ready to read will they fail, if the buffer ring +// is still empty at that time. This makes sense when thinking about it from how the kernel +// implements the start of a read command; it can be confusing when first working with these +// commands from the userland perspective. + +// While the file! calls yield the clippy false positive. +#![allow(clippy::print_literal)] + +use io_uring::types; +use std::cell::Cell; +use std::io; +use std::rc::Rc; +use std::sync::atomic::{self, AtomicU16}; + +use super::CurrentRing; + +/// The buffer group ID. +/// +/// The creater of a buffer group is responsible for picking a buffer group id +/// that does not conflict with other buffer group ids also being registered with the uring +/// interface. +pub(crate) type Bgid = u16; + +// Future: Maybe create a bgid module with a trivial implementation of a type that tracks the next +// bgid to use. The crate's driver could do that perhaps, but there could be a benefit to tracking +// them across multiple thread's drivers. So there is flexibility in not building it into the +// driver. + +/// The buffer ID. Buffer ids are assigned and used by the crate and probably are not visible +/// to the crate user. +pub(crate) type Bid = u16; + +/// This tracks a buffer that has been filled in by the kernel, having gotten the memory +/// from a buffer ring, and returned to userland via a cqe entry. +pub struct BufX { + bgroup: BufRing, + bid: Bid, + len: usize, +} + +impl BufX { + // # Safety + // + // The bid must be the buffer id supplied by the kernel as having been chosen and written to. + // The length of the buffer must represent the length written to by the kernel. + pub(crate) unsafe fn new(bgroup: BufRing, bid: Bid, len: usize) -> Self { + // len will already have been checked against the buf_capacity + // so it is guaranteed that len <= bgroup.buf_capacity. + + Self { bgroup, bid, len } + } + + /// Return the number of bytes initialized. + /// + /// This value initially came from the kernel, as reported in the cqe. This value may have been + /// modified with a call to the IoBufMut::set_init method. + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Return true if this represents an empty buffer. The length reported by the kernel was 0. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Return the capacity of this buffer. + #[inline] + pub fn cap(&self) -> usize { + self.bgroup.buf_capacity(self.bid) + } + + /// Return a byte slice reference. + #[inline] + pub fn as_slice(&self) -> &[u8] { + let p = self.bgroup.stable_ptr(self.bid); + // Safety: the pointer returned by stable_ptr is valid for the lifetime of self, + // and self's len is set when the kernel reports the amount of data that was + // written into the buffer. + unsafe { std::slice::from_raw_parts(p, self.len) } + } + + /// Return a mutable byte slice reference. + #[inline] + pub fn as_slice_mut(&mut self) -> &mut [u8] { + let p = self.bgroup.stable_mut_ptr(self.bid); + // Safety: the pointer returned by stable_mut_ptr is valid for the lifetime of self, + // and self's len is set when the kernel reports the amount of data that was + // written into the buffer. In addition, we hold a &mut reference to self. + unsafe { std::slice::from_raw_parts_mut(p, self.len) } + } + + // Future: provide access to the uninit space between len and cap if the buffer is being + // repurposed before being dropped. The set_init below does that too. +} + +impl Drop for BufX { + fn drop(&mut self) { + // Add the buffer back to the bgroup, for the kernel to reuse. + // Safety: this function may only be called by the buffer's drop function. + unsafe { self.bgroup.dropping_bid(self.bid) }; + } +} + +/* +unsafe impl crate::buf::IoBuf for BufX { + fn stable_ptr(&self) -> *const u8 { + self.bgroup.stable_ptr(self.bid) + } + + fn bytes_init(&self) -> usize { + self.len + } + + fn bytes_total(&self) -> usize { + self.cap() + } +} + +unsafe impl crate::buf::IoBufMut for BufX { + fn stable_mut_ptr(&mut self) -> *mut u8 { + self.bgroup.stable_mut_ptr(self.bid) + } + + unsafe fn set_init(&mut self, init_len: usize) { + if self.len < init_len { + let cap = self.bgroup.buf_capacity(self.bid); + assert!(init_len <= cap); + self.len = init_len; + } + } +} +*/ + +impl From for Vec { + fn from(item: BufX) -> Self { + item.as_slice().to_vec() + } +} + +/// A `BufRing` represents the ring and the buffers used with the kernel's io_uring buf_ring +/// feature. +/// +/// In this implementation, it is both the ring of buffer entries and the actual buffer +/// allocations. +/// +/// A BufRing is created through the [`Builder`] and can be registered automatically by the +/// builder's `build` step or at a later time by the user. Registration involves informing the +/// kernel of the ring's dimensions and its identifier (its buffer group id, which goes by the name +/// `bgid`). +/// +/// Multiple buf_rings, here multiple BufRings, can be created and registered. BufRings are +/// reference counted to ensure their memory is live while their BufX buffers are live. When a BufX +/// buffer is dropped, it releases itself back to the BufRing from which it came allowing it to be +/// reused by the kernel. +/// +/// It is perhaps worth pointing out that it is the ring itself that is registered with the kernel, +/// not the buffers per se. While a given buf_ring cannot have it size changed dynamically, the +/// buffers that are pushed to the ring by userland, and later potentially re-pushed in the ring, +/// can change. The buffers can be of different sizes and they could come from different allocation +/// blocks. This implementation does not provide that flexibility. Each BufRing comes with its own +/// equal length buffer allocation. And when a BufRing buffer, a BufX, is dropped, its id is pushed +/// back to the ring. +/// +/// This is the one and only `Provided Buffers` implementation in `tokio_uring` at the moment and +/// in this version, is a purely concrete type, with a concrete BufX type for buffers that are +/// returned by operations like `recv_provbuf` to the userland application. +/// +/// Aside from the register and unregister steps, there are no syscalls used to pass buffers to the +/// kernel. The ring contains a tail memory address that this userland type updates as buffers are +/// added to the ring and which the kernel reads when it needs to pull a buffer from the ring. The +/// kernel does not have a head pointer address that it updates for the userland. The userland +/// (this type), is expected to avoid overwriting the head of the circular ring by keeping track of +/// how many buffers were added to the ring and how many have been returned through the CQE +/// mechanism. This particular implementation does not track the count because all buffers are +/// allocated at the beginning, by the builder, and only its own buffers that came back via a CQE +/// are ever added back to the ring, so it should be impossible to overflow the ring. +#[derive(Clone, Debug)] +pub struct BufRing { + // RawBufRing uses cell for fields where necessary. + raw: Rc, +} + +// Methods the BufX needs. + +impl BufRing { + pub(crate) fn buf_capacity(&self, _: Bid) -> usize { + self.raw.buf_capacity_i() + } + + pub(crate) fn stable_ptr(&self, bid: Bid) -> *const u8 { + // Will panic if bid is out of range. + self.raw.stable_ptr_i(bid) + } + + pub(crate) fn stable_mut_ptr(&mut self, bid: Bid) -> *mut u8 { + // Safety: self is &mut, we're good. + unsafe { self.raw.stable_mut_ptr_i(bid) } + } + + // # Safety + // + // `dropping_bid` should only be called by the buffer's drop function because once called, the + // buffer may be given back to the kernel for reuse. + pub(crate) unsafe fn dropping_bid(&self, bid: Bid) { + self.raw.dropping_bid_i(bid); + } +} + +// Methods the io operations need. + +impl BufRing { + pub(crate) fn bgid(&self) -> Bgid { + self.raw.bgid() + } + + // # Safety + // + // The res and flags values are used to lookup a buffer and set its initialized length. + // The caller is responsible for these being correct. This is expected to be called + // when these two values are received from the kernel via a CQE and we rely on the kernel to + // give us correct information. + pub(crate) unsafe fn get_buf(&self, res: u32, flags: u32) -> io::Result> { + let bid = match io_uring::cqueue::buffer_select(flags) { + Some(bid) => bid, + None => { + // Have seen res == 0, flags == 4 with a TCP socket. res == 0 we take to mean the + // socket is empty so return None to show there is no buffer returned, which should + // be interpreted to mean there is no more data to read from this file or socket. + if res == 0 { + return Ok(None); + } + + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "BufRing::get_buf failed as the buffer bit, IORING_CQE_F_BUFFER, was missing from flags, res = {}, flags = {}", + res, flags) + )); + } + }; + + let len = res as usize; + + /* + let flags = flags & !io_uring::sys::IORING_CQE_F_BUFFER; // for tracing flags + println!( + "{}:{}: get_buf res({res})=len({len}) flags({:#x})->bid({bid})\n\n", + file!(), + line!(), + flags + ); + */ + + assert!(len <= self.raw.buf_len); + + // TODO maybe later + // #[cfg(any(debug, feature = "cautious"))] + // { + // let mut debug_bitmap = self.debug_bitmap.borrow_mut(); + // let m = 1 << (bid % 8); + // assert!(debug_bitmap[(bid / 8) as usize] & m == m); + // debug_bitmap[(bid / 8) as usize] &= !m; + // } + + self.raw.metric_getting_another(); + /* + println!( + "{}:{}: get_buf cur {}, min {}", + file!(), + line!(), + self.possible_cur.get(), + self.possible_min.get(), + ); + */ + + // Safety: the len provided to BufX::new is given to us from the kernel. + Ok(Some(unsafe { BufX::new(self.clone(), bid, len) })) + } +} + +#[derive(Debug, Copy, Clone)] +/// Build the arguments to call build() that returns a [`BufRing`]. +/// +/// Refer to the methods descriptions for details. +#[allow(dead_code)] +pub struct Builder { + page_size: usize, + bgid: Bgid, + ring_entries: u16, + buf_cnt: u16, + buf_len: usize, + buf_align: usize, + ring_pad: usize, + bufend_align: usize, + + skip_register: bool, +} + +#[allow(dead_code)] +impl Builder { + /// Create a new Builder with the given buffer group ID and defaults. + /// + /// The buffer group ID, `bgid`, is the id the kernel's io_uring device uses to identify the + /// provided buffer pool to use by operations that are posted to the device. + /// + /// The user is responsible for picking a bgid that does not conflict with other buffer groups + /// that have been registered with the same uring interface. + pub fn new(bgid: Bgid) -> Builder { + Builder { + page_size: 4096, + bgid, + ring_entries: 128, + buf_cnt: 0, + buf_len: 4096, + buf_align: 0, + ring_pad: 0, + bufend_align: 0, + skip_register: false, + } + } + + /// The page size of the kernel. Defaults to 4096. + /// + /// The io_uring device requires the BufRing is allocated on the start of a page, i.e. with a + /// page size alignment. + /// + /// The caller should determine the page size, and may want to cache the info if multiple buf + /// rings are to be created. Crates are available to get this information or the user may want + /// to call the libc sysconf directly: + /// + /// use libc::{_SC_PAGESIZE, sysconf}; + /// let page_size: usize = unsafe { sysconf(_SC_PAGESIZE) as usize }; + pub fn page_size(mut self, page_size: usize) -> Builder { + self.page_size = page_size; + self + } + + /// The number of ring entries to create for the buffer ring. + /// + /// This defaults to 128 or the `buf_cnt`, whichever is larger. + /// + /// The number will be made a power of 2, and will be the maximum of the ring_entries setting + /// and the buf_cnt setting. The interface will enforce a maximum of 2^15 (32768) so it can do + /// rollover calculation. + /// + /// Each ring entry is 16 bytes. + pub fn ring_entries(mut self, ring_entries: u16) -> Builder { + self.ring_entries = ring_entries; + self + } + + /// The number of buffers to allocate. If left zero, the ring_entries value will be used and + /// that value defaults to 128. + pub fn buf_cnt(mut self, buf_cnt: u16) -> Builder { + self.buf_cnt = buf_cnt; + self + } + + /// The length of each allocated buffer. Defaults to 4096. + /// + /// Non-alignment values are possible and `buf_align` can be used to allocate each buffer on + /// an alignment buffer, even if the buffer length is not desired to equal the alignment. + pub fn buf_len(mut self, buf_len: usize) -> Builder { + self.buf_len = buf_len; + self + } + + /// The alignment of the first buffer allocated. + /// + /// Generally not needed. + /// + /// The buffers are allocated right after the ring unless `ring_pad` is used and generally the + /// buffers are allocated contiguous to one another unless the `buf_len` is set to something + /// different. + pub fn buf_align(mut self, buf_align: usize) -> Builder { + self.buf_align = buf_align; + self + } + + /// Pad to place after ring to ensure separation between rings and first buffer. + /// + /// Generally not needed but may be useful if the ring's end and the buffers' start are to have + /// some separation, perhaps for cacheline reasons. + pub fn ring_pad(mut self, ring_pad: usize) -> Builder { + self.ring_pad = ring_pad; + self + } + + /// The alignment of the end of the buffer allocated. To keep other things out of a cache line + /// or out of a page, if that's desired. + pub fn bufend_align(mut self, bufend_align: usize) -> Builder { + self.bufend_align = bufend_align; + self + } + + /// Skip automatic registration. The caller can manually invoke the buf_ring.register() + /// function later. Regardless, the unregister() method will be called automatically when the + /// BufRing goes out of scope if the caller hadn't manually called buf_ring.unregister() + /// already. + pub fn skip_auto_register(mut self, skip: bool) -> Builder { + self.skip_register = skip; + self + } + + /// Return a BufRing, having computed the layout for the single aligned allocation + /// of both the buffer ring elements and the buffers themselves. + /// + /// If auto_register was left enabled, register the BufRing with the driver. + pub fn build(&self) -> io::Result { + let mut b: Builder = *self; + + // Two cases where both buf_cnt and ring_entries are set to the max of the two. + if b.buf_cnt == 0 || b.ring_entries < b.buf_cnt { + let max = std::cmp::max(b.ring_entries, b.buf_cnt); + b.buf_cnt = max; + b.ring_entries = max; + } + + // Don't allow the next_power_of_two calculation to be done if already larger than 2^15 + // because 2^16 reads back as 0 in a u16. And the interface doesn't allow for ring_entries + // larger than 2^15 anyway, so this is a good place to catch it. Here we return a unique + // error that is more descriptive than the InvalidArg that would come from the interface. + if b.ring_entries > (1 << 15) { + return Err(io::Error::new( + io::ErrorKind::Other, + "ring_entries exceeded 32768", + )); + } + + // Requirement of the interface is the ring entries is a power of two, making its and our + // mask calculation trivial. + b.ring_entries = b.ring_entries.next_power_of_two(); + + Ok(BufRing { + raw: Rc::new(RawBufRing::new(NewArgs { + page_size: b.page_size, + bgid: b.bgid, + ring_entries: b.ring_entries, + buf_cnt: b.buf_cnt, + buf_len: b.buf_len, + buf_align: b.buf_align, + ring_pad: b.ring_pad, + bufend_align: b.bufend_align, + auto_register: !b.skip_register, + })?), + }) + } +} + +// Trivial helper struct for this module. +struct NewArgs { + page_size: usize, + bgid: Bgid, + ring_entries: u16, + buf_cnt: u16, + buf_len: usize, + buf_align: usize, + ring_pad: usize, + bufend_align: usize, + auto_register: bool, +} + +#[derive(Debug)] +struct RawBufRing { + bgid: Bgid, + + // Keep mask rather than ring size because mask is used often, ring size not. + //ring_entries: u16, // Invariants: > 0, power of 2, max 2^15 (32768). + ring_entries_mask: u16, // Invariant one less than ring_entries which is > 0, power of 2, max 2^15 (32768). + + buf_cnt: u16, // Invariants: > 0, <= ring_entries. + buf_len: usize, // Invariant: > 0. + layout: std::alloc::Layout, + ring_addr: *const types::BufRingEntry, // Invariant: constant. + buffers_addr: *mut u8, // Invariant: constant. + local_tail: Cell, + tail_addr: *const AtomicU16, + registered: Cell, + + // The first `possible` field is a best effort at tracking the current buffer pool usage and + // from that, tracking the lowest level that has been reached. The two are an attempt at + // letting the user check the sizing needs of their buf_ring pool. + // + // We don't really know how deep the uring device has gone into the pool because we never see + // its head value and it can be taking buffers from the ring, in-flight, while we add buffers + // back to the ring. All we know is when a CQE arrives and a buffer lookup is performed, a + // buffer has already been taken from the pool, and when the buffer is dropped, we add it back + // to the the ring and it is about to be considered part of the pool again. + possible_cur: Cell, + possible_min: Cell, + // + // TODO maybe later + // #[cfg(any(debug, feature = "cautious"))] + // debug_bitmap: RefCell>, +} + +impl RawBufRing { + fn new(new_args: NewArgs) -> io::Result { + #[allow(non_upper_case_globals)] + const trace: bool = false; + + let NewArgs { + page_size, + bgid, + ring_entries, + buf_cnt, + buf_len, + buf_align, + ring_pad, + bufend_align, + auto_register, + } = new_args; + + // Check that none of the important args are zero and the ring_entries is at least large + // enough to hold all the buffers and that ring_entries is a power of 2. + + if (buf_cnt == 0) + || (buf_cnt > ring_entries) + || (buf_len == 0) + || ((ring_entries & (ring_entries - 1)) != 0) + { + return Err(io::Error::from(io::ErrorKind::InvalidInput)); + } + + // entry_size is 16 bytes. + let entry_size = std::mem::size_of::(); + let mut ring_size = entry_size * (ring_entries as usize); + if trace { + println!( + "{}:{}: entry_size {} * ring_entries {} = ring_size {} {:#x}", + file!(), + line!(), + entry_size, + ring_entries, + ring_size, + ring_size, + ); + } + + ring_size += ring_pad; + + if trace { + println!( + "{}:{}: after +ring_pad {} ring_size {} {:#x}", + file!(), + line!(), + ring_pad, + ring_size, + ring_size, + ); + } + + if buf_align > 0 { + let buf_align = buf_align.next_power_of_two(); + ring_size = (ring_size + (buf_align - 1)) & !(buf_align - 1); + if trace { + println!( + "{}:{}: after buf_align ring_size {} {:#x}", + file!(), + line!(), + ring_size, + ring_size, + ); + } + } + let buf_size = buf_len * (buf_cnt as usize); + assert!(ring_size != 0); + assert!(buf_size != 0); + let mut tot_size: usize = ring_size + buf_size; + if trace { + println!( + "{}:{}: ring_size {} {:#x} + buf_size {} {:#x} = tot_size {} {:#x}", + file!(), + line!(), + ring_size, + ring_size, + buf_size, + buf_size, + tot_size, + tot_size + ); + } + if bufend_align > 0 { + // for example, if bufend_align is 4096, would make total size a multiple of pages + let bufend_align = bufend_align.next_power_of_two(); + tot_size = (tot_size + (bufend_align - 1)) & !(bufend_align - 1); + if trace { + println!( + "{}:{}: after bufend_align tot_size {} {:#x}", + file!(), + line!(), + tot_size, + tot_size, + ); + } + } + + let align: usize = page_size; // alignment must be at least the page size + let align = align.next_power_of_two(); + let layout = std::alloc::Layout::from_size_align(tot_size, align).unwrap(); + + assert!(layout.size() >= ring_size); + // Safety: we are assured layout has nonzero size, we passed the assert just above. + let ring_addr: *mut u8 = unsafe { std::alloc::alloc_zeroed(layout) }; + + // Buffers starts after the ring_size. + // Safety: are we assured the address and the offset are in bounds because the ring_addr is + // the value we got from the alloc call, and the layout.size was shown to be at least as + // large as the ring_size. + let buffers_addr: *mut u8 = unsafe { ring_addr.add(ring_size) }; + if trace { + println!( + "{}:{}: ring_addr {} {:#x}, layout: size {} align {}", + file!(), + line!(), + ring_addr as u64, + ring_addr as u64, + layout.size(), + layout.align() + ); + println!( + "{}:{}: buffers_addr {} {:#x}", + file!(), + line!(), + buffers_addr as u64, + buffers_addr as u64, + ); + } + + let ring_addr: *const types::BufRingEntry = ring_addr as _; + + // Safety: the ring_addr passed into tail is the start of the ring. It is both the start of + // the ring and the first entry in the ring. + let tail_addr = unsafe { types::BufRingEntry::tail(ring_addr) } as *const AtomicU16; + + let ring_entries_mask = ring_entries - 1; + assert!((ring_entries & ring_entries_mask) == 0); + + let buf_ring = RawBufRing { + bgid, + ring_entries_mask, + buf_cnt, + buf_len, + layout, + ring_addr, + buffers_addr, + local_tail: Cell::new(0), + tail_addr, + registered: Cell::new(false), + possible_cur: Cell::new(0), + possible_min: Cell::new(buf_cnt), + // + // TODO maybe later + // #[cfg(any(debug, feature = "cautious"))] + // debug_bitmap: RefCell::new(std::vec![0; ((buf_cnt+7)/8) as usize]), + }; + + // Question had come up: where should the initial buffers be added to the ring? + // Here when the ring is created, even before it is registered potentially? + // Or after registration? + // + // For this type, BufRing, we are adding the buffers to the ring as the last part of creating the BufRing, + // even before registration is optionally performed. + // + // We've seen the registration to be successful, even when the ring starts off empty. + + // Add the buffers here where the ring is created. + + for bid in 0..buf_cnt { + buf_ring.buf_ring_add(bid); + } + buf_ring.buf_ring_sync(); + + // The default is to register the buffer ring right here. There is usually no reason the + // caller should want to register it some time later. + // + // Perhaps the caller wants to allocate the buffer ring before the CONTEXT driver is in + // place - that would be a reason to delay the register call until later. + + if auto_register { + buf_ring.register()?; + } + Ok(buf_ring) + } + + /// Register the buffer ring with the kernel. + /// Normally this is done automatically when building a BufRing. + /// + /// This method must be called in the context of a `tokio-uring` runtime. + /// The registration persists for the lifetime of the runtime, unless + /// revoked by the [`unregister`] method. Dropping the + /// instance this method has been called on does revoke + /// the registration and deallocate the buffer space. + /// + /// [`unregister`]: Self::unregister + /// + /// # Errors + /// + /// If a `Provided Buffers` group with the same `bgid` is already registered, the function + /// returns an error. + fn register(&self) -> io::Result<()> { + let bgid = self.bgid; + //println!("{}:{}: register bgid {bgid}", file!(), line!()); + + // Future: move to separate public function so other buf_ring implementations + // can register, and unregister, the same way. + + let res = CurrentRing::with(|ring| { + ring.submitter() + .register_buf_ring(self.ring_addr as _, self.ring_entries(), bgid) + }); + // println!("{}:{}: res {:?}", file!(), line!(), res); + + if let Err(e) = res { + match e.raw_os_error() { + Some(22) => { + // using buf_ring requires kernel 5.19 or greater. + // TODO turn these eprintln into new, more expressive error being returned. + // TODO what convention should we follow in this crate for adding information + // onto an error? + eprintln!( + "buf_ring.register returned {e}, most likely indicating this kernel is not 5.19+", + ); + } + Some(17) => { + // Registering a duplicate bgid is not allowed. There is an `unregister` + // operations that can remove the first. + eprintln!( + "buf_ring.register returned `{e}`, indicating the attempted buffer group id {bgid} was already registered", + ); + } + _ => { + eprintln!("buf_ring.register returned `{e}` for group id {bgid}"); + } + } + return Err(e); + }; + + self.registered.set(true); + + res + } + + /// Unregister the buffer ring from the io_uring. + /// Normally this is done automatically when the BufRing goes out of scope. + /// + /// Warning: requires the CONTEXT driver is already in place or will panic. + fn unregister(&self) -> io::Result<()> { + // If not registered, make this a no-op. + if !self.registered.get() { + return Ok(()); + } + + self.registered.set(false); + + let bgid = self.bgid; + + CurrentRing::with(|ring| ring.submitter().unregister_buf_ring(bgid)) + } + + /// Returns the buffer group id. + #[inline] + fn bgid(&self) -> Bgid { + self.bgid + } + + fn metric_getting_another(&self) { + self.possible_cur.set(self.possible_cur.get() - 1); + self.possible_min.set(std::cmp::min( + self.possible_min.get(), + self.possible_cur.get(), + )); + } + + // # Safety + // + // Dropping a duplicate bid is likely to cause undefined behavior + // as the kernel uses the same buffer for different data concurrently. + unsafe fn dropping_bid_i(&self, bid: Bid) { + self.buf_ring_add(bid); + self.buf_ring_sync(); + } + + #[inline] + fn buf_capacity_i(&self) -> usize { + self.buf_len as _ + } + + #[inline] + // # Panic + // + // This function will panic if given a bid that is not within the valid range 0..self.buf_cnt. + fn stable_ptr_i(&self, bid: Bid) -> *const u8 { + assert!(bid < self.buf_cnt); + let offset: usize = self.buf_len * (bid as usize); + // Safety: buffers_addr is an u8 pointer and was part of an allocation large enough to hold + // buf_cnt number of buf_len buffers. buffers_addr, buf_cnt and buf_len are treated as + // constants and bid was just asserted to be less than buf_cnt. + unsafe { self.buffers_addr.add(offset) } + } + + // # Safety + // + // This may only be called by an owned or &mut object. + // + // # Panic + // This will panic if bid is out of range. + #[inline] + unsafe fn stable_mut_ptr_i(&self, bid: Bid) -> *mut u8 { + assert!(bid < self.buf_cnt); + let offset: usize = self.buf_len * (bid as usize); + // Safety: buffers_addr is an u8 pointer and was part of an allocation large enough to hold + // buf_cnt number of buf_len buffers. buffers_addr, buf_cnt and buf_len are treated as + // constants and bid was just asserted to be less than buf_cnt. + self.buffers_addr.add(offset) + } + + #[inline] + fn ring_entries(&self) -> u16 { + self.ring_entries_mask + 1 + } + + #[inline] + fn mask(&self) -> u16 { + self.ring_entries_mask + } + + // Writes to a ring entry and updates our local copy of the tail. + // + // Adds the buffer known by its buffer id to the buffer ring. The buffer's address and length + // are known given its bid. + // + // This does not sync the new tail value. The caller should use `buf_ring_sync` for that. + // + // Panics if the bid is out of range. + fn buf_ring_add(&self, bid: Bid) { + // Compute address of current tail position, increment the local copy of the tail. Then + // write the buffer's address, length and bid into the current tail entry. + + let cur_tail = self.local_tail.get(); + self.local_tail.set(cur_tail.wrapping_add(1)); + let ring_idx = cur_tail & self.mask(); + + let ring_addr = self.ring_addr as *mut types::BufRingEntry; + + // Safety: + // 1. the pointer address (ring_addr), is set and const at self creation time, + // and points to a block of memory at least as large as the number of ring_entries, + // 2. the mask used to create ring_idx is one less than + // the number of ring_entries, and ring_entries was tested to be a power of two, + // So the address gotten by adding ring_idx entries to ring_addr is guaranteed to + // be a valid address of a ring entry. + let entry = unsafe { &mut *ring_addr.add(ring_idx as usize) }; + + entry.set_addr(self.stable_ptr_i(bid) as _); + entry.set_len(self.buf_len as _); + entry.set_bid(bid); + + // Update accounting. + self.possible_cur.set(self.possible_cur.get() + 1); + + // TODO maybe later + // #[cfg(any(debug, feature = "cautious"))] + // { + // let mut debug_bitmap = self.debug_bitmap.borrow_mut(); + // let m = 1 << (bid % 8); + // assert!(debug_bitmap[(bid / 8) as usize] & m == 0); + // debug_bitmap[(bid / 8) as usize] |= m; + // } + } + + // Make 'count' new buffers visible to the kernel. Called after + // io_uring_buf_ring_add() has been called 'count' times to fill in new + // buffers. + #[inline] + fn buf_ring_sync(&self) { + // Safety: dereferencing this raw pointer is safe. The tail_addr was computed once at init + // to refer to the tail address in the ring and is held const for self's lifetime. + unsafe { + (*self.tail_addr).store(self.local_tail.get(), atomic::Ordering::Release); + } + // The liburing code did io_uring_smp_store_release(&br.tail, local_tail); + } + + // Return the possible_min buffer pool size. + #[allow(dead_code)] + fn possible_min(&self) -> u16 { + self.possible_min.get() + } + + // Return the possible_min buffer pool size and reset to allow fresh counting going forward. + #[allow(dead_code)] + fn possible_min_and_reset(&self) -> u16 { + let res = self.possible_min.get(); + self.possible_min.set(self.buf_cnt); + res + } +} + +impl Drop for RawBufRing { + fn drop(&mut self) { + if self.registered.get() { + _ = self.unregister(); + } + // Safety: the ptr and layout are treated as constant, and ptr (ring_addr) was assigned by + // a call to std::alloc::alloc_zeroed using the same layout. + unsafe { std::alloc::dealloc(self.ring_addr as *mut u8, self.layout) }; + } +} diff --git a/aquatic_udp/src/workers/socket/uring/mod.rs b/aquatic_udp/src/workers/socket/uring/mod.rs new file mode 100644 index 0000000..b68ddc6 --- /dev/null +++ b/aquatic_udp/src/workers/socket/uring/mod.rs @@ -0,0 +1,508 @@ +mod buf_ring; +mod recv_helper; +mod send_buffers; + +use std::cell::RefCell; +use std::collections::VecDeque; +use std::net::UdpSocket; +use std::ops::DerefMut; +use std::os::fd::AsRawFd; +use std::sync::atomic::Ordering; + +use anyhow::Context; +use aquatic_common::access_list::AccessListCache; +use aquatic_common::ServerStartInstant; +use crossbeam_channel::Receiver; +use io_uring::opcode::Timeout; +use io_uring::types::{Fixed, Timespec}; +use io_uring::{IoUring, Probe}; + +use aquatic_common::{ + access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr, + PanicSentinel, ValidUntil, +}; +use aquatic_udp_protocol::*; + +use crate::common::*; +use crate::config::Config; + +use self::buf_ring::BufRing; +use self::recv_helper::RecvHelper; +use self::send_buffers::{ResponseType, SendBuffers}; + +use super::storage::PendingScrapeResponseSlab; +use super::validator::ConnectionValidator; +use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; + +const BUF_LEN: usize = 8192; + +const USER_DATA_RECV: u64 = u64::MAX; +const USER_DATA_PULSE_TIMEOUT: u64 = u64::MAX - 1; +const USER_DATA_CLEANING_TIMEOUT: u64 = u64::MAX - 2; + +const SOCKET_IDENTIFIER: Fixed = Fixed(0); + +thread_local! { + /// Store IoUring instance here so that it can be accessed in BufRing::drop + pub static CURRENT_RING: CurrentRing = CurrentRing(RefCell::new(None)); +} + +pub struct CurrentRing(RefCell>); + +impl CurrentRing { + fn with(mut f: F) -> T + where + F: FnMut(&mut IoUring) -> T, + { + CURRENT_RING.with(|r| { + let mut opt_ring = r.0.borrow_mut(); + + f(Option::as_mut(opt_ring.deref_mut()).expect("IoUring not set")) + }) + } +} + +pub struct SocketWorker { + config: Config, + shared_state: State, + request_sender: ConnectedRequestSender, + response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + access_list_cache: AccessListCache, + validator: ConnectionValidator, + server_start_instant: ServerStartInstant, + pending_scrape_responses: PendingScrapeResponseSlab, + send_buffers: SendBuffers, + recv_helper: RecvHelper, + local_responses: VecDeque<(Response, CanonicalSocketAddr)>, + pulse_timeout: Timespec, + cleaning_timeout: Timespec, + buf_ring: BufRing, + #[allow(dead_code)] + socket: UdpSocket, +} + +impl SocketWorker { + pub fn run( + _sentinel: PanicSentinel, + shared_state: State, + config: Config, + validator: ConnectionValidator, + server_start_instant: ServerStartInstant, + request_sender: ConnectedRequestSender, + response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, + priv_dropper: PrivilegeDropper, + ) { + let ring_entries = config.network.ring_entries.next_power_of_two(); + // Bias ring towards sending to prevent build-up of unsent responses + let send_buffer_entries = ring_entries - (ring_entries / 4); + + let socket = create_socket(&config, priv_dropper).expect("create socket"); + let access_list_cache = create_access_list_cache(&shared_state.access_list); + let send_buffers = SendBuffers::new(&config, send_buffer_entries as usize); + let recv_helper = RecvHelper::new(&config); + let cleaning_timeout = + Timespec::new().sec(config.cleaning.pending_scrape_cleaning_interval); + + let ring = IoUring::builder() + .setup_coop_taskrun() + .setup_single_issuer() + .setup_submit_all() + .build(ring_entries.into()) + .unwrap(); + + ring.submitter() + .register_files(&[socket.as_raw_fd()]) + .unwrap(); + + // Store ring in thread local storage before creating BufRing + CURRENT_RING.with(|r| *r.0.borrow_mut() = Some(ring)); + + let buf_ring = buf_ring::Builder::new(0) + .ring_entries(ring_entries) + .buf_len(BUF_LEN) + .build() + .unwrap(); + + let mut worker = Self { + config, + shared_state, + validator, + server_start_instant, + request_sender, + response_receiver, + access_list_cache, + pending_scrape_responses: Default::default(), + send_buffers, + recv_helper, + local_responses: Default::default(), + pulse_timeout: Timespec::new().sec(1), + cleaning_timeout, + buf_ring, + socket, + }; + + CurrentRing::with(|ring| worker.run_inner(ring)); + } + + pub fn run_inner(&mut self, ring: &mut IoUring) { + let mut pending_scrape_valid_until = ValidUntil::new( + self.server_start_instant, + self.config.cleaning.max_pending_scrape_age, + ); + + let recv_entry = self + .recv_helper + .create_entry(self.buf_ring.bgid().try_into().unwrap()); + // This timeout makes it possible to avoid busy-polling and enables + // regular updates of pending_scrape_valid_until + let pulse_timeout_entry = Timeout::new(&self.pulse_timeout as *const _) + .build() + .user_data(USER_DATA_PULSE_TIMEOUT); + let cleaning_timeout_entry = Timeout::new(&self.cleaning_timeout as *const _) + .build() + .user_data(USER_DATA_CLEANING_TIMEOUT); + + let mut squeue_buf = vec![ + recv_entry.clone(), + pulse_timeout_entry.clone(), + cleaning_timeout_entry.clone(), + ]; + + loop { + for sqe in squeue_buf.drain(..) { + unsafe { ring.submission().push(&sqe).unwrap() }; + } + + let mut num_send_added = 0; + + let sq_space = { + let sq = ring.submission(); + + sq.capacity() - sq.len() + }; + + // Enqueue local responses + for _ in 0..sq_space { + if let Some((response, addr)) = self.local_responses.pop_front() { + match self.send_buffers.prepare_entry(&response, addr) { + Ok(entry) => { + unsafe { ring.submission().push(&entry).unwrap() }; + + num_send_added += 1; + } + Err(send_buffers::Error::NoBuffers) => { + self.local_responses.push_front((response, addr)); + + break; + } + Err(send_buffers::Error::SerializationFailed(err)) => { + ::log::error!("write response to buffer: {:#}", err); + } + } + } else { + break; + } + } + + let sq_space = { + let sq = ring.submission(); + + sq.capacity() - sq.len() + }; + + // Enqueue swarm worker responses + 'outer: for _ in 0..sq_space { + let (response, addr) = loop { + match self.response_receiver.try_recv() { + Ok((ConnectedResponse::AnnounceIpv4(response), addr)) => { + break (Response::AnnounceIpv4(response), addr); + } + Ok((ConnectedResponse::AnnounceIpv6(response), addr)) => { + break (Response::AnnounceIpv6(response), addr); + } + Ok((ConnectedResponse::Scrape(response), addr)) => { + if let Some(response) = + self.pending_scrape_responses.add_and_get_finished(response) + { + break (Response::Scrape(response), addr); + } + } + Err(_) => { + break 'outer; + } + } + }; + + match self.send_buffers.prepare_entry(&response, addr) { + Ok(entry) => { + unsafe { ring.submission().push(&entry).unwrap() }; + + num_send_added += 1; + } + Err(send_buffers::Error::NoBuffers) => { + self.local_responses.push_back((response, addr)); + + break; + } + Err(send_buffers::Error::SerializationFailed(err)) => { + ::log::error!("write response to buffer: {:#}", err); + } + } + } + + // Wait for all sendmsg entries to complete, as well as at least + // one recvmsg or timeout, in order to avoid busy-polling if there + // is no incoming data. + ring.submitter() + .submit_and_wait(num_send_added + 1) + .unwrap(); + + for cqe in ring.completion() { + match cqe.user_data() { + USER_DATA_RECV => { + self.handle_recv_cqe(pending_scrape_valid_until, &cqe); + + if !io_uring::cqueue::more(cqe.flags()) { + squeue_buf.push(recv_entry.clone()); + } + } + USER_DATA_PULSE_TIMEOUT => { + pending_scrape_valid_until = ValidUntil::new( + self.server_start_instant, + self.config.cleaning.max_pending_scrape_age, + ); + + ::log::info!( + "pending responses: {} local, {} swarm", + self.local_responses.len(), + self.response_receiver.len() + ); + + squeue_buf.push(pulse_timeout_entry.clone()); + } + USER_DATA_CLEANING_TIMEOUT => { + self.pending_scrape_responses + .clean(self.server_start_instant.seconds_elapsed()); + + squeue_buf.push(cleaning_timeout_entry.clone()); + } + send_buffer_index => { + let result = cqe.result(); + + if result < 0 { + ::log::error!( + "send: {:#}", + ::std::io::Error::from_raw_os_error(-result) + ); + } else if self.config.statistics.active() { + let send_buffer_index = send_buffer_index as usize; + + let (statistics, extra_bytes) = + if self.send_buffers.receiver_is_ipv4(send_buffer_index) { + (&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4) + } else { + (&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6) + }; + + statistics + .bytes_sent + .fetch_add(result as usize + extra_bytes, Ordering::Relaxed); + + let response_counter = + match self.send_buffers.response_type(send_buffer_index) { + ResponseType::Connect => &statistics.responses_sent_connect, + ResponseType::Announce => &statistics.responses_sent_announce, + ResponseType::Scrape => &statistics.responses_sent_scrape, + ResponseType::Error => &statistics.responses_sent_error, + }; + + response_counter.fetch_add(1, Ordering::Relaxed); + } + + self.send_buffers + .mark_index_as_free(send_buffer_index as usize); + } + } + } + + self.send_buffers.reset_index(); + } + } + + fn handle_recv_cqe( + &mut self, + pending_scrape_valid_until: ValidUntil, + cqe: &io_uring::cqueue::Entry, + ) { + let result = cqe.result(); + + if result < 0 { + // Will produce ENOBUFS if there were no free buffers + ::log::warn!("recv: {:#}", ::std::io::Error::from_raw_os_error(-result)); + + return; + } + + let buffer = unsafe { + match self.buf_ring.get_buf(result as u32, cqe.flags()) { + Ok(Some(buffer)) => buffer, + Ok(None) => { + ::log::error!("Couldn't get buffer"); + + return; + } + Err(err) => { + ::log::error!("Couldn't get buffer: {:#}", err); + + return; + } + } + }; + + let buffer = buffer.as_slice(); + + let (res_request, addr) = self.recv_helper.parse(buffer); + + match res_request { + Ok(request) => self.handle_request(pending_scrape_valid_until, request, addr), + Err(RequestParseError::Sendable { + connection_id, + transaction_id, + err, + }) => { + ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); + + if self.validator.connection_id_valid(addr, connection_id) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + self.local_responses.push_back((response.into(), addr)); + } + } + Err(RequestParseError::Unsendable { err }) => { + ::log::debug!("Couldn't parse request from {:?}: {}", addr, err); + } + } + + if self.config.statistics.active() { + if addr.is_ipv4() { + self.shared_state + .statistics_ipv4 + .bytes_received + .fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); + self.shared_state + .statistics_ipv4 + .requests_received + .fetch_add(1, Ordering::Relaxed); + } else { + self.shared_state + .statistics_ipv6 + .bytes_received + .fetch_add(buffer.len() + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + self.shared_state + .statistics_ipv6 + .requests_received + .fetch_add(1, Ordering::Relaxed); + } + } + } + + fn handle_request( + &mut self, + pending_scrape_valid_until: ValidUntil, + request: Request, + src: CanonicalSocketAddr, + ) { + let access_list_mode = self.config.access_list.mode; + + match request { + Request::Connect(request) => { + let connection_id = self.validator.create_connection_id(src); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, + }); + + self.local_responses.push_back((response, src)); + } + Request::Announce(request) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + if self + .access_list_cache + .load() + .allows(access_list_mode, &request.info_hash.0) + { + let worker_index = + SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); + + self.request_sender.try_send_to( + worker_index, + ConnectedRequest::Announce(request), + src, + ); + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + self.local_responses.push_back((response, src)) + } + } + } + Request::Scrape(request) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + let split_requests = self.pending_scrape_responses.prepare_split_requests( + &self.config, + request, + pending_scrape_valid_until, + ); + + for (swarm_worker_index, request) in split_requests { + self.request_sender.try_send_to( + swarm_worker_index, + ConnectedRequest::Scrape(request), + src, + ); + } + } + } + } + } +} + +pub fn supported_on_current_kernel() -> anyhow::Result<()> { + let opcodes = [ + // We can't probe for RecvMsgMulti, so we probe for SendZc, which was + // also introduced in Linux 6.0 + io_uring::opcode::SendZc::CODE, + ]; + + let ring = IoUring::new(1).with_context(|| "create ring")?; + + let mut probe = Probe::new(); + + ring.submitter() + .register_probe(&mut probe) + .with_context(|| "register probe")?; + + for opcode in opcodes { + if !probe.is_supported(opcode) { + return Err(anyhow::anyhow!( + "io_uring opcode {:b} not supported", + opcode + )); + } + } + + Ok(()) +} diff --git a/aquatic_udp/src/workers/socket/uring/recv_helper.rs b/aquatic_udp/src/workers/socket/uring/recv_helper.rs new file mode 100644 index 0000000..920cb2f --- /dev/null +++ b/aquatic_udp/src/workers/socket/uring/recv_helper.rs @@ -0,0 +1,121 @@ +use std::{ + net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + ptr::null_mut, +}; + +use aquatic_common::CanonicalSocketAddr; +use aquatic_udp_protocol::{Request, RequestParseError}; +use io_uring::{opcode::RecvMsgMulti, types::RecvMsgOut}; + +use crate::config::Config; + +use super::{SOCKET_IDENTIFIER, USER_DATA_RECV}; + +pub struct RecvHelper { + network_address: IpAddr, + max_scrape_torrents: u8, + #[allow(dead_code)] + name_v4: Box, + msghdr_v4: Box, + #[allow(dead_code)] + name_v6: Box, + msghdr_v6: Box, +} + +impl RecvHelper { + pub fn new(config: &Config) -> Self { + let mut name_v4 = Box::new(libc::sockaddr_in { + sin_family: 0, + sin_port: 0, + sin_addr: libc::in_addr { s_addr: 0 }, + sin_zero: [0; 8], + }); + + let msghdr_v4 = Box::new(libc::msghdr { + msg_name: &mut name_v4 as *mut _ as *mut libc::c_void, + msg_namelen: core::mem::size_of::() as u32, + msg_iov: null_mut(), + msg_iovlen: 0, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }); + + let mut name_v6 = Box::new(libc::sockaddr_in6 { + sin6_family: 0, + sin6_port: 0, + sin6_flowinfo: 0, + sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, + sin6_scope_id: 0, + }); + + let msghdr_v6 = Box::new(libc::msghdr { + msg_name: &mut name_v6 as *mut _ as *mut libc::c_void, + msg_namelen: core::mem::size_of::() as u32, + msg_iov: null_mut(), + msg_iovlen: 0, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }); + + Self { + network_address: config.network.address.ip(), + max_scrape_torrents: config.protocol.max_scrape_torrents, + name_v4, + msghdr_v4, + name_v6, + msghdr_v6, + } + } + + pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry { + let msghdr: *const libc::msghdr = if self.network_address.is_ipv4() { + &*self.msghdr_v4 + } else { + &*self.msghdr_v6 + }; + + RecvMsgMulti::new(SOCKET_IDENTIFIER, msghdr, buf_group) + .build() + .user_data(USER_DATA_RECV) + } + + pub fn parse( + &self, + buffer: &[u8], + ) -> (Result, CanonicalSocketAddr) { + let msghdr = if self.network_address.is_ipv4() { + &self.msghdr_v4 + } else { + &self.msghdr_v6 + }; + + let msg = RecvMsgOut::parse(buffer, msghdr).unwrap(); + + let addr = unsafe { + if self.network_address.is_ipv4() { + let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in); + + SocketAddr::V4(SocketAddrV4::new( + u32::from_be(name_data.sin_addr.s_addr).into(), + u16::from_be(name_data.sin_port), + )) + } else { + let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in6); + + SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from(name_data.sin6_addr.s6_addr), + u16::from_be(name_data.sin6_port), + u32::from_be(name_data.sin6_flowinfo), + u32::from_be(name_data.sin6_scope_id), + )) + } + }; + + ( + Request::from_bytes(msg.payload_data(), self.max_scrape_torrents), + CanonicalSocketAddr::new(addr), + ) + } +} diff --git a/aquatic_udp/src/workers/socket/uring/send_buffers.rs b/aquatic_udp/src/workers/socket/uring/send_buffers.rs new file mode 100644 index 0000000..0eae144 --- /dev/null +++ b/aquatic_udp/src/workers/socket/uring/send_buffers.rs @@ -0,0 +1,221 @@ +use std::{io::Cursor, net::IpAddr, ops::IndexMut, ptr::null_mut}; + +use aquatic_common::CanonicalSocketAddr; +use aquatic_udp_protocol::Response; +use io_uring::opcode::SendMsg; + +use crate::config::Config; + +use super::{BUF_LEN, SOCKET_IDENTIFIER}; + +pub enum Error { + NoBuffers, + SerializationFailed(std::io::Error), +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ResponseType { + Connect, + Announce, + Scrape, + Error, +} + +impl ResponseType { + fn from_response(response: &Response) -> Self { + match response { + Response::Connect(_) => Self::Connect, + Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => Self::Announce, + Response::Scrape(_) => Self::Scrape, + Response::Error(_) => Self::Error, + } + } +} + +pub struct SendBuffers { + likely_next_free_index: usize, + network_address: IpAddr, + names_v4: Vec, + names_v6: Vec, + buffers: Vec<[u8; BUF_LEN]>, + iovecs: Vec, + msghdrs: Vec, + free: Vec, + // Only used for statistics + receiver_is_ipv4: Vec, + // Only used for statistics + response_types: Vec, +} + +impl SendBuffers { + pub fn new(config: &Config, capacity: usize) -> Self { + let mut buffers = ::std::iter::repeat([0u8; BUF_LEN]) + .take(capacity) + .collect::>(); + + let mut iovecs = buffers + .iter_mut() + .map(|buffer| libc::iovec { + iov_base: buffer.as_mut_ptr() as *mut libc::c_void, + iov_len: buffer.len(), + }) + .collect::>(); + + let (names_v4, names_v6, msghdrs) = if config.network.address.is_ipv4() { + let mut names_v4 = ::std::iter::repeat(libc::sockaddr_in { + sin_family: libc::AF_INET as u16, + sin_port: 0, + sin_addr: libc::in_addr { s_addr: 0 }, + sin_zero: [0; 8], + }) + .take(capacity) + .collect::>(); + + let msghdrs = names_v4 + .iter_mut() + .zip(iovecs.iter_mut()) + .map(|(msg_name, msg_iov)| libc::msghdr { + msg_name: msg_name as *mut _ as *mut libc::c_void, + msg_namelen: core::mem::size_of::() as u32, + msg_iov: msg_iov as *mut _, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }) + .collect::>(); + + (names_v4, Vec::new(), msghdrs) + } else { + let mut names_v6 = ::std::iter::repeat(libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as u16, + sin6_port: 0, + sin6_flowinfo: 0, + sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, + sin6_scope_id: 0, + }) + .take(capacity) + .collect::>(); + + let msghdrs = names_v6 + .iter_mut() + .zip(iovecs.iter_mut()) + .map(|(msg_name, msg_iov)| libc::msghdr { + msg_name: msg_name as *mut _ as *mut libc::c_void, + msg_namelen: core::mem::size_of::() as u32, + msg_iov: msg_iov as *mut _, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }) + .collect::>(); + + (Vec::new(), names_v6, msghdrs) + }; + + Self { + likely_next_free_index: 0, + network_address: config.network.address.ip(), + names_v4, + names_v6, + buffers, + iovecs, + msghdrs, + free: ::std::iter::repeat(true).take(capacity).collect(), + receiver_is_ipv4: ::std::iter::repeat(true).take(capacity).collect(), + response_types: ::std::iter::repeat(ResponseType::Connect) + .take(capacity) + .collect(), + } + } + + pub fn receiver_is_ipv4(&mut self, index: usize) -> bool { + self.receiver_is_ipv4[index] + } + + pub fn response_type(&mut self, index: usize) -> ResponseType { + self.response_types[index] + } + + pub fn mark_index_as_free(&mut self, index: usize) { + self.free[index] = true; + } + + /// Call after going through completion queue + pub fn reset_index(&mut self) { + self.likely_next_free_index = 0; + } + + pub fn prepare_entry( + &mut self, + response: &Response, + addr: CanonicalSocketAddr, + ) -> Result { + let index = self.next_free_index()?; + + // Set receiver socket addr + if self.network_address.is_ipv4() { + let msg_name = self.names_v4.index_mut(index); + let addr = addr.get_ipv4().unwrap(); + + msg_name.sin_port = addr.port().to_be(); + msg_name.sin_addr.s_addr = if let IpAddr::V4(addr) = addr.ip() { + u32::from(addr).to_be() + } else { + panic!("ipv6 address in ipv4 mode"); + }; + + self.receiver_is_ipv4[index] = true; + } else { + let msg_name = self.names_v6.index_mut(index); + let addr = addr.get_ipv6_mapped(); + + msg_name.sin6_port = addr.port().to_be(); + msg_name.sin6_addr.s6_addr = if let IpAddr::V6(addr) = addr.ip() { + addr.octets() + } else { + panic!("ipv4 address when ipv6 or ipv6-mapped address expected"); + }; + + self.receiver_is_ipv4[index] = addr.is_ipv4(); + } + + let mut cursor = Cursor::new(self.buffers.index_mut(index).as_mut_slice()); + + match response.write(&mut cursor) { + Ok(()) => { + self.iovecs[index].iov_len = cursor.position() as usize; + self.response_types[index] = ResponseType::from_response(response); + self.free[index] = false; + + self.likely_next_free_index = index + 1; + + let sqe = SendMsg::new(SOCKET_IDENTIFIER, self.msghdrs.index_mut(index)) + .build() + .user_data(index as u64); + + Ok(sqe) + } + Err(err) => Err(Error::SerializationFailed(err)), + } + } + + fn next_free_index(&self) -> Result { + if self.likely_next_free_index >= self.free.len() { + return Err(Error::NoBuffers); + } + + for (i, free) in self.free[self.likely_next_free_index..] + .iter() + .copied() + .enumerate() + { + if free { + return Ok(self.likely_next_free_index + i); + } + } + + Err(Error::NoBuffers) + } +} diff --git a/scripts/run-aquatic-udp.sh b/scripts/run-aquatic-udp.sh index 0007289..2438388 100755 --- a/scripts/run-aquatic-udp.sh +++ b/scripts/run-aquatic-udp.sh @@ -2,4 +2,4 @@ . ./scripts/env-native-cpu-without-avx-512 -cargo run --profile "release-debug" -p aquatic_udp -- $@ +cargo run --profile "release-debug" -p aquatic_udp --features "io-uring" -- $@