Move all crates to new crates dir

This commit is contained in:
Joakim Frostegård 2023-10-18 23:53:41 +02:00
parent 3835da22ac
commit 9b032f7e24
128 changed files with 27 additions and 26 deletions

61
crates/udp/Cargo.toml Normal file
View file

@ -0,0 +1,61 @@
[package]
name = "aquatic_udp"
description = "High-performance open UDP BitTorrent tracker"
keywords = ["udp", "server", "peer-to-peer", "torrent", "bittorrent"]
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
readme.workspace = true
rust-version.workspace = true
[lib]
name = "aquatic_udp"
[[bin]]
name = "aquatic_udp"
[features]
default = ["prometheus"]
cpu-pinning = ["aquatic_common/hwloc"]
prometheus = ["metrics", "metrics-util", "metrics-exporter-prometheus"]
io-uring = ["dep:io-uring"]
[dependencies]
aquatic_common.workspace = true
aquatic_toml_config.workspace = true
aquatic_udp_protocol.workspace = true
anyhow = "1"
blake3 = "1"
cfg-if = "1"
compact_str = "0.7"
constant_time_eq = "0.3"
crossbeam-channel = "0.5"
getrandom = "0.2"
hashbrown = { version = "0.14", default-features = false }
hdrhistogram = "7"
hex = "0.4"
io-uring = { version = "0.6", optional = true }
libc = "0.2"
log = "0.4"
metrics = { version = "0.21", optional = true }
metrics-util = { version = "0.15", optional = true }
metrics-exporter-prometheus = { version = "0.12", optional = true, default-features = false, features = ["http-listener"] }
mimalloc = { version = "0.1", default-features = false }
mio = { version = "0.8", features = ["net", "os-poll"] }
num-format = "0.4"
rand = { version = "0.8", features = ["small_rng"] }
serde = { version = "1", features = ["derive"] }
signal-hook = { version = "0.3" }
slab = "0.4"
socket2 = { version = "0.5", features = ["all"] }
time = { version = "0.3", features = ["formatting"] }
tinytemplate = "1"
[dev-dependencies]
hex = "0.4"
tempfile = "3"
quickcheck = "1"
quickcheck_macros = "1"

254
crates/udp/src/common.rs Normal file
View file

@ -0,0 +1,254 @@
use std::collections::BTreeMap;
use std::hash::Hash;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use crossbeam_channel::{Sender, TrySendError};
use aquatic_common::access_list::AccessListArcSwap;
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::*;
use hdrhistogram::Histogram;
use crate::config::Config;
pub const BUFFER_SIZE: usize = 8192;
#[derive(Debug)]
pub struct PendingScrapeRequest {
pub slab_key: usize,
pub info_hashes: BTreeMap<usize, InfoHash>,
}
#[derive(Debug)]
pub struct PendingScrapeResponse {
pub slab_key: usize,
pub torrent_stats: BTreeMap<usize, TorrentScrapeStatistics>,
}
#[derive(Debug)]
pub enum ConnectedRequest {
Announce(AnnounceRequest),
Scrape(PendingScrapeRequest),
}
#[derive(Debug)]
pub enum ConnectedResponse {
AnnounceIpv4(AnnounceResponse<Ipv4Addr>),
AnnounceIpv6(AnnounceResponse<Ipv6Addr>),
Scrape(PendingScrapeResponse),
}
#[derive(Clone, Copy, Debug)]
pub struct SocketWorkerIndex(pub usize);
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct SwarmWorkerIndex(pub usize);
impl SwarmWorkerIndex {
pub fn from_info_hash(config: &Config, info_hash: InfoHash) -> Self {
Self(info_hash.0[0] as usize % config.swarm_workers)
}
}
pub struct ConnectedRequestSender {
index: SocketWorkerIndex,
senders: Vec<Sender<(SocketWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>>,
}
impl ConnectedRequestSender {
pub fn new(
index: SocketWorkerIndex,
senders: Vec<Sender<(SocketWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>>,
) -> Self {
Self { index, senders }
}
pub fn try_send_to(
&self,
index: SwarmWorkerIndex,
request: ConnectedRequest,
addr: CanonicalSocketAddr,
) {
match self.senders[index.0].try_send((self.index, request, addr)) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
::log::error!("Request channel {} is full, dropping request. Try increasing number of swarm workers or raising config.worker_channel_size.", index.0)
}
Err(TrySendError::Disconnected(_)) => {
panic!("Request channel {} is disconnected", index.0);
}
}
}
}
pub struct ConnectedResponseSender {
senders: Vec<Sender<(ConnectedResponse, CanonicalSocketAddr)>>,
}
impl ConnectedResponseSender {
pub fn new(senders: Vec<Sender<(ConnectedResponse, CanonicalSocketAddr)>>) -> Self {
Self { senders }
}
pub fn try_send_to(
&self,
index: SocketWorkerIndex,
response: ConnectedResponse,
addr: CanonicalSocketAddr,
) {
match self.senders[index.0].try_send((response, addr)) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
::log::error!("Response channel {} is full, dropping response. Try increasing number of socket workers or raising config.worker_channel_size.", index.0)
}
Err(TrySendError::Disconnected(_)) => {
panic!("Response channel {} is disconnected", index.0);
}
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)]
pub enum PeerStatus {
Seeding,
Leeching,
Stopped,
}
impl PeerStatus {
/// Determine peer status from announce event and number of bytes left.
///
/// Likely, the last branch will be taken most of the time.
#[inline]
pub fn from_event_and_bytes_left(event: AnnounceEvent, bytes_left: NumberOfBytes) -> Self {
if event == AnnounceEvent::Stopped {
Self::Stopped
} else if bytes_left.0 == 0 {
Self::Seeding
} else {
Self::Leeching
}
}
}
pub enum StatisticsMessage {
Ipv4PeerHistogram(Histogram<u64>),
Ipv6PeerHistogram(Histogram<u64>),
PeerAdded(PeerId),
PeerRemoved(PeerId),
}
pub struct Statistics {
pub requests_received: AtomicUsize,
pub responses_sent_connect: AtomicUsize,
pub responses_sent_announce: AtomicUsize,
pub responses_sent_scrape: AtomicUsize,
pub responses_sent_error: AtomicUsize,
pub bytes_received: AtomicUsize,
pub bytes_sent: AtomicUsize,
pub torrents: Vec<AtomicUsize>,
pub peers: Vec<AtomicUsize>,
}
impl Statistics {
pub fn new(num_swarm_workers: usize) -> Self {
Self {
requests_received: Default::default(),
responses_sent_connect: Default::default(),
responses_sent_announce: Default::default(),
responses_sent_scrape: Default::default(),
responses_sent_error: Default::default(),
bytes_received: Default::default(),
bytes_sent: Default::default(),
torrents: Self::create_atomic_usize_vec(num_swarm_workers),
peers: Self::create_atomic_usize_vec(num_swarm_workers),
}
}
fn create_atomic_usize_vec(len: usize) -> Vec<AtomicUsize> {
::std::iter::repeat_with(|| AtomicUsize::default())
.take(len)
.collect()
}
}
#[derive(Clone)]
pub struct State {
pub access_list: Arc<AccessListArcSwap>,
pub statistics_ipv4: Arc<Statistics>,
pub statistics_ipv6: Arc<Statistics>,
}
impl State {
pub fn new(num_swarm_workers: usize) -> Self {
Self {
access_list: Arc::new(AccessListArcSwap::default()),
statistics_ipv4: Arc::new(Statistics::new(num_swarm_workers)),
statistics_ipv6: Arc::new(Statistics::new(num_swarm_workers)),
}
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv6Addr;
use crate::config::Config;
use super::*;
#[test]
fn test_peer_status_from_event_and_bytes_left() {
use crate::common::*;
use PeerStatus::*;
let f = PeerStatus::from_event_and_bytes_left;
assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes(0)));
assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes(1)));
assert_eq!(Seeding, f(AnnounceEvent::Started, NumberOfBytes(0)));
assert_eq!(Leeching, f(AnnounceEvent::Started, NumberOfBytes(1)));
assert_eq!(Seeding, f(AnnounceEvent::Completed, NumberOfBytes(0)));
assert_eq!(Leeching, f(AnnounceEvent::Completed, NumberOfBytes(1)));
assert_eq!(Seeding, f(AnnounceEvent::None, NumberOfBytes(0)));
assert_eq!(Leeching, f(AnnounceEvent::None, NumberOfBytes(1)));
}
// Assumes that announce response with maximum amount of ipv6 peers will
// be the longest
#[test]
fn test_buffer_size() {
use aquatic_udp_protocol::*;
let config = Config::default();
let peers = ::std::iter::repeat(ResponsePeer {
ip_address: Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1),
port: Port(1),
})
.take(config.protocol.max_response_peers)
.collect();
let response = Response::AnnounceIpv6(AnnounceResponse {
transaction_id: TransactionId(1),
announce_interval: AnnounceInterval(1),
seeders: NumberOfPeers(1),
leechers: NumberOfPeers(1),
peers,
});
let mut buf = Vec::new();
response.write(&mut buf).unwrap();
println!("Buffer len: {}", buf.len());
assert!(buf.len() <= BUFFER_SIZE);
}
}

267
crates/udp/src/config.rs Normal file
View file

@ -0,0 +1,267 @@
use std::{net::SocketAddr, path::PathBuf};
use aquatic_common::{access_list::AccessListConfig, privileges::PrivilegeConfig};
use cfg_if::cfg_if;
use serde::Deserialize;
use aquatic_common::cli::LogLevel;
use aquatic_toml_config::TomlConfig;
/// aquatic_udp configuration
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Config {
/// Number of socket workers. Increase with core count
///
/// Socket workers receive requests from clients and parse them.
/// Responses to connect requests are sent back immediately. Announce and
/// scrape requests are passed on to swarm workers, which generate
/// responses and send them back to the socket worker, which sends them
/// to the client.
pub socket_workers: usize,
/// Number of swarm workers. One is enough in almost all cases
///
/// Swarm workers receive parsed announce and scrape requests from socket
/// workers, generate responses and send them back to the socket workers.
pub swarm_workers: usize,
pub log_level: LogLevel,
/// Maximum number of items in each channel passing requests/responses
/// between workers. A value of zero means that the channels will be of
/// unbounded size.
pub worker_channel_size: usize,
/// How long to block waiting for requests in swarm workers.
///
/// Higher values means that with zero traffic, the worker will not
/// unnecessarily cause the CPU to wake up as often. However, high values
/// (something like larger than 1000) combined with very low traffic can
/// cause delays in torrent cleaning.
pub request_channel_recv_timeout_ms: u64,
pub network: NetworkConfig,
pub protocol: ProtocolConfig,
pub statistics: StatisticsConfig,
pub cleaning: CleaningConfig,
pub privileges: PrivilegeConfig,
pub access_list: AccessListConfig,
#[cfg(feature = "cpu-pinning")]
pub cpu_pinning: aquatic_common::cpu_pinning::asc::CpuPinningConfigAsc,
}
impl Default for Config {
fn default() -> Self {
Self {
socket_workers: 1,
swarm_workers: 1,
log_level: LogLevel::Error,
worker_channel_size: 0,
request_channel_recv_timeout_ms: 100,
network: NetworkConfig::default(),
protocol: ProtocolConfig::default(),
statistics: StatisticsConfig::default(),
cleaning: CleaningConfig::default(),
privileges: PrivilegeConfig::default(),
access_list: AccessListConfig::default(),
#[cfg(feature = "cpu-pinning")]
cpu_pinning: Default::default(),
}
}
}
impl aquatic_common::cli::Config for Config {
fn get_log_level(&self) -> Option<LogLevel> {
Some(self.log_level)
}
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct NetworkConfig {
/// Bind to this address
pub address: SocketAddr,
/// Only allow access over IPv6
pub only_ipv6: bool,
/// Size of socket recv buffer. Use 0 for OS default.
///
/// This setting can have a big impact on dropped packages. It might
/// require changing system defaults. Some examples of commands to set
/// values for different operating systems:
///
/// macOS:
/// $ sudo sysctl net.inet.udp.recvspace=6000000
///
/// Linux:
/// $ sudo sysctl -w net.core.rmem_max=104857600
/// $ sudo sysctl -w net.core.rmem_default=104857600
pub socket_recv_buffer_size: usize,
/// Poll event capacity (mio backend only)
pub poll_event_capacity: usize,
/// Poll timeout in milliseconds (mio backend only)
pub poll_timeout_ms: u64,
/// Number of ring entries (io_uring backend only)
///
/// Will be rounded to next power of two if not already one. Increasing
/// this value can help throughput up to a certain point.
#[cfg(feature = "io-uring")]
pub ring_size: u16,
/// Store this many responses at most for retrying (once) on send failure
/// (mio backend only)
///
/// Useful on operating systems that do not provide an udp send buffer,
/// such as FreeBSD. Setting the value to zero disables resending
/// functionality.
pub resend_buffer_max_len: usize,
}
impl NetworkConfig {
pub fn ipv4_active(&self) -> bool {
self.address.is_ipv4() || !self.only_ipv6
}
pub fn ipv6_active(&self) -> bool {
self.address.is_ipv6()
}
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
address: SocketAddr::from(([0, 0, 0, 0], 3000)),
only_ipv6: false,
socket_recv_buffer_size: 4096 * 128,
poll_event_capacity: 4096,
poll_timeout_ms: 50,
#[cfg(feature = "io-uring")]
ring_size: 1024,
resend_buffer_max_len: 0,
}
}
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct ProtocolConfig {
/// Maximum number of torrents to allow in scrape request
pub max_scrape_torrents: u8,
/// Maximum number of peers to return in announce response
pub max_response_peers: usize,
/// Ask peers to announce this often (seconds)
pub peer_announce_interval: i32,
}
impl Default for ProtocolConfig {
fn default() -> Self {
Self {
max_scrape_torrents: 70,
max_response_peers: 50,
peer_announce_interval: 60 * 15,
}
}
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct StatisticsConfig {
/// Collect and print/write statistics this often (seconds)
pub interval: u64,
/// Collect statistics on number of peers per torrent
///
/// Will increase time taken for torrent cleaning.
pub torrent_peer_histograms: bool,
/// Collect statistics on peer clients.
///
/// Also, see `prometheus_peer_id_prefixes`.
///
/// Expect a certain CPU hit (maybe 5% higher consumption) and a bit higher
/// memory use
pub peer_clients: bool,
/// Print statistics to standard output
pub print_to_stdout: bool,
/// Save statistics as HTML to a file
pub write_html_to_file: bool,
/// Path to save HTML file to
pub html_file_path: PathBuf,
/// Run a prometheus endpoint
#[cfg(feature = "prometheus")]
pub run_prometheus_endpoint: bool,
/// Address to run prometheus endpoint on
#[cfg(feature = "prometheus")]
pub prometheus_endpoint_address: SocketAddr,
/// Serve information on all peer id prefixes on the prometheus endpoint.
///
/// Requires `peer_clients` to be activated.
///
/// May consume quite a bit of CPU and RAM, since data on every single peer
/// client will be reported continuously on the endpoint
#[cfg(feature = "prometheus")]
pub prometheus_peer_id_prefixes: bool,
}
impl StatisticsConfig {
cfg_if! {
if #[cfg(feature = "prometheus")] {
pub fn active(&self) -> bool {
(self.interval != 0) &
(self.print_to_stdout | self.write_html_to_file | self.run_prometheus_endpoint)
}
} else {
pub fn active(&self) -> bool {
(self.interval != 0) & (self.print_to_stdout | self.write_html_to_file)
}
}
}
}
impl Default for StatisticsConfig {
fn default() -> Self {
Self {
interval: 5,
torrent_peer_histograms: false,
peer_clients: false,
print_to_stdout: false,
write_html_to_file: false,
html_file_path: "tmp/statistics.html".into(),
#[cfg(feature = "prometheus")]
run_prometheus_endpoint: false,
#[cfg(feature = "prometheus")]
prometheus_endpoint_address: SocketAddr::from(([0, 0, 0, 0], 9000)),
#[cfg(feature = "prometheus")]
prometheus_peer_id_prefixes: false,
}
}
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct CleaningConfig {
/// Clean torrents this often (seconds)
pub torrent_cleaning_interval: u64,
/// Clean pending scrape responses this often (seconds)
///
/// In regular operation, there should be no pending scrape responses
/// lingering for long enough to have to be cleaned up this way.
pub pending_scrape_cleaning_interval: u64,
/// Allow clients to use a connection token for this long (seconds)
pub max_connection_age: u32,
/// Remove peers who have not announced for this long (seconds)
pub max_peer_age: u32,
/// Remove pending scrape responses that have not been returned from swarm
/// workers for this long (seconds)
pub max_pending_scrape_age: u32,
}
impl Default for CleaningConfig {
fn default() -> Self {
Self {
torrent_cleaning_interval: 60 * 2,
pending_scrape_cleaning_interval: 60 * 10,
max_connection_age: 60 * 2,
max_peer_age: 60 * 20,
max_pending_scrape_age: 60,
}
}
}
#[cfg(test)]
mod tests {
use super::Config;
::aquatic_toml_config::gen_serialize_deserialize_test!(Config);
}

210
crates/udp/src/lib.rs Normal file
View file

@ -0,0 +1,210 @@
pub mod common;
pub mod config;
pub mod workers;
use std::collections::BTreeMap;
use std::thread::Builder;
use std::time::Duration;
use anyhow::Context;
use crossbeam_channel::{bounded, unbounded};
use signal_hook::consts::{SIGTERM, SIGUSR1};
use signal_hook::iterator::Signals;
use aquatic_common::access_list::update_access_list;
#[cfg(feature = "cpu-pinning")]
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
use aquatic_common::privileges::PrivilegeDropper;
use aquatic_common::{PanicSentinelWatcher, ServerStartInstant};
use common::{
ConnectedRequestSender, ConnectedResponseSender, SocketWorkerIndex, State, SwarmWorkerIndex,
};
use config::Config;
use workers::socket::ConnectionValidator;
pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker";
pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
pub fn run(config: Config) -> ::anyhow::Result<()> {
let mut signals = Signals::new([SIGUSR1, SIGTERM])?;
let state = State::new(config.swarm_workers);
let connection_validator = ConnectionValidator::new(&config)?;
let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel();
let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
update_access_list(&config.access_list, &state.access_list)?;
let mut request_senders = Vec::new();
let mut request_receivers = BTreeMap::new();
let mut response_senders = Vec::new();
let mut response_receivers = BTreeMap::new();
let (statistics_sender, statistics_receiver) = unbounded();
let server_start_instant = ServerStartInstant::new();
for i in 0..config.swarm_workers {
let (request_sender, request_receiver) = if config.worker_channel_size == 0 {
unbounded()
} else {
bounded(config.worker_channel_size)
};
request_senders.push(request_sender);
request_receivers.insert(i, request_receiver);
}
for i in 0..config.socket_workers {
let (response_sender, response_receiver) = if config.worker_channel_size == 0 {
unbounded()
} else {
bounded(config.worker_channel_size)
};
response_senders.push(response_sender);
response_receivers.insert(i, response_receiver);
}
for i in 0..config.swarm_workers {
let sentinel = sentinel.clone();
let config = config.clone();
let state = state.clone();
let request_receiver = request_receivers.remove(&i).unwrap().clone();
let response_sender = ConnectedResponseSender::new(response_senders.clone());
let statistics_sender = statistics_sender.clone();
Builder::new()
.name(format!("swarm-{:02}", i + 1))
.spawn(move || {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
config.swarm_workers,
WorkerIndex::SwarmWorker(i),
);
workers::swarm::run_swarm_worker(
sentinel,
config,
state,
server_start_instant,
request_receiver,
response_sender,
statistics_sender,
SwarmWorkerIndex(i),
)
})
.with_context(|| "spawn swarm worker")?;
}
for i in 0..config.socket_workers {
let sentinel = sentinel.clone();
let state = state.clone();
let config = config.clone();
let connection_validator = connection_validator.clone();
let request_sender =
ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone());
let response_receiver = response_receivers.remove(&i).unwrap();
let priv_dropper = priv_dropper.clone();
Builder::new()
.name(format!("socket-{:02}", i + 1))
.spawn(move || {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
config.swarm_workers,
WorkerIndex::SocketWorker(i),
);
workers::socket::run_socket_worker(
sentinel,
state,
config,
connection_validator,
server_start_instant,
request_sender,
response_receiver,
priv_dropper,
);
})
.with_context(|| "spawn socket worker")?;
}
if config.statistics.active() {
let sentinel = sentinel.clone();
let state = state.clone();
let config = config.clone();
#[cfg(feature = "prometheus")]
if config.statistics.run_prometheus_endpoint {
use metrics_exporter_prometheus::PrometheusBuilder;
use metrics_util::MetricKindMask;
PrometheusBuilder::new()
.idle_timeout(
MetricKindMask::ALL,
Some(Duration::from_secs(config.statistics.interval * 2)),
)
.with_http_listener(config.statistics.prometheus_endpoint_address)
.install()
.with_context(|| {
format!(
"Install prometheus endpoint on {}",
config.statistics.prometheus_endpoint_address
)
})?;
}
Builder::new()
.name("statistics".into())
.spawn(move || {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
config.swarm_workers,
WorkerIndex::Util,
);
workers::statistics::run_statistics_worker(
sentinel,
config,
state,
statistics_receiver,
);
})
.with_context(|| "spawn statistics worker")?;
}
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
config.swarm_workers,
WorkerIndex::Util,
);
for signal in &mut signals {
match signal {
SIGUSR1 => {
let _ = update_access_list(&config.access_list, &state.access_list);
}
SIGTERM => {
if sentinel_watcher.panic_was_triggered() {
return Err(anyhow::anyhow!("worker thread panicked"));
}
break;
}
_ => unreachable!(),
}
}
Ok(())
}

11
crates/udp/src/main.rs Normal file
View file

@ -0,0 +1,11 @@
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
fn main() {
aquatic_common::cli::run_app_with_cli_and_config::<aquatic_udp::config::Config>(
aquatic_udp::APP_NAME,
aquatic_udp::APP_VERSION,
aquatic_udp::run,
None,
)
}

View file

@ -0,0 +1,3 @@
pub mod socket;
pub mod statistics;
pub mod swarm;

View file

@ -0,0 +1,432 @@
use std::io::{Cursor, ErrorKind};
use std::sync::atomic::Ordering;
use std::time::{Duration, Instant};
use aquatic_common::access_list::AccessListCache;
use aquatic_common::ServerStartInstant;
use crossbeam_channel::Receiver;
use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token};
use aquatic_common::{
access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr,
PanicSentinel, ValidUntil,
};
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
use super::storage::PendingScrapeResponseSlab;
use super::validator::ConnectionValidator;
use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6};
pub struct SocketWorker {
config: Config,
shared_state: State,
request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
access_list_cache: AccessListCache,
validator: ConnectionValidator,
server_start_instant: ServerStartInstant,
pending_scrape_responses: PendingScrapeResponseSlab,
socket: UdpSocket,
buffer: [u8; BUFFER_SIZE],
}
impl SocketWorker {
pub fn run(
_sentinel: PanicSentinel,
shared_state: State,
config: Config,
validator: ConnectionValidator,
server_start_instant: ServerStartInstant,
request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
priv_dropper: PrivilegeDropper,
) {
let socket =
UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket"));
let access_list_cache = create_access_list_cache(&shared_state.access_list);
let mut worker = Self {
config,
shared_state,
validator,
server_start_instant,
request_sender,
response_receiver,
access_list_cache,
pending_scrape_responses: Default::default(),
socket,
buffer: [0; BUFFER_SIZE],
};
worker.run_inner();
}
pub fn run_inner(&mut self) {
let mut local_responses = Vec::new();
let mut opt_resend_buffer =
(self.config.network.resend_buffer_max_len > 0).then_some(Vec::new());
let mut events = Events::with_capacity(self.config.network.poll_event_capacity);
let mut poll = Poll::new().expect("create poll");
poll.registry()
.register(&mut self.socket, Token(0), Interest::READABLE)
.expect("register poll");
let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms);
let pending_scrape_cleaning_duration =
Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval);
let mut pending_scrape_valid_until = ValidUntil::new(
self.server_start_instant,
self.config.cleaning.max_pending_scrape_age,
);
let mut last_pending_scrape_cleaning = Instant::now();
let mut iter_counter = 0usize;
loop {
poll.poll(&mut events, Some(poll_timeout))
.expect("failed polling");
for event in events.iter() {
if event.is_readable() {
self.read_and_handle_requests(&mut local_responses, pending_scrape_valid_until);
}
}
// If resend buffer is enabled, send any responses in it
if let Some(resend_buffer) = opt_resend_buffer.as_mut() {
for (response, addr) in resend_buffer.drain(..) {
Self::send_response(
&self.config,
&self.shared_state,
&mut self.socket,
&mut self.buffer,
&mut None,
response,
addr,
);
}
}
// Send any connect and error responses generated by this socket worker
for (response, addr) in local_responses.drain(..) {
Self::send_response(
&self.config,
&self.shared_state,
&mut self.socket,
&mut self.buffer,
&mut opt_resend_buffer,
response,
addr,
);
}
// Check channel for any responses generated by swarm workers
for (response, addr) in self.response_receiver.try_iter() {
let opt_response = match response {
ConnectedResponse::Scrape(r) => self
.pending_scrape_responses
.add_and_get_finished(r)
.map(Response::Scrape),
ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)),
ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)),
};
if let Some(response) = opt_response {
Self::send_response(
&self.config,
&self.shared_state,
&mut self.socket,
&mut self.buffer,
&mut opt_resend_buffer,
response,
addr,
);
}
}
// Run periodic ValidUntil updates and state cleaning
if iter_counter % 256 == 0 {
let seconds_since_start = self.server_start_instant.seconds_elapsed();
pending_scrape_valid_until = ValidUntil::new_with_now(
seconds_since_start,
self.config.cleaning.max_pending_scrape_age,
);
let now = Instant::now();
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration {
self.pending_scrape_responses.clean(seconds_since_start);
last_pending_scrape_cleaning = now;
}
}
iter_counter = iter_counter.wrapping_add(1);
}
}
fn read_and_handle_requests(
&mut self,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
pending_scrape_valid_until: ValidUntil,
) {
let mut requests_received_ipv4: usize = 0;
let mut requests_received_ipv6: usize = 0;
let mut bytes_received_ipv4: usize = 0;
let mut bytes_received_ipv6 = 0;
loop {
match self.socket.recv_from(&mut self.buffer[..]) {
Ok((bytes_read, src)) => {
if src.port() == 0 {
::log::info!("Ignored request from {} because source port is zero", src);
continue;
}
let src = CanonicalSocketAddr::new(src);
let request_parsable = match Request::from_bytes(
&self.buffer[..bytes_read],
self.config.protocol.max_scrape_torrents,
) {
Ok(request) => {
self.handle_request(
local_responses,
pending_scrape_valid_until,
request,
src,
);
true
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if self.validator.connection_id_valid(src, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_responses.push((response.into(), src));
}
}
false
}
};
// Update statistics for converted address
if src.is_ipv4() {
if request_parsable {
requests_received_ipv4 += 1;
}
bytes_received_ipv4 += bytes_read + EXTRA_PACKET_SIZE_IPV4;
} else {
if request_parsable {
requests_received_ipv6 += 1;
}
bytes_received_ipv6 += bytes_read + EXTRA_PACKET_SIZE_IPV6;
}
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
break;
}
Err(err) => {
::log::warn!("recv_from error: {:#}", err);
}
}
}
if self.config.statistics.active() {
self.shared_state
.statistics_ipv4
.requests_received
.fetch_add(requests_received_ipv4, Ordering::Relaxed);
self.shared_state
.statistics_ipv6
.requests_received
.fetch_add(requests_received_ipv6, Ordering::Relaxed);
self.shared_state
.statistics_ipv4
.bytes_received
.fetch_add(bytes_received_ipv4, Ordering::Relaxed);
self.shared_state
.statistics_ipv6
.bytes_received
.fetch_add(bytes_received_ipv6, Ordering::Relaxed);
}
}
fn handle_request(
&mut self,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
pending_scrape_valid_until: ValidUntil,
request: Request,
src: CanonicalSocketAddr,
) {
let access_list_mode = self.config.access_list.mode;
match request {
Request::Connect(request) => {
let connection_id = self.validator.create_connection_id(src);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Request::Announce(request) => {
if self
.validator
.connection_id_valid(src, request.connection_id)
{
if self
.access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
let worker_index =
SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash);
self.request_sender.try_send_to(
worker_index,
ConnectedRequest::Announce(request),
src,
);
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Request::Scrape(request) => {
if self
.validator
.connection_id_valid(src, request.connection_id)
{
let split_requests = self.pending_scrape_responses.prepare_split_requests(
&self.config,
request,
pending_scrape_valid_until,
);
for (swarm_worker_index, request) in split_requests {
self.request_sender.try_send_to(
swarm_worker_index,
ConnectedRequest::Scrape(request),
src,
);
}
}
}
}
}
fn send_response(
config: &Config,
shared_state: &State,
socket: &mut UdpSocket,
buffer: &mut [u8],
opt_resend_buffer: &mut Option<Vec<(Response, CanonicalSocketAddr)>>,
response: Response,
canonical_addr: CanonicalSocketAddr,
) {
let mut cursor = Cursor::new(buffer);
if let Err(err) = response.write(&mut cursor) {
::log::error!("Converting response to bytes failed: {:#}", err);
return;
}
let bytes_written = cursor.position() as usize;
let addr = if config.network.address.is_ipv4() {
canonical_addr
.get_ipv4()
.expect("found peer ipv6 address while running bound to ipv4 address")
} else {
canonical_addr.get_ipv6_mapped()
};
match socket.send_to(&cursor.get_ref()[..bytes_written], addr) {
Ok(amt) if config.statistics.active() => {
let stats = if canonical_addr.is_ipv4() {
let stats = &shared_state.statistics_ipv4;
stats
.bytes_sent
.fetch_add(amt + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed);
stats
} else {
let stats = &shared_state.statistics_ipv6;
stats
.bytes_sent
.fetch_add(amt + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed);
stats
};
match response {
Response::Connect(_) => {
stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed);
}
Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => {
stats
.responses_sent_announce
.fetch_add(1, Ordering::Relaxed);
}
Response::Scrape(_) => {
stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed);
}
Response::Error(_) => {
stats.responses_sent_error.fetch_add(1, Ordering::Relaxed);
}
}
}
Ok(_) => (),
Err(err) => match opt_resend_buffer.as_mut() {
Some(resend_buffer)
if (err.raw_os_error() == Some(libc::ENOBUFS))
|| (err.kind() == ErrorKind::WouldBlock) =>
{
if resend_buffer.len() < config.network.resend_buffer_max_len {
::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err);
resend_buffer.push((response, canonical_addr));
} else {
::log::warn!("Response resend buffer full, dropping response");
}
}
_ => {
::log::warn!("Sending response to {} failed: {:#}", addr, err);
}
},
}
}
}

View file

@ -0,0 +1,128 @@
mod mio;
mod storage;
#[cfg(feature = "io-uring")]
mod uring;
mod validator;
use anyhow::Context;
use aquatic_common::{
privileges::PrivilegeDropper, CanonicalSocketAddr, PanicSentinel, ServerStartInstant,
};
use crossbeam_channel::Receiver;
use socket2::{Domain, Protocol, Socket, Type};
use crate::{
common::{ConnectedRequestSender, ConnectedResponse, State},
config::Config,
};
pub use self::validator::ConnectionValidator;
/// Bytes of data transmitted when sending an IPv4 UDP packet, in addition to payload size
///
/// Consists of:
/// - 8 bit ethernet frame
/// - 14 + 4 bit MAC header and checksum
/// - 20 bit IPv4 header
/// - 8 bit udp header
const EXTRA_PACKET_SIZE_IPV4: usize = 8 + 18 + 20 + 8;
/// Bytes of data transmitted when sending an IPv4 UDP packet, in addition to payload size
///
/// Consists of:
/// - 8 bit ethernet frame
/// - 14 + 4 bit MAC header and checksum
/// - 40 bit IPv6 header
/// - 8 bit udp header
const EXTRA_PACKET_SIZE_IPV6: usize = 8 + 18 + 40 + 8;
pub fn run_socket_worker(
sentinel: PanicSentinel,
shared_state: State,
config: Config,
validator: ConnectionValidator,
server_start_instant: ServerStartInstant,
request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
priv_dropper: PrivilegeDropper,
) {
#[cfg(feature = "io-uring")]
match self::uring::supported_on_current_kernel() {
Ok(()) => {
self::uring::SocketWorker::run(
sentinel,
shared_state,
config,
validator,
server_start_instant,
request_sender,
response_receiver,
priv_dropper,
);
return;
}
Err(err) => {
::log::warn!(
"Falling back to mio because of lacking kernel io_uring support: {:#}",
err
);
}
}
self::mio::SocketWorker::run(
sentinel,
shared_state,
config,
validator,
server_start_instant,
request_sender,
response_receiver,
priv_dropper,
);
}
fn create_socket(
config: &Config,
priv_dropper: PrivilegeDropper,
) -> anyhow::Result<::std::net::UdpSocket> {
let socket = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?
} else {
Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?
};
if config.network.only_ipv6 {
socket
.set_only_v6(true)
.with_context(|| "socket: set only ipv6")?;
}
socket
.set_reuse_port(true)
.with_context(|| "socket: set reuse port")?;
socket
.set_nonblocking(true)
.with_context(|| "socket: set nonblocking")?;
let recv_buffer_size = config.network.socket_recv_buffer_size;
if recv_buffer_size != 0 {
if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) {
::log::error!(
"socket: failed setting recv buffer to {}: {:?}",
recv_buffer_size,
err
);
}
}
socket
.bind(&config.network.address.into())
.with_context(|| format!("socket: bind to {}", config.network.address))?;
priv_dropper.after_socket_creation()?;
Ok(socket.into())
}

View file

@ -0,0 +1,219 @@
use std::collections::BTreeMap;
use hashbrown::HashMap;
use slab::Slab;
use aquatic_common::{SecondsSinceServerStart, ValidUntil};
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
#[derive(Debug)]
pub struct PendingScrapeResponseSlabEntry {
num_pending: usize,
valid_until: ValidUntil,
torrent_stats: BTreeMap<usize, TorrentScrapeStatistics>,
transaction_id: TransactionId,
}
#[derive(Default)]
pub struct PendingScrapeResponseSlab(Slab<PendingScrapeResponseSlabEntry>);
impl PendingScrapeResponseSlab {
pub fn prepare_split_requests(
&mut self,
config: &Config,
request: ScrapeRequest,
valid_until: ValidUntil,
) -> impl IntoIterator<Item = (SwarmWorkerIndex, PendingScrapeRequest)> {
let capacity = config.swarm_workers.min(request.info_hashes.len());
let mut split_requests: HashMap<SwarmWorkerIndex, PendingScrapeRequest> =
HashMap::with_capacity(capacity);
if request.info_hashes.is_empty() {
::log::warn!(
"Attempted to prepare PendingScrapeResponseSlab entry with zero info hashes"
);
return split_requests;
}
let vacant_entry = self.0.vacant_entry();
let slab_key = vacant_entry.key();
for (i, info_hash) in request.info_hashes.into_iter().enumerate() {
let split_request = split_requests
.entry(SwarmWorkerIndex::from_info_hash(&config, info_hash))
.or_insert_with(|| PendingScrapeRequest {
slab_key,
info_hashes: BTreeMap::new(),
});
split_request.info_hashes.insert(i, info_hash);
}
vacant_entry.insert(PendingScrapeResponseSlabEntry {
num_pending: split_requests.len(),
valid_until,
torrent_stats: Default::default(),
transaction_id: request.transaction_id,
});
split_requests
}
pub fn add_and_get_finished(
&mut self,
response: PendingScrapeResponse,
) -> Option<ScrapeResponse> {
let finished = if let Some(entry) = self.0.get_mut(response.slab_key) {
entry.num_pending -= 1;
entry
.torrent_stats
.extend(response.torrent_stats.into_iter());
entry.num_pending == 0
} else {
::log::warn!(
"PendingScrapeResponseSlab.add didn't find entry for key {:?}",
response.slab_key
);
false
};
if finished {
let entry = self.0.remove(response.slab_key);
Some(ScrapeResponse {
transaction_id: entry.transaction_id,
torrent_stats: entry.torrent_stats.into_values().collect(),
})
} else {
None
}
}
pub fn clean(&mut self, now: SecondsSinceServerStart) {
self.0.retain(|k, v| {
if v.valid_until.valid(now) {
true
} else {
::log::warn!(
"Unconsumed PendingScrapeResponseSlab entry. {:?}: {:?}",
k,
v
);
false
}
});
self.0.shrink_to_fit();
}
}
#[cfg(test)]
mod tests {
use aquatic_common::ServerStartInstant;
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;
use super::*;
#[quickcheck]
fn test_pending_scrape_response_slab(
request_data: Vec<(i32, i64, u8)>,
swarm_workers: u8,
) -> TestResult {
if swarm_workers == 0 {
return TestResult::discard();
}
let mut config = Config::default();
config.swarm_workers = swarm_workers as usize;
let valid_until = ValidUntil::new(ServerStartInstant::new(), 1);
let mut map = PendingScrapeResponseSlab::default();
let mut requests = Vec::new();
for (t, c, b) in request_data {
if b == 0 {
return TestResult::discard();
}
let mut info_hashes = Vec::new();
for i in 0..b {
let info_hash = InfoHash([i; 20]);
info_hashes.push(info_hash);
}
let request = ScrapeRequest {
transaction_id: TransactionId(t),
connection_id: ConnectionId(c),
info_hashes,
};
requests.push(request);
}
let mut all_split_requests = Vec::new();
for request in requests.iter() {
let split_requests =
map.prepare_split_requests(&config, request.to_owned(), valid_until);
all_split_requests.push(
split_requests
.into_iter()
.collect::<Vec<(SwarmWorkerIndex, PendingScrapeRequest)>>(),
);
}
assert_eq!(map.0.len(), requests.len());
let mut responses = Vec::new();
for split_requests in all_split_requests {
for (worker_index, split_request) in split_requests {
assert!(worker_index.0 < swarm_workers as usize);
let torrent_stats = split_request
.info_hashes
.into_iter()
.map(|(i, info_hash)| {
(
i,
TorrentScrapeStatistics {
seeders: NumberOfPeers((info_hash.0[0]) as i32),
leechers: NumberOfPeers(0),
completed: NumberOfDownloads(0),
},
)
})
.collect();
let response = PendingScrapeResponse {
slab_key: split_request.slab_key,
torrent_stats,
};
if let Some(response) = map.add_and_get_finished(response) {
responses.push(response);
}
}
}
assert!(map.0.is_empty());
assert_eq!(responses.len(), requests.len());
TestResult::from_bool(true)
}
}

View file

@ -0,0 +1,947 @@
// Copyright (c) 2021 Carl Lerche
//
// Permission is hereby granted, free of charge, to any
// person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the
// Software without restriction, including without
// limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice
// shall be included in all copies or substantial portions
// of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
// Copied (with slight modifications) from
// - https://github.com/FrankReh/tokio-uring/tree/9387c92c98138451f7d760432a04b0b95a406f22/src/buf/bufring
// - https://github.com/FrankReh/tokio-uring/blob/9387c92c98138451f7d760432a04b0b95a406f22/src/buf/bufgroup/mod.rs
//! Module for the io_uring device's buf_ring feature.
// Developer's note about io_uring return codes when a buf_ring is used:
//
// While a buf_ring pool is exhaused, new calls to read that are, or are not, ready to read will
// fail with the 105 error, "no buffers", while existing calls that were waiting to become ready to
// read will not fail. Only when the data becomes ready to read will they fail, if the buffer ring
// is still empty at that time. This makes sense when thinking about it from how the kernel
// implements the start of a read command; it can be confusing when first working with these
// commands from the userland perspective.
// While the file! calls yield the clippy false positive.
#![allow(clippy::print_literal)]
use io_uring::types;
use std::cell::Cell;
use std::io;
use std::rc::Rc;
use std::sync::atomic::{self, AtomicU16};
use super::CurrentRing;
/// The buffer group ID.
///
/// The creater of a buffer group is responsible for picking a buffer group id
/// that does not conflict with other buffer group ids also being registered with the uring
/// interface.
pub(crate) type Bgid = u16;
// Future: Maybe create a bgid module with a trivial implementation of a type that tracks the next
// bgid to use. The crate's driver could do that perhaps, but there could be a benefit to tracking
// them across multiple thread's drivers. So there is flexibility in not building it into the
// driver.
/// The buffer ID. Buffer ids are assigned and used by the crate and probably are not visible
/// to the crate user.
pub(crate) type Bid = u16;
/// This tracks a buffer that has been filled in by the kernel, having gotten the memory
/// from a buffer ring, and returned to userland via a cqe entry.
pub struct BufX {
bgroup: BufRing,
bid: Bid,
len: usize,
}
impl BufX {
// # Safety
//
// The bid must be the buffer id supplied by the kernel as having been chosen and written to.
// The length of the buffer must represent the length written to by the kernel.
pub(crate) unsafe fn new(bgroup: BufRing, bid: Bid, len: usize) -> Self {
// len will already have been checked against the buf_capacity
// so it is guaranteed that len <= bgroup.buf_capacity.
Self { bgroup, bid, len }
}
/// Return the number of bytes initialized.
///
/// This value initially came from the kernel, as reported in the cqe. This value may have been
/// modified with a call to the IoBufMut::set_init method.
#[inline]
pub fn len(&self) -> usize {
self.len
}
/// Return true if this represents an empty buffer. The length reported by the kernel was 0.
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Return the capacity of this buffer.
#[inline]
pub fn cap(&self) -> usize {
self.bgroup.buf_capacity(self.bid)
}
/// Return a byte slice reference.
#[inline]
pub fn as_slice(&self) -> &[u8] {
let p = self.bgroup.stable_ptr(self.bid);
// Safety: the pointer returned by stable_ptr is valid for the lifetime of self,
// and self's len is set when the kernel reports the amount of data that was
// written into the buffer.
unsafe { std::slice::from_raw_parts(p, self.len) }
}
/// Return a mutable byte slice reference.
#[inline]
pub fn as_slice_mut(&mut self) -> &mut [u8] {
let p = self.bgroup.stable_mut_ptr(self.bid);
// Safety: the pointer returned by stable_mut_ptr is valid for the lifetime of self,
// and self's len is set when the kernel reports the amount of data that was
// written into the buffer. In addition, we hold a &mut reference to self.
unsafe { std::slice::from_raw_parts_mut(p, self.len) }
}
// Future: provide access to the uninit space between len and cap if the buffer is being
// repurposed before being dropped. The set_init below does that too.
}
impl Drop for BufX {
fn drop(&mut self) {
// Add the buffer back to the bgroup, for the kernel to reuse.
// Safety: this function may only be called by the buffer's drop function.
unsafe { self.bgroup.dropping_bid(self.bid) };
}
}
/*
unsafe impl crate::buf::IoBuf for BufX {
fn stable_ptr(&self) -> *const u8 {
self.bgroup.stable_ptr(self.bid)
}
fn bytes_init(&self) -> usize {
self.len
}
fn bytes_total(&self) -> usize {
self.cap()
}
}
unsafe impl crate::buf::IoBufMut for BufX {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.bgroup.stable_mut_ptr(self.bid)
}
unsafe fn set_init(&mut self, init_len: usize) {
if self.len < init_len {
let cap = self.bgroup.buf_capacity(self.bid);
assert!(init_len <= cap);
self.len = init_len;
}
}
}
*/
impl From<BufX> for Vec<u8> {
fn from(item: BufX) -> Self {
item.as_slice().to_vec()
}
}
/// A `BufRing` represents the ring and the buffers used with the kernel's io_uring buf_ring
/// feature.
///
/// In this implementation, it is both the ring of buffer entries and the actual buffer
/// allocations.
///
/// A BufRing is created through the [`Builder`] and can be registered automatically by the
/// builder's `build` step or at a later time by the user. Registration involves informing the
/// kernel of the ring's dimensions and its identifier (its buffer group id, which goes by the name
/// `bgid`).
///
/// Multiple buf_rings, here multiple BufRings, can be created and registered. BufRings are
/// reference counted to ensure their memory is live while their BufX buffers are live. When a BufX
/// buffer is dropped, it releases itself back to the BufRing from which it came allowing it to be
/// reused by the kernel.
///
/// It is perhaps worth pointing out that it is the ring itself that is registered with the kernel,
/// not the buffers per se. While a given buf_ring cannot have it size changed dynamically, the
/// buffers that are pushed to the ring by userland, and later potentially re-pushed in the ring,
/// can change. The buffers can be of different sizes and they could come from different allocation
/// blocks. This implementation does not provide that flexibility. Each BufRing comes with its own
/// equal length buffer allocation. And when a BufRing buffer, a BufX, is dropped, its id is pushed
/// back to the ring.
///
/// This is the one and only `Provided Buffers` implementation in `tokio_uring` at the moment and
/// in this version, is a purely concrete type, with a concrete BufX type for buffers that are
/// returned by operations like `recv_provbuf` to the userland application.
///
/// Aside from the register and unregister steps, there are no syscalls used to pass buffers to the
/// kernel. The ring contains a tail memory address that this userland type updates as buffers are
/// added to the ring and which the kernel reads when it needs to pull a buffer from the ring. The
/// kernel does not have a head pointer address that it updates for the userland. The userland
/// (this type), is expected to avoid overwriting the head of the circular ring by keeping track of
/// how many buffers were added to the ring and how many have been returned through the CQE
/// mechanism. This particular implementation does not track the count because all buffers are
/// allocated at the beginning, by the builder, and only its own buffers that came back via a CQE
/// are ever added back to the ring, so it should be impossible to overflow the ring.
#[derive(Clone, Debug)]
pub struct BufRing {
// RawBufRing uses cell for fields where necessary.
raw: Rc<RawBufRing>,
}
// Methods the BufX needs.
impl BufRing {
pub(crate) fn buf_capacity(&self, _: Bid) -> usize {
self.raw.buf_capacity_i()
}
pub(crate) fn stable_ptr(&self, bid: Bid) -> *const u8 {
// Will panic if bid is out of range.
self.raw.stable_ptr_i(bid)
}
pub(crate) fn stable_mut_ptr(&mut self, bid: Bid) -> *mut u8 {
// Safety: self is &mut, we're good.
unsafe { self.raw.stable_mut_ptr_i(bid) }
}
// # Safety
//
// `dropping_bid` should only be called by the buffer's drop function because once called, the
// buffer may be given back to the kernel for reuse.
pub(crate) unsafe fn dropping_bid(&self, bid: Bid) {
self.raw.dropping_bid_i(bid);
}
}
// Methods the io operations need.
impl BufRing {
pub(crate) fn bgid(&self) -> Bgid {
self.raw.bgid()
}
// # Safety
//
// The res and flags values are used to lookup a buffer and set its initialized length.
// The caller is responsible for these being correct. This is expected to be called
// when these two values are received from the kernel via a CQE and we rely on the kernel to
// give us correct information.
pub(crate) unsafe fn get_buf(&self, res: u32, flags: u32) -> io::Result<Option<BufX>> {
let bid = match io_uring::cqueue::buffer_select(flags) {
Some(bid) => bid,
None => {
// Have seen res == 0, flags == 4 with a TCP socket. res == 0 we take to mean the
// socket is empty so return None to show there is no buffer returned, which should
// be interpreted to mean there is no more data to read from this file or socket.
if res == 0 {
return Ok(None);
}
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"BufRing::get_buf failed as the buffer bit, IORING_CQE_F_BUFFER, was missing from flags, res = {}, flags = {}",
res, flags)
));
}
};
let len = res as usize;
/*
let flags = flags & !io_uring::sys::IORING_CQE_F_BUFFER; // for tracing flags
println!(
"{}:{}: get_buf res({res})=len({len}) flags({:#x})->bid({bid})\n\n",
file!(),
line!(),
flags
);
*/
assert!(len <= self.raw.buf_len);
// TODO maybe later
// #[cfg(any(debug, feature = "cautious"))]
// {
// let mut debug_bitmap = self.debug_bitmap.borrow_mut();
// let m = 1 << (bid % 8);
// assert!(debug_bitmap[(bid / 8) as usize] & m == m);
// debug_bitmap[(bid / 8) as usize] &= !m;
// }
self.raw.metric_getting_another();
/*
println!(
"{}:{}: get_buf cur {}, min {}",
file!(),
line!(),
self.possible_cur.get(),
self.possible_min.get(),
);
*/
// Safety: the len provided to BufX::new is given to us from the kernel.
Ok(Some(unsafe { BufX::new(self.clone(), bid, len) }))
}
}
#[derive(Debug, Copy, Clone)]
/// Build the arguments to call build() that returns a [`BufRing`].
///
/// Refer to the methods descriptions for details.
#[allow(dead_code)]
pub struct Builder {
page_size: usize,
bgid: Bgid,
ring_entries: u16,
buf_cnt: u16,
buf_len: usize,
buf_align: usize,
ring_pad: usize,
bufend_align: usize,
skip_register: bool,
}
#[allow(dead_code)]
impl Builder {
/// Create a new Builder with the given buffer group ID and defaults.
///
/// The buffer group ID, `bgid`, is the id the kernel's io_uring device uses to identify the
/// provided buffer pool to use by operations that are posted to the device.
///
/// The user is responsible for picking a bgid that does not conflict with other buffer groups
/// that have been registered with the same uring interface.
pub fn new(bgid: Bgid) -> Builder {
Builder {
page_size: 4096,
bgid,
ring_entries: 128,
buf_cnt: 0,
buf_len: 4096,
buf_align: 0,
ring_pad: 0,
bufend_align: 0,
skip_register: false,
}
}
/// The page size of the kernel. Defaults to 4096.
///
/// The io_uring device requires the BufRing is allocated on the start of a page, i.e. with a
/// page size alignment.
///
/// The caller should determine the page size, and may want to cache the info if multiple buf
/// rings are to be created. Crates are available to get this information or the user may want
/// to call the libc sysconf directly:
///
/// use libc::{_SC_PAGESIZE, sysconf};
/// let page_size: usize = unsafe { sysconf(_SC_PAGESIZE) as usize };
pub fn page_size(mut self, page_size: usize) -> Builder {
self.page_size = page_size;
self
}
/// The number of ring entries to create for the buffer ring.
///
/// This defaults to 128 or the `buf_cnt`, whichever is larger.
///
/// The number will be made a power of 2, and will be the maximum of the ring_entries setting
/// and the buf_cnt setting. The interface will enforce a maximum of 2^15 (32768) so it can do
/// rollover calculation.
///
/// Each ring entry is 16 bytes.
pub fn ring_entries(mut self, ring_entries: u16) -> Builder {
self.ring_entries = ring_entries;
self
}
/// The number of buffers to allocate. If left zero, the ring_entries value will be used and
/// that value defaults to 128.
pub fn buf_cnt(mut self, buf_cnt: u16) -> Builder {
self.buf_cnt = buf_cnt;
self
}
/// The length of each allocated buffer. Defaults to 4096.
///
/// Non-alignment values are possible and `buf_align` can be used to allocate each buffer on
/// an alignment buffer, even if the buffer length is not desired to equal the alignment.
pub fn buf_len(mut self, buf_len: usize) -> Builder {
self.buf_len = buf_len;
self
}
/// The alignment of the first buffer allocated.
///
/// Generally not needed.
///
/// The buffers are allocated right after the ring unless `ring_pad` is used and generally the
/// buffers are allocated contiguous to one another unless the `buf_len` is set to something
/// different.
pub fn buf_align(mut self, buf_align: usize) -> Builder {
self.buf_align = buf_align;
self
}
/// Pad to place after ring to ensure separation between rings and first buffer.
///
/// Generally not needed but may be useful if the ring's end and the buffers' start are to have
/// some separation, perhaps for cacheline reasons.
pub fn ring_pad(mut self, ring_pad: usize) -> Builder {
self.ring_pad = ring_pad;
self
}
/// The alignment of the end of the buffer allocated. To keep other things out of a cache line
/// or out of a page, if that's desired.
pub fn bufend_align(mut self, bufend_align: usize) -> Builder {
self.bufend_align = bufend_align;
self
}
/// Skip automatic registration. The caller can manually invoke the buf_ring.register()
/// function later. Regardless, the unregister() method will be called automatically when the
/// BufRing goes out of scope if the caller hadn't manually called buf_ring.unregister()
/// already.
pub fn skip_auto_register(mut self, skip: bool) -> Builder {
self.skip_register = skip;
self
}
/// Return a BufRing, having computed the layout for the single aligned allocation
/// of both the buffer ring elements and the buffers themselves.
///
/// If auto_register was left enabled, register the BufRing with the driver.
pub fn build(&self) -> io::Result<BufRing> {
let mut b: Builder = *self;
// Two cases where both buf_cnt and ring_entries are set to the max of the two.
if b.buf_cnt == 0 || b.ring_entries < b.buf_cnt {
let max = std::cmp::max(b.ring_entries, b.buf_cnt);
b.buf_cnt = max;
b.ring_entries = max;
}
// Don't allow the next_power_of_two calculation to be done if already larger than 2^15
// because 2^16 reads back as 0 in a u16. And the interface doesn't allow for ring_entries
// larger than 2^15 anyway, so this is a good place to catch it. Here we return a unique
// error that is more descriptive than the InvalidArg that would come from the interface.
if b.ring_entries > (1 << 15) {
return Err(io::Error::new(
io::ErrorKind::Other,
"ring_entries exceeded 32768",
));
}
// Requirement of the interface is the ring entries is a power of two, making its and our
// mask calculation trivial.
b.ring_entries = b.ring_entries.next_power_of_two();
Ok(BufRing {
raw: Rc::new(RawBufRing::new(NewArgs {
page_size: b.page_size,
bgid: b.bgid,
ring_entries: b.ring_entries,
buf_cnt: b.buf_cnt,
buf_len: b.buf_len,
buf_align: b.buf_align,
ring_pad: b.ring_pad,
bufend_align: b.bufend_align,
auto_register: !b.skip_register,
})?),
})
}
}
// Trivial helper struct for this module.
struct NewArgs {
page_size: usize,
bgid: Bgid,
ring_entries: u16,
buf_cnt: u16,
buf_len: usize,
buf_align: usize,
ring_pad: usize,
bufend_align: usize,
auto_register: bool,
}
#[derive(Debug)]
struct RawBufRing {
bgid: Bgid,
// Keep mask rather than ring size because mask is used often, ring size not.
//ring_entries: u16, // Invariants: > 0, power of 2, max 2^15 (32768).
ring_entries_mask: u16, // Invariant one less than ring_entries which is > 0, power of 2, max 2^15 (32768).
buf_cnt: u16, // Invariants: > 0, <= ring_entries.
buf_len: usize, // Invariant: > 0.
layout: std::alloc::Layout,
ring_addr: *const types::BufRingEntry, // Invariant: constant.
buffers_addr: *mut u8, // Invariant: constant.
local_tail: Cell<u16>,
tail_addr: *const AtomicU16,
registered: Cell<bool>,
// The first `possible` field is a best effort at tracking the current buffer pool usage and
// from that, tracking the lowest level that has been reached. The two are an attempt at
// letting the user check the sizing needs of their buf_ring pool.
//
// We don't really know how deep the uring device has gone into the pool because we never see
// its head value and it can be taking buffers from the ring, in-flight, while we add buffers
// back to the ring. All we know is when a CQE arrives and a buffer lookup is performed, a
// buffer has already been taken from the pool, and when the buffer is dropped, we add it back
// to the the ring and it is about to be considered part of the pool again.
possible_cur: Cell<u16>,
possible_min: Cell<u16>,
//
// TODO maybe later
// #[cfg(any(debug, feature = "cautious"))]
// debug_bitmap: RefCell<std::vec::Vec<u8>>,
}
impl RawBufRing {
fn new(new_args: NewArgs) -> io::Result<RawBufRing> {
#[allow(non_upper_case_globals)]
const trace: bool = false;
let NewArgs {
page_size,
bgid,
ring_entries,
buf_cnt,
buf_len,
buf_align,
ring_pad,
bufend_align,
auto_register,
} = new_args;
// Check that none of the important args are zero and the ring_entries is at least large
// enough to hold all the buffers and that ring_entries is a power of 2.
if (buf_cnt == 0)
|| (buf_cnt > ring_entries)
|| (buf_len == 0)
|| ((ring_entries & (ring_entries - 1)) != 0)
{
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
// entry_size is 16 bytes.
let entry_size = std::mem::size_of::<types::BufRingEntry>();
let mut ring_size = entry_size * (ring_entries as usize);
if trace {
println!(
"{}:{}: entry_size {} * ring_entries {} = ring_size {} {:#x}",
file!(),
line!(),
entry_size,
ring_entries,
ring_size,
ring_size,
);
}
ring_size += ring_pad;
if trace {
println!(
"{}:{}: after +ring_pad {} ring_size {} {:#x}",
file!(),
line!(),
ring_pad,
ring_size,
ring_size,
);
}
if buf_align > 0 {
let buf_align = buf_align.next_power_of_two();
ring_size = (ring_size + (buf_align - 1)) & !(buf_align - 1);
if trace {
println!(
"{}:{}: after buf_align ring_size {} {:#x}",
file!(),
line!(),
ring_size,
ring_size,
);
}
}
let buf_size = buf_len * (buf_cnt as usize);
assert!(ring_size != 0);
assert!(buf_size != 0);
let mut tot_size: usize = ring_size + buf_size;
if trace {
println!(
"{}:{}: ring_size {} {:#x} + buf_size {} {:#x} = tot_size {} {:#x}",
file!(),
line!(),
ring_size,
ring_size,
buf_size,
buf_size,
tot_size,
tot_size
);
}
if bufend_align > 0 {
// for example, if bufend_align is 4096, would make total size a multiple of pages
let bufend_align = bufend_align.next_power_of_two();
tot_size = (tot_size + (bufend_align - 1)) & !(bufend_align - 1);
if trace {
println!(
"{}:{}: after bufend_align tot_size {} {:#x}",
file!(),
line!(),
tot_size,
tot_size,
);
}
}
let align: usize = page_size; // alignment must be at least the page size
let align = align.next_power_of_two();
let layout = std::alloc::Layout::from_size_align(tot_size, align).unwrap();
assert!(layout.size() >= ring_size);
// Safety: we are assured layout has nonzero size, we passed the assert just above.
let ring_addr: *mut u8 = unsafe { std::alloc::alloc_zeroed(layout) };
// Buffers starts after the ring_size.
// Safety: are we assured the address and the offset are in bounds because the ring_addr is
// the value we got from the alloc call, and the layout.size was shown to be at least as
// large as the ring_size.
let buffers_addr: *mut u8 = unsafe { ring_addr.add(ring_size) };
if trace {
println!(
"{}:{}: ring_addr {} {:#x}, layout: size {} align {}",
file!(),
line!(),
ring_addr as u64,
ring_addr as u64,
layout.size(),
layout.align()
);
println!(
"{}:{}: buffers_addr {} {:#x}",
file!(),
line!(),
buffers_addr as u64,
buffers_addr as u64,
);
}
let ring_addr: *const types::BufRingEntry = ring_addr as _;
// Safety: the ring_addr passed into tail is the start of the ring. It is both the start of
// the ring and the first entry in the ring.
let tail_addr = unsafe { types::BufRingEntry::tail(ring_addr) } as *const AtomicU16;
let ring_entries_mask = ring_entries - 1;
assert!((ring_entries & ring_entries_mask) == 0);
let buf_ring = RawBufRing {
bgid,
ring_entries_mask,
buf_cnt,
buf_len,
layout,
ring_addr,
buffers_addr,
local_tail: Cell::new(0),
tail_addr,
registered: Cell::new(false),
possible_cur: Cell::new(0),
possible_min: Cell::new(buf_cnt),
//
// TODO maybe later
// #[cfg(any(debug, feature = "cautious"))]
// debug_bitmap: RefCell::new(std::vec![0; ((buf_cnt+7)/8) as usize]),
};
// Question had come up: where should the initial buffers be added to the ring?
// Here when the ring is created, even before it is registered potentially?
// Or after registration?
//
// For this type, BufRing, we are adding the buffers to the ring as the last part of creating the BufRing,
// even before registration is optionally performed.
//
// We've seen the registration to be successful, even when the ring starts off empty.
// Add the buffers here where the ring is created.
for bid in 0..buf_cnt {
buf_ring.buf_ring_add(bid);
}
buf_ring.buf_ring_sync();
// The default is to register the buffer ring right here. There is usually no reason the
// caller should want to register it some time later.
//
// Perhaps the caller wants to allocate the buffer ring before the CONTEXT driver is in
// place - that would be a reason to delay the register call until later.
if auto_register {
buf_ring.register()?;
}
Ok(buf_ring)
}
/// Register the buffer ring with the kernel.
/// Normally this is done automatically when building a BufRing.
///
/// This method must be called in the context of a `tokio-uring` runtime.
/// The registration persists for the lifetime of the runtime, unless
/// revoked by the [`unregister`] method. Dropping the
/// instance this method has been called on does revoke
/// the registration and deallocate the buffer space.
///
/// [`unregister`]: Self::unregister
///
/// # Errors
///
/// If a `Provided Buffers` group with the same `bgid` is already registered, the function
/// returns an error.
fn register(&self) -> io::Result<()> {
let bgid = self.bgid;
//println!("{}:{}: register bgid {bgid}", file!(), line!());
// Future: move to separate public function so other buf_ring implementations
// can register, and unregister, the same way.
let res = CurrentRing::with(|ring| unsafe {
ring.submitter()
.register_buf_ring(self.ring_addr as _, self.ring_entries(), bgid)
});
// println!("{}:{}: res {:?}", file!(), line!(), res);
if let Err(e) = res {
match e.raw_os_error() {
Some(22) => {
// using buf_ring requires kernel 5.19 or greater.
// TODO turn these eprintln into new, more expressive error being returned.
// TODO what convention should we follow in this crate for adding information
// onto an error?
eprintln!(
"buf_ring.register returned {e}, most likely indicating this kernel is not 5.19+",
);
}
Some(17) => {
// Registering a duplicate bgid is not allowed. There is an `unregister`
// operations that can remove the first.
eprintln!(
"buf_ring.register returned `{e}`, indicating the attempted buffer group id {bgid} was already registered",
);
}
_ => {
eprintln!("buf_ring.register returned `{e}` for group id {bgid}");
}
}
return Err(e);
};
self.registered.set(true);
res
}
/// Unregister the buffer ring from the io_uring.
/// Normally this is done automatically when the BufRing goes out of scope.
///
/// Warning: requires the CONTEXT driver is already in place or will panic.
fn unregister(&self) -> io::Result<()> {
// If not registered, make this a no-op.
if !self.registered.get() {
return Ok(());
}
self.registered.set(false);
let bgid = self.bgid;
CurrentRing::with(|ring| ring.submitter().unregister_buf_ring(bgid))
}
/// Returns the buffer group id.
#[inline]
fn bgid(&self) -> Bgid {
self.bgid
}
fn metric_getting_another(&self) {
self.possible_cur.set(self.possible_cur.get() - 1);
self.possible_min.set(std::cmp::min(
self.possible_min.get(),
self.possible_cur.get(),
));
}
// # Safety
//
// Dropping a duplicate bid is likely to cause undefined behavior
// as the kernel uses the same buffer for different data concurrently.
unsafe fn dropping_bid_i(&self, bid: Bid) {
self.buf_ring_add(bid);
self.buf_ring_sync();
}
#[inline]
fn buf_capacity_i(&self) -> usize {
self.buf_len as _
}
#[inline]
// # Panic
//
// This function will panic if given a bid that is not within the valid range 0..self.buf_cnt.
fn stable_ptr_i(&self, bid: Bid) -> *const u8 {
assert!(bid < self.buf_cnt);
let offset: usize = self.buf_len * (bid as usize);
// Safety: buffers_addr is an u8 pointer and was part of an allocation large enough to hold
// buf_cnt number of buf_len buffers. buffers_addr, buf_cnt and buf_len are treated as
// constants and bid was just asserted to be less than buf_cnt.
unsafe { self.buffers_addr.add(offset) }
}
// # Safety
//
// This may only be called by an owned or &mut object.
//
// # Panic
// This will panic if bid is out of range.
#[inline]
unsafe fn stable_mut_ptr_i(&self, bid: Bid) -> *mut u8 {
assert!(bid < self.buf_cnt);
let offset: usize = self.buf_len * (bid as usize);
// Safety: buffers_addr is an u8 pointer and was part of an allocation large enough to hold
// buf_cnt number of buf_len buffers. buffers_addr, buf_cnt and buf_len are treated as
// constants and bid was just asserted to be less than buf_cnt.
self.buffers_addr.add(offset)
}
#[inline]
fn ring_entries(&self) -> u16 {
self.ring_entries_mask + 1
}
#[inline]
fn mask(&self) -> u16 {
self.ring_entries_mask
}
// Writes to a ring entry and updates our local copy of the tail.
//
// Adds the buffer known by its buffer id to the buffer ring. The buffer's address and length
// are known given its bid.
//
// This does not sync the new tail value. The caller should use `buf_ring_sync` for that.
//
// Panics if the bid is out of range.
fn buf_ring_add(&self, bid: Bid) {
// Compute address of current tail position, increment the local copy of the tail. Then
// write the buffer's address, length and bid into the current tail entry.
let cur_tail = self.local_tail.get();
self.local_tail.set(cur_tail.wrapping_add(1));
let ring_idx = cur_tail & self.mask();
let ring_addr = self.ring_addr as *mut types::BufRingEntry;
// Safety:
// 1. the pointer address (ring_addr), is set and const at self creation time,
// and points to a block of memory at least as large as the number of ring_entries,
// 2. the mask used to create ring_idx is one less than
// the number of ring_entries, and ring_entries was tested to be a power of two,
// So the address gotten by adding ring_idx entries to ring_addr is guaranteed to
// be a valid address of a ring entry.
let entry = unsafe { &mut *ring_addr.add(ring_idx as usize) };
entry.set_addr(self.stable_ptr_i(bid) as _);
entry.set_len(self.buf_len as _);
entry.set_bid(bid);
// Update accounting.
self.possible_cur.set(self.possible_cur.get() + 1);
// TODO maybe later
// #[cfg(any(debug, feature = "cautious"))]
// {
// let mut debug_bitmap = self.debug_bitmap.borrow_mut();
// let m = 1 << (bid % 8);
// assert!(debug_bitmap[(bid / 8) as usize] & m == 0);
// debug_bitmap[(bid / 8) as usize] |= m;
// }
}
// Make 'count' new buffers visible to the kernel. Called after
// io_uring_buf_ring_add() has been called 'count' times to fill in new
// buffers.
#[inline]
fn buf_ring_sync(&self) {
// Safety: dereferencing this raw pointer is safe. The tail_addr was computed once at init
// to refer to the tail address in the ring and is held const for self's lifetime.
unsafe {
(*self.tail_addr).store(self.local_tail.get(), atomic::Ordering::Release);
}
// The liburing code did io_uring_smp_store_release(&br.tail, local_tail);
}
// Return the possible_min buffer pool size.
#[allow(dead_code)]
fn possible_min(&self) -> u16 {
self.possible_min.get()
}
// Return the possible_min buffer pool size and reset to allow fresh counting going forward.
#[allow(dead_code)]
fn possible_min_and_reset(&self) -> u16 {
let res = self.possible_min.get();
self.possible_min.set(self.buf_cnt);
res
}
}
impl Drop for RawBufRing {
fn drop(&mut self) {
if self.registered.get() {
_ = self.unregister();
}
// Safety: the ptr and layout are treated as constant, and ptr (ring_addr) was assigned by
// a call to std::alloc::alloc_zeroed using the same layout.
unsafe { std::alloc::dealloc(self.ring_addr as *mut u8, self.layout) };
}
}

View file

@ -0,0 +1,548 @@
mod buf_ring;
mod recv_helper;
mod send_buffers;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::net::UdpSocket;
use std::ops::DerefMut;
use std::os::fd::AsRawFd;
use std::sync::atomic::Ordering;
use anyhow::Context;
use aquatic_common::access_list::AccessListCache;
use aquatic_common::ServerStartInstant;
use crossbeam_channel::Receiver;
use io_uring::opcode::Timeout;
use io_uring::types::{Fixed, Timespec};
use io_uring::{IoUring, Probe};
use aquatic_common::{
access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr,
PanicSentinel, ValidUntil,
};
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
use self::buf_ring::BufRing;
use self::recv_helper::RecvHelper;
use self::send_buffers::{ResponseType, SendBuffers};
use super::storage::PendingScrapeResponseSlab;
use super::validator::ConnectionValidator;
use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6};
/// Size of each request buffer
///
/// Enough for scrape request with 20 info hashes
const REQUEST_BUF_LEN: usize = 256;
/// Size of each response buffer
///
/// Enough for:
/// - IPv6 announce response with 112 peers
/// - scrape response for 170 info hashes
const RESPONSE_BUF_LEN: usize = 2048;
const USER_DATA_RECV: u64 = u64::MAX;
const USER_DATA_PULSE_TIMEOUT: u64 = u64::MAX - 1;
const USER_DATA_CLEANING_TIMEOUT: u64 = u64::MAX - 2;
const SOCKET_IDENTIFIER: Fixed = Fixed(0);
thread_local! {
/// Store IoUring instance here so that it can be accessed in BufRing::drop
pub static CURRENT_RING: CurrentRing = CurrentRing(RefCell::new(None));
}
pub struct CurrentRing(RefCell<Option<IoUring>>);
impl CurrentRing {
fn with<F, T>(mut f: F) -> T
where
F: FnMut(&mut IoUring) -> T,
{
CURRENT_RING.with(|r| {
let mut opt_ring = r.0.borrow_mut();
f(Option::as_mut(opt_ring.deref_mut()).expect("IoUring not set"))
})
}
}
pub struct SocketWorker {
config: Config,
shared_state: State,
request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
access_list_cache: AccessListCache,
validator: ConnectionValidator,
server_start_instant: ServerStartInstant,
#[allow(dead_code)]
socket: UdpSocket,
pending_scrape_responses: PendingScrapeResponseSlab,
buf_ring: BufRing,
send_buffers: SendBuffers,
recv_helper: RecvHelper,
local_responses: VecDeque<(Response, CanonicalSocketAddr)>,
resubmittable_sqe_buf: Vec<io_uring::squeue::Entry>,
recv_sqe: io_uring::squeue::Entry,
pulse_timeout_sqe: io_uring::squeue::Entry,
cleaning_timeout_sqe: io_uring::squeue::Entry,
pending_scrape_valid_until: ValidUntil,
}
impl SocketWorker {
pub fn run(
_sentinel: PanicSentinel,
shared_state: State,
config: Config,
validator: ConnectionValidator,
server_start_instant: ServerStartInstant,
request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
priv_dropper: PrivilegeDropper,
) {
let ring_entries = config.network.ring_size.next_power_of_two();
// Try to fill up the ring with send requests
let send_buffer_entries = ring_entries;
let socket = create_socket(&config, priv_dropper).expect("create socket");
let access_list_cache = create_access_list_cache(&shared_state.access_list);
let send_buffers = SendBuffers::new(&config, send_buffer_entries as usize);
let recv_helper = RecvHelper::new(&config);
let ring = IoUring::builder()
.setup_coop_taskrun()
.setup_single_issuer()
.setup_submit_all()
.build(ring_entries.into())
.unwrap();
ring.submitter()
.register_files(&[socket.as_raw_fd()])
.unwrap();
// Store ring in thread local storage before creating BufRing
CURRENT_RING.with(|r| *r.0.borrow_mut() = Some(ring));
let buf_ring = buf_ring::Builder::new(0)
.ring_entries(ring_entries)
.buf_len(REQUEST_BUF_LEN)
.build()
.unwrap();
let recv_sqe = recv_helper.create_entry(buf_ring.bgid().try_into().unwrap());
// This timeout enables regular updates of pending_scrape_valid_until
// and wakes the main loop to send any pending responses in the case
// of no incoming requests
let pulse_timeout_sqe = {
let timespec_ptr = Box::into_raw(Box::new(Timespec::new().sec(1))) as *const _;
Timeout::new(timespec_ptr)
.build()
.user_data(USER_DATA_PULSE_TIMEOUT)
};
let cleaning_timeout_sqe = {
let timespec_ptr = Box::into_raw(Box::new(
Timespec::new().sec(config.cleaning.pending_scrape_cleaning_interval),
)) as *const _;
Timeout::new(timespec_ptr)
.build()
.user_data(USER_DATA_CLEANING_TIMEOUT)
};
let resubmittable_sqe_buf = vec![
recv_sqe.clone(),
pulse_timeout_sqe.clone(),
cleaning_timeout_sqe.clone(),
];
let pending_scrape_valid_until =
ValidUntil::new(server_start_instant, config.cleaning.max_pending_scrape_age);
let mut worker = Self {
config,
shared_state,
validator,
server_start_instant,
request_sender,
response_receiver,
access_list_cache,
pending_scrape_responses: Default::default(),
send_buffers,
recv_helper,
local_responses: Default::default(),
buf_ring,
recv_sqe,
pulse_timeout_sqe,
cleaning_timeout_sqe,
resubmittable_sqe_buf,
socket,
pending_scrape_valid_until,
};
CurrentRing::with(|ring| worker.run_inner(ring));
}
fn run_inner(&mut self, ring: &mut IoUring) {
loop {
for sqe in self.resubmittable_sqe_buf.drain(..) {
unsafe { ring.submission().push(&sqe).unwrap() };
}
let sq_space = {
let sq = ring.submission();
sq.capacity() - sq.len()
};
let mut num_send_added = 0;
// Enqueue local responses
for _ in 0..sq_space {
if let Some((response, addr)) = self.local_responses.pop_front() {
match self.send_buffers.prepare_entry(&response, addr) {
Ok(entry) => {
unsafe { ring.submission().push(&entry).unwrap() };
num_send_added += 1;
}
Err(send_buffers::Error::NoBuffers) => {
self.local_responses.push_front((response, addr));
break;
}
Err(send_buffers::Error::SerializationFailed(err)) => {
::log::error!("Failed serializing response: {:#}", err);
}
}
} else {
break;
}
}
// Enqueue swarm worker responses
for _ in 0..(sq_space - num_send_added) {
if let Some((response, addr)) = self.get_next_swarm_response() {
match self.send_buffers.prepare_entry(&response, addr) {
Ok(entry) => {
unsafe { ring.submission().push(&entry).unwrap() };
num_send_added += 1;
}
Err(send_buffers::Error::NoBuffers) => {
self.local_responses.push_back((response, addr));
break;
}
Err(send_buffers::Error::SerializationFailed(err)) => {
::log::error!("Failed serializing response: {:#}", err);
}
}
} else {
break;
}
}
// Wait for all sendmsg entries to complete. If none were added,
// wait for at least one recvmsg or timeout in order to avoid
// busy-polling if there is no incoming data.
ring.submitter()
.submit_and_wait(num_send_added.max(1))
.unwrap();
for cqe in ring.completion() {
self.handle_cqe(cqe);
}
self.send_buffers.reset_likely_next_free_index();
}
}
fn handle_cqe(&mut self, cqe: io_uring::cqueue::Entry) {
match cqe.user_data() {
USER_DATA_RECV => {
self.handle_recv_cqe(&cqe);
if !io_uring::cqueue::more(cqe.flags()) {
self.resubmittable_sqe_buf.push(self.recv_sqe.clone());
}
}
USER_DATA_PULSE_TIMEOUT => {
self.pending_scrape_valid_until = ValidUntil::new(
self.server_start_instant,
self.config.cleaning.max_pending_scrape_age,
);
::log::info!(
"pending responses: {} local, {} swarm",
self.local_responses.len(),
self.response_receiver.len()
);
self.resubmittable_sqe_buf
.push(self.pulse_timeout_sqe.clone());
}
USER_DATA_CLEANING_TIMEOUT => {
self.pending_scrape_responses
.clean(self.server_start_instant.seconds_elapsed());
self.resubmittable_sqe_buf
.push(self.cleaning_timeout_sqe.clone());
}
send_buffer_index => {
let result = cqe.result();
if result < 0 {
::log::error!(
"Couldn't send response: {:#}",
::std::io::Error::from_raw_os_error(-result)
);
} else if self.config.statistics.active() {
let send_buffer_index = send_buffer_index as usize;
let (response_type, receiver_is_ipv4) =
self.send_buffers.response_type_and_ipv4(send_buffer_index);
let (statistics, extra_bytes) = if receiver_is_ipv4 {
(&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4)
} else {
(&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6)
};
statistics
.bytes_sent
.fetch_add(result as usize + extra_bytes, Ordering::Relaxed);
let response_counter = match response_type {
ResponseType::Connect => &statistics.responses_sent_connect,
ResponseType::Announce => &statistics.responses_sent_announce,
ResponseType::Scrape => &statistics.responses_sent_scrape,
ResponseType::Error => &statistics.responses_sent_error,
};
response_counter.fetch_add(1, Ordering::Relaxed);
}
// Safety: OK because cqe using buffer has been returned and
// contents will no longer be accessed by kernel
unsafe {
self.send_buffers
.mark_buffer_as_free(send_buffer_index as usize);
}
}
}
}
fn handle_recv_cqe(&mut self, cqe: &io_uring::cqueue::Entry) {
let result = cqe.result();
if result < 0 {
if -result == libc::ENOBUFS {
::log::info!("recv failed due to lack of buffers. If increasing ring size doesn't help, get faster hardware");
} else {
::log::warn!(
"recv failed: {:#}",
::std::io::Error::from_raw_os_error(-result)
);
}
return;
}
let buffer = unsafe {
match self.buf_ring.get_buf(result as u32, cqe.flags()) {
Ok(Some(buffer)) => buffer,
Ok(None) => {
::log::error!("Couldn't get recv buffer");
return;
}
Err(err) => {
::log::error!("Couldn't get recv buffer: {:#}", err);
return;
}
}
};
let buffer = buffer.as_slice();
let addr = match self.recv_helper.parse(buffer) {
Ok((request, addr)) => {
self.handle_request(request, addr);
addr
}
Err(self::recv_helper::Error::RequestParseError(err, addr)) => {
match err {
RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} => {
::log::debug!("Couldn't parse request from {:?}: {}", addr, err);
if self.validator.connection_id_valid(addr, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
self.local_responses.push_back((response.into(), addr));
}
}
RequestParseError::Unsendable { err } => {
::log::debug!("Couldn't parse request from {:?}: {}", addr, err);
}
}
addr
}
Err(self::recv_helper::Error::InvalidSocketAddress) => {
::log::debug!("Ignored request claiming to be from port 0");
return;
}
Err(self::recv_helper::Error::RecvMsgParseError) => {
::log::error!("RecvMsgOut::parse failed");
return;
}
};
if self.config.statistics.active() {
let (statistics, extra_bytes) = if addr.is_ipv4() {
(&self.shared_state.statistics_ipv4, EXTRA_PACKET_SIZE_IPV4)
} else {
(&self.shared_state.statistics_ipv6, EXTRA_PACKET_SIZE_IPV6)
};
statistics
.bytes_received
.fetch_add(buffer.len() + extra_bytes, Ordering::Relaxed);
statistics.requests_received.fetch_add(1, Ordering::Relaxed);
}
}
fn handle_request(&mut self, request: Request, src: CanonicalSocketAddr) {
let access_list_mode = self.config.access_list.mode;
match request {
Request::Connect(request) => {
let connection_id = self.validator.create_connection_id(src);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
self.local_responses.push_back((response, src));
}
Request::Announce(request) => {
if self
.validator
.connection_id_valid(src, request.connection_id)
{
if self
.access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
let worker_index =
SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash);
self.request_sender.try_send_to(
worker_index,
ConnectedRequest::Announce(request),
src,
);
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
self.local_responses.push_back((response, src))
}
}
}
Request::Scrape(request) => {
if self
.validator
.connection_id_valid(src, request.connection_id)
{
let split_requests = self.pending_scrape_responses.prepare_split_requests(
&self.config,
request,
self.pending_scrape_valid_until,
);
for (swarm_worker_index, request) in split_requests {
self.request_sender.try_send_to(
swarm_worker_index,
ConnectedRequest::Scrape(request),
src,
);
}
}
}
}
}
fn get_next_swarm_response(&mut self) -> Option<(Response, CanonicalSocketAddr)> {
loop {
match self.response_receiver.try_recv() {
Ok((ConnectedResponse::AnnounceIpv4(response), addr)) => {
return Some((Response::AnnounceIpv4(response), addr));
}
Ok((ConnectedResponse::AnnounceIpv6(response), addr)) => {
return Some((Response::AnnounceIpv6(response), addr));
}
Ok((ConnectedResponse::Scrape(response), addr)) => {
if let Some(response) =
self.pending_scrape_responses.add_and_get_finished(response)
{
return Some((Response::Scrape(response), addr));
}
}
Err(_) => {
return None;
}
}
}
}
}
pub fn supported_on_current_kernel() -> anyhow::Result<()> {
let opcodes = [
// We can't probe for RecvMsgMulti, so we probe for SendZc, which was
// also introduced in Linux 6.0
io_uring::opcode::SendZc::CODE,
];
let ring = IoUring::new(1).with_context(|| "create ring")?;
let mut probe = Probe::new();
ring.submitter()
.register_probe(&mut probe)
.with_context(|| "register probe")?;
for opcode in opcodes {
if !probe.is_supported(opcode) {
return Err(anyhow::anyhow!(
"io_uring opcode {:b} not supported",
opcode
));
}
}
Ok(())
}

View file

@ -0,0 +1,145 @@
use std::{
cell::UnsafeCell,
net::{Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
ptr::null_mut,
};
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::{Request, RequestParseError};
use io_uring::{opcode::RecvMsgMulti, types::RecvMsgOut};
use crate::config::Config;
use super::{SOCKET_IDENTIFIER, USER_DATA_RECV};
pub enum Error {
RecvMsgParseError,
RequestParseError(RequestParseError, CanonicalSocketAddr),
InvalidSocketAddress,
}
pub struct RecvHelper {
socket_is_ipv4: bool,
max_scrape_torrents: u8,
#[allow(dead_code)]
name_v4: Box<UnsafeCell<libc::sockaddr_in>>,
msghdr_v4: Box<UnsafeCell<libc::msghdr>>,
#[allow(dead_code)]
name_v6: Box<UnsafeCell<libc::sockaddr_in6>>,
msghdr_v6: Box<UnsafeCell<libc::msghdr>>,
}
impl RecvHelper {
pub fn new(config: &Config) -> Self {
let name_v4 = Box::new(UnsafeCell::new(libc::sockaddr_in {
sin_family: 0,
sin_port: 0,
sin_addr: libc::in_addr { s_addr: 0 },
sin_zero: [0; 8],
}));
let msghdr_v4 = Box::new(UnsafeCell::new(libc::msghdr {
msg_name: name_v4.get() as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in>() as u32,
msg_iov: null_mut(),
msg_iovlen: 0,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
}));
let name_v6 = Box::new(UnsafeCell::new(libc::sockaddr_in6 {
sin6_family: 0,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
}));
let msghdr_v6 = Box::new(UnsafeCell::new(libc::msghdr {
msg_name: name_v6.get() as *mut libc::c_void,
msg_namelen: core::mem::size_of::<libc::sockaddr_in6>() as u32,
msg_iov: null_mut(),
msg_iovlen: 0,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
}));
Self {
socket_is_ipv4: config.network.address.is_ipv4(),
max_scrape_torrents: config.protocol.max_scrape_torrents,
name_v4,
msghdr_v4,
name_v6,
msghdr_v6,
}
}
pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry {
let msghdr: *const libc::msghdr = if self.socket_is_ipv4 {
self.msghdr_v4.get()
} else {
self.msghdr_v6.get()
};
RecvMsgMulti::new(SOCKET_IDENTIFIER, msghdr, buf_group)
.build()
.user_data(USER_DATA_RECV)
}
pub fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> {
let (msg, addr) = if self.socket_is_ipv4 {
let msg = unsafe {
let msghdr = &*(self.msghdr_v4.get() as *const _);
RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)?
};
let addr = unsafe {
let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in);
SocketAddr::V4(SocketAddrV4::new(
u32::from_be(name_data.sin_addr.s_addr).into(),
u16::from_be(name_data.sin_port),
))
};
if addr.port() == 0 {
return Err(Error::InvalidSocketAddress);
}
(msg, addr)
} else {
let msg = unsafe {
let msghdr = &*(self.msghdr_v6.get() as *const _);
RecvMsgOut::parse(buffer, msghdr).map_err(|_| Error::RecvMsgParseError)?
};
let addr = unsafe {
let name_data = *(msg.name_data().as_ptr() as *const libc::sockaddr_in6);
SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(name_data.sin6_addr.s6_addr),
u16::from_be(name_data.sin6_port),
u32::from_be(name_data.sin6_flowinfo),
u32::from_be(name_data.sin6_scope_id),
))
};
if addr.port() == 0 {
return Err(Error::InvalidSocketAddress);
}
(msg, addr)
};
let addr = CanonicalSocketAddr::new(addr);
let request = Request::from_bytes(msg.payload_data(), self.max_scrape_torrents)
.map_err(|err| Error::RequestParseError(err, addr))?;
Ok((request, addr))
}
}

View file

@ -0,0 +1,251 @@
use std::{cell::UnsafeCell, io::Cursor, net::SocketAddr, ops::IndexMut, ptr::null_mut};
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::Response;
use io_uring::opcode::SendMsg;
use crate::config::Config;
use super::{RESPONSE_BUF_LEN, SOCKET_IDENTIFIER};
pub enum Error {
NoBuffers,
SerializationFailed(std::io::Error),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ResponseType {
Connect,
Announce,
Scrape,
Error,
}
impl ResponseType {
fn from_response(response: &Response) -> Self {
match response {
Response::Connect(_) => Self::Connect,
Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => Self::Announce,
Response::Scrape(_) => Self::Scrape,
Response::Error(_) => Self::Error,
}
}
}
struct SendBuffer {
name_v4: UnsafeCell<libc::sockaddr_in>,
name_v6: UnsafeCell<libc::sockaddr_in6>,
bytes: UnsafeCell<[u8; RESPONSE_BUF_LEN]>,
iovec: UnsafeCell<libc::iovec>,
msghdr: UnsafeCell<libc::msghdr>,
free: bool,
/// Only used for statistics
receiver_is_ipv4: bool,
/// Only used for statistics
response_type: ResponseType,
}
impl SendBuffer {
fn new_with_null_pointers() -> Self {
Self {
name_v4: UnsafeCell::new(libc::sockaddr_in {
sin_family: libc::AF_INET as u16,
sin_port: 0,
sin_addr: libc::in_addr { s_addr: 0 },
sin_zero: [0; 8],
}),
name_v6: UnsafeCell::new(libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as u16,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
}),
bytes: UnsafeCell::new([0; RESPONSE_BUF_LEN]),
iovec: UnsafeCell::new(libc::iovec {
iov_base: null_mut(),
iov_len: 0,
}),
msghdr: UnsafeCell::new(libc::msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: null_mut(),
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
}),
free: true,
receiver_is_ipv4: true,
response_type: ResponseType::Connect,
}
}
fn setup_pointers(&mut self, socket_is_ipv4: bool) {
unsafe {
let iovec = &mut *self.iovec.get();
iovec.iov_base = self.bytes.get() as *mut libc::c_void;
iovec.iov_len = (&*self.bytes.get()).len();
let msghdr = &mut *self.msghdr.get();
msghdr.msg_iov = self.iovec.get();
if socket_is_ipv4 {
msghdr.msg_name = self.name_v4.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in>() as u32;
} else {
msghdr.msg_name = self.name_v6.get() as *mut libc::c_void;
msghdr.msg_namelen = core::mem::size_of::<libc::sockaddr_in6>() as u32;
}
}
}
/// # Safety
///
/// - SendBuffer must be stored at a fixed location in memory
/// - SendBuffer.setup_pointers must have been called while stored at that
/// fixed location
/// - Contents of struct fields wrapped in UnsafeCell can NOT be accessed
/// simultaneously to this function call
unsafe fn prepare_entry(
&mut self,
response: &Response,
addr: CanonicalSocketAddr,
socket_is_ipv4: bool,
) -> Result<io_uring::squeue::Entry, Error> {
// Set receiver socket addr
if socket_is_ipv4 {
self.receiver_is_ipv4 = true;
let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() {
addr
} else {
panic!("ipv6 address in ipv4 mode");
};
let name = &mut *self.name_v4.get();
name.sin_port = addr.port().to_be();
name.sin_addr.s_addr = u32::from(*addr.ip()).to_be();
} else {
// Set receiver protocol type before calling addr.get_ipv6_mapped()
self.receiver_is_ipv4 = addr.is_ipv4();
let addr = if let SocketAddr::V6(addr) = addr.get_ipv6_mapped() {
addr
} else {
panic!("ipv4 address when ipv6 or ipv6-mapped address expected");
};
let name = &mut *self.name_v6.get();
name.sin6_port = addr.port().to_be();
name.sin6_addr.s6_addr = addr.ip().octets();
}
let bytes = (&mut *self.bytes.get()).as_mut_slice();
let mut cursor = Cursor::new(bytes);
match response.write(&mut cursor) {
Ok(()) => {
(&mut *self.iovec.get()).iov_len = cursor.position() as usize;
self.response_type = ResponseType::from_response(response);
self.free = false;
Ok(SendMsg::new(SOCKET_IDENTIFIER, self.msghdr.get()).build())
}
Err(err) => Err(Error::SerializationFailed(err)),
}
}
}
pub struct SendBuffers {
likely_next_free_index: usize,
socket_is_ipv4: bool,
buffers: Box<[SendBuffer]>,
}
impl SendBuffers {
pub fn new(config: &Config, capacity: usize) -> Self {
let socket_is_ipv4 = config.network.address.is_ipv4();
let mut buffers = ::std::iter::repeat_with(|| SendBuffer::new_with_null_pointers())
.take(capacity)
.collect::<Vec<_>>()
.into_boxed_slice();
for buffer in buffers.iter_mut() {
buffer.setup_pointers(socket_is_ipv4);
}
Self {
likely_next_free_index: 0,
socket_is_ipv4,
buffers,
}
}
pub fn response_type_and_ipv4(&self, index: usize) -> (ResponseType, bool) {
let buffer = self.buffers.get(index).unwrap();
(buffer.response_type, buffer.receiver_is_ipv4)
}
/// # Safety
///
/// Only safe to call once buffer is no longer referenced by in-flight
/// io_uring queue entries
pub unsafe fn mark_buffer_as_free(&mut self, index: usize) {
self.buffers[index].free = true;
}
/// Call after going through completion queue
pub fn reset_likely_next_free_index(&mut self) {
self.likely_next_free_index = 0;
}
pub fn prepare_entry(
&mut self,
response: &Response,
addr: CanonicalSocketAddr,
) -> Result<io_uring::squeue::Entry, Error> {
let index = self.next_free_index()?;
let buffer = self.buffers.index_mut(index);
// Safety: OK because buffers are stored in fixed memory location,
// buffer pointers were set up in SendBuffers::new() and pointers to
// SendBuffer UnsafeCell contents are not accessed elsewhere
unsafe {
match buffer.prepare_entry(response, addr, self.socket_is_ipv4) {
Ok(entry) => {
self.likely_next_free_index = index + 1;
Ok(entry.user_data(index as u64))
}
Err(err) => Err(err),
}
}
}
fn next_free_index(&self) -> Result<usize, Error> {
if self.likely_next_free_index >= self.buffers.len() {
return Err(Error::NoBuffers);
}
for (i, buffer) in self.buffers[self.likely_next_free_index..]
.iter()
.enumerate()
{
if buffer.free {
return Ok(self.likely_next_free_index + i);
}
}
Err(Error::NoBuffers)
}
}

View file

@ -0,0 +1,143 @@
use std::net::IpAddr;
use std::time::Instant;
use anyhow::Context;
use constant_time_eq::constant_time_eq;
use getrandom::getrandom;
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::ConnectionId;
use crate::config::Config;
/// HMAC (BLAKE3) based ConnectionID creator and validator
///
/// Structure of created ConnectionID (bytes making up inner i64):
/// - &[0..4]: connection expiration time as number of seconds after
/// ConnectionValidator instance was created, encoded as u32 bytes.
/// Value fits around 136 years.
/// - &[4..8]: truncated keyed BLAKE3 hash of above 4 bytes and octets of
/// client IP address
///
/// The purpose of using ConnectionIDs is to prevent IP spoofing, mainly to
/// prevent the tracker from being used as an amplification vector for DDoS
/// attacks. By including 32 bits of BLAKE3 keyed hash output in its contents,
/// such abuse should be rendered impractical.
#[derive(Clone)]
pub struct ConnectionValidator {
start_time: Instant,
max_connection_age: u32,
keyed_hasher: blake3::Hasher,
}
impl ConnectionValidator {
/// Create new instance. Must be created once and cloned if used in several
/// threads.
pub fn new(config: &Config) -> anyhow::Result<Self> {
let mut key = [0; 32];
getrandom(&mut key)
.with_context(|| "Couldn't get random bytes for ConnectionValidator key")?;
let keyed_hasher = blake3::Hasher::new_keyed(&key);
Ok(Self {
keyed_hasher,
start_time: Instant::now(),
max_connection_age: config.cleaning.max_connection_age,
})
}
pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId {
let valid_until =
(self.start_time.elapsed().as_secs() as u32 + self.max_connection_age).to_ne_bytes();
let hash = self.hash(valid_until, source_addr.get().ip());
let mut connection_id_bytes = [0u8; 8];
(&mut connection_id_bytes[..4]).copy_from_slice(&valid_until);
(&mut connection_id_bytes[4..]).copy_from_slice(&hash);
ConnectionId(i64::from_ne_bytes(connection_id_bytes))
}
pub fn connection_id_valid(
&mut self,
source_addr: CanonicalSocketAddr,
connection_id: ConnectionId,
) -> bool {
let bytes = connection_id.0.to_ne_bytes();
let (valid_until, hash) = bytes.split_at(4);
let valid_until: [u8; 4] = valid_until.try_into().unwrap();
if !constant_time_eq(hash, &self.hash(valid_until, source_addr.get().ip())) {
return false;
}
u32::from_ne_bytes(valid_until) > self.start_time.elapsed().as_secs() as u32
}
fn hash(&mut self, valid_until: [u8; 4], ip_addr: IpAddr) -> [u8; 4] {
self.keyed_hasher.update(&valid_until);
match ip_addr {
IpAddr::V4(ip) => self.keyed_hasher.update(&ip.octets()),
IpAddr::V6(ip) => self.keyed_hasher.update(&ip.octets()),
};
let mut hash = [0u8; 4];
self.keyed_hasher.finalize_xof().fill(&mut hash);
self.keyed_hasher.reset();
hash
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use quickcheck_macros::quickcheck;
use super::*;
#[quickcheck]
fn test_connection_validator(
original_addr: IpAddr,
different_addr: IpAddr,
max_connection_age: u32,
) -> quickcheck::TestResult {
let original_addr = CanonicalSocketAddr::new(SocketAddr::new(original_addr, 0));
let different_addr = CanonicalSocketAddr::new(SocketAddr::new(different_addr, 0));
if original_addr == different_addr {
return quickcheck::TestResult::discard();
}
let mut validator = {
let mut config = Config::default();
config.cleaning.max_connection_age = max_connection_age;
ConnectionValidator::new(&config).unwrap()
};
let connection_id = validator.create_connection_id(original_addr);
let original_valid = validator.connection_id_valid(original_addr, connection_id);
let different_valid = validator.connection_id_valid(different_addr, connection_id);
if different_valid {
return quickcheck::TestResult::failed();
}
if max_connection_age == 0 {
quickcheck::TestResult::from_bool(!original_valid)
} else {
// Note: depends on that running this test takes less than a second
quickcheck::TestResult::from_bool(original_valid)
}
}
}

View file

@ -0,0 +1,318 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use hdrhistogram::Histogram;
use num_format::{Locale, ToFormattedString};
use serde::Serialize;
use crate::common::Statistics;
use crate::config::Config;
pub struct StatisticsCollector {
shared: Arc<Statistics>,
last_update: Instant,
pending_histograms: Vec<Histogram<u64>>,
last_complete_histogram: PeerHistogramStatistics,
#[cfg(feature = "prometheus")]
ip_version: String,
}
impl StatisticsCollector {
pub fn new(shared: Arc<Statistics>, #[cfg(feature = "prometheus")] ip_version: String) -> Self {
Self {
shared,
last_update: Instant::now(),
pending_histograms: Vec::new(),
last_complete_histogram: Default::default(),
#[cfg(feature = "prometheus")]
ip_version,
}
}
pub fn add_histogram(&mut self, config: &Config, histogram: Histogram<u64>) {
self.pending_histograms.push(histogram);
if self.pending_histograms.len() == config.swarm_workers {
self.last_complete_histogram =
PeerHistogramStatistics::new(self.pending_histograms.drain(..).sum());
}
}
pub fn collect_from_shared(
&mut self,
#[cfg(feature = "prometheus")] config: &Config,
) -> CollectedStatistics {
let requests_received = Self::fetch_and_reset(&self.shared.requests_received);
let responses_sent_connect = Self::fetch_and_reset(&self.shared.responses_sent_connect);
let responses_sent_announce = Self::fetch_and_reset(&self.shared.responses_sent_announce);
let responses_sent_scrape = Self::fetch_and_reset(&self.shared.responses_sent_scrape);
let responses_sent_error = Self::fetch_and_reset(&self.shared.responses_sent_error);
let bytes_received = Self::fetch_and_reset(&self.shared.bytes_received);
let bytes_sent = Self::fetch_and_reset(&self.shared.bytes_sent);
let num_torrents_by_worker: Vec<usize> = self
.shared
.torrents
.iter()
.map(|n| n.load(Ordering::Relaxed))
.collect();
let num_peers_by_worker: Vec<usize> = self
.shared
.peers
.iter()
.map(|n| n.load(Ordering::Relaxed))
.collect();
let elapsed = {
let now = Instant::now();
let elapsed = (now - self.last_update).as_secs_f64();
self.last_update = now;
elapsed
};
#[cfg(feature = "prometheus")]
if config.statistics.run_prometheus_endpoint {
::metrics::counter!(
"aquatic_requests_total",
requests_received.try_into().unwrap(),
"ip_version" => self.ip_version.clone(),
);
::metrics::counter!(
"aquatic_responses_total",
responses_sent_connect.try_into().unwrap(),
"type" => "connect",
"ip_version" => self.ip_version.clone(),
);
::metrics::counter!(
"aquatic_responses_total",
responses_sent_announce.try_into().unwrap(),
"type" => "announce",
"ip_version" => self.ip_version.clone(),
);
::metrics::counter!(
"aquatic_responses_total",
responses_sent_scrape.try_into().unwrap(),
"type" => "scrape",
"ip_version" => self.ip_version.clone(),
);
::metrics::counter!(
"aquatic_responses_total",
responses_sent_error.try_into().unwrap(),
"type" => "error",
"ip_version" => self.ip_version.clone(),
);
::metrics::counter!(
"aquatic_rx_bytes",
bytes_received.try_into().unwrap(),
"ip_version" => self.ip_version.clone(),
);
::metrics::counter!(
"aquatic_tx_bytes",
bytes_sent.try_into().unwrap(),
"ip_version" => self.ip_version.clone(),
);
for (worker_index, n) in num_torrents_by_worker.iter().copied().enumerate() {
::metrics::gauge!(
"aquatic_torrents",
n as f64,
"ip_version" => self.ip_version.clone(),
"worker_index" => worker_index.to_string(),
);
}
for (worker_index, n) in num_peers_by_worker.iter().copied().enumerate() {
::metrics::gauge!(
"aquatic_peers",
n as f64,
"ip_version" => self.ip_version.clone(),
"worker_index" => worker_index.to_string(),
);
}
if config.statistics.torrent_peer_histograms {
self.last_complete_histogram
.update_metrics(self.ip_version.clone());
}
}
let num_peers: usize = num_peers_by_worker.into_iter().sum();
let num_torrents: usize = num_torrents_by_worker.into_iter().sum();
let requests_per_second = requests_received as f64 / elapsed;
let responses_per_second_connect = responses_sent_connect as f64 / elapsed;
let responses_per_second_announce = responses_sent_announce as f64 / elapsed;
let responses_per_second_scrape = responses_sent_scrape as f64 / elapsed;
let responses_per_second_error = responses_sent_error as f64 / elapsed;
let bytes_received_per_second = bytes_received as f64 / elapsed;
let bytes_sent_per_second = bytes_sent as f64 / elapsed;
let responses_per_second_total = responses_per_second_connect
+ responses_per_second_announce
+ responses_per_second_scrape
+ responses_per_second_error;
CollectedStatistics {
requests_per_second: (requests_per_second as usize).to_formatted_string(&Locale::en),
responses_per_second_total: (responses_per_second_total as usize)
.to_formatted_string(&Locale::en),
responses_per_second_connect: (responses_per_second_connect as usize)
.to_formatted_string(&Locale::en),
responses_per_second_announce: (responses_per_second_announce as usize)
.to_formatted_string(&Locale::en),
responses_per_second_scrape: (responses_per_second_scrape as usize)
.to_formatted_string(&Locale::en),
responses_per_second_error: (responses_per_second_error as usize)
.to_formatted_string(&Locale::en),
rx_mbits: format!("{:.2}", bytes_received_per_second * 8.0 / 1_000_000.0),
tx_mbits: format!("{:.2}", bytes_sent_per_second * 8.0 / 1_000_000.0),
num_torrents: num_torrents.to_formatted_string(&Locale::en),
num_peers: num_peers.to_formatted_string(&Locale::en),
peer_histogram: self.last_complete_histogram.clone(),
}
}
fn fetch_and_reset(atomic: &AtomicUsize) -> usize {
atomic.fetch_and(0, Ordering::Relaxed)
}
}
#[derive(Clone, Debug, Serialize)]
pub struct CollectedStatistics {
pub requests_per_second: String,
pub responses_per_second_total: String,
pub responses_per_second_connect: String,
pub responses_per_second_announce: String,
pub responses_per_second_scrape: String,
pub responses_per_second_error: String,
pub rx_mbits: String,
pub tx_mbits: String,
pub num_torrents: String,
pub num_peers: String,
pub peer_histogram: PeerHistogramStatistics,
}
#[derive(Clone, Debug, Serialize, Default)]
pub struct PeerHistogramStatistics {
pub min: u64,
pub p10: u64,
pub p20: u64,
pub p30: u64,
pub p40: u64,
pub p50: u64,
pub p60: u64,
pub p70: u64,
pub p80: u64,
pub p90: u64,
pub p95: u64,
pub p99: u64,
pub p999: u64,
pub max: u64,
}
impl PeerHistogramStatistics {
fn new(h: Histogram<u64>) -> Self {
Self {
min: h.min(),
p10: h.value_at_percentile(10.0),
p20: h.value_at_percentile(20.0),
p30: h.value_at_percentile(30.0),
p40: h.value_at_percentile(40.0),
p50: h.value_at_percentile(50.0),
p60: h.value_at_percentile(60.0),
p70: h.value_at_percentile(70.0),
p80: h.value_at_percentile(80.0),
p90: h.value_at_percentile(90.0),
p95: h.value_at_percentile(95.0),
p99: h.value_at_percentile(99.0),
p999: h.value_at_percentile(99.9),
max: h.max(),
}
}
#[cfg(feature = "prometheus")]
fn update_metrics(&self, ip_version: String) {
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.min as f64,
"type" => "max",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p10 as f64,
"type" => "p10",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p20 as f64,
"type" => "p20",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p30 as f64,
"type" => "p30",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p40 as f64,
"type" => "p40",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p50 as f64,
"type" => "p50",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p60 as f64,
"type" => "p60",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p70 as f64,
"type" => "p70",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p80 as f64,
"type" => "p80",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p90 as f64,
"type" => "p90",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p99 as f64,
"type" => "p99",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.p999 as f64,
"type" => "p99.9",
"ip_version" => ip_version.clone(),
);
::metrics::gauge!(
"aquatic_peers_per_torrent",
self.max as f64,
"type" => "max",
"ip_version" => ip_version.clone(),
);
}
}

View file

@ -0,0 +1,306 @@
mod collector;
use std::fs::File;
use std::io::Write;
use std::time::{Duration, Instant};
use anyhow::Context;
use aquatic_common::{IndexMap, PanicSentinel};
use aquatic_udp_protocol::{PeerClient, PeerId};
use compact_str::CompactString;
use crossbeam_channel::Receiver;
use num_format::{Locale, ToFormattedString};
use serde::Serialize;
use time::format_description::well_known::Rfc2822;
use time::OffsetDateTime;
use tinytemplate::TinyTemplate;
use collector::{CollectedStatistics, StatisticsCollector};
use crate::common::*;
use crate::config::Config;
const TEMPLATE_KEY: &str = "statistics";
const TEMPLATE_CONTENTS: &str = include_str!("../../../templates/statistics.html");
const STYLESHEET_CONTENTS: &str = concat!(
"<style>",
include_str!("../../../templates/statistics.css"),
"</style>"
);
#[derive(Debug, Serialize)]
struct TemplateData {
stylesheet: String,
ipv4_active: bool,
ipv6_active: bool,
extended_active: bool,
ipv4: CollectedStatistics,
ipv6: CollectedStatistics,
last_updated: String,
peer_update_interval: String,
peer_clients: Vec<(String, String)>,
}
pub fn run_statistics_worker(
_sentinel: PanicSentinel,
config: Config,
shared_state: State,
statistics_receiver: Receiver<StatisticsMessage>,
) {
let process_peer_client_data = {
let mut collect = config.statistics.write_html_to_file;
#[cfg(feature = "prometheus")]
{
collect |= config.statistics.run_prometheus_endpoint;
}
collect & config.statistics.peer_clients
};
let opt_tt = if config.statistics.write_html_to_file {
let mut tt = TinyTemplate::new();
if let Err(err) = tt.add_template(TEMPLATE_KEY, TEMPLATE_CONTENTS) {
::log::error!("Couldn't parse statistics html template: {:#}", err);
None
} else {
Some(tt)
}
} else {
None
};
let mut ipv4_collector = StatisticsCollector::new(
shared_state.statistics_ipv4,
#[cfg(feature = "prometheus")]
"4".into(),
);
let mut ipv6_collector = StatisticsCollector::new(
shared_state.statistics_ipv6,
#[cfg(feature = "prometheus")]
"6".into(),
);
// Store a count to enable not removing peers from the count completely
// just because they were removed from one torrent
let mut peers: IndexMap<PeerId, (usize, PeerClient, CompactString)> = IndexMap::default();
loop {
let start_time = Instant::now();
for message in statistics_receiver.try_iter() {
match message {
StatisticsMessage::Ipv4PeerHistogram(h) => ipv4_collector.add_histogram(&config, h),
StatisticsMessage::Ipv6PeerHistogram(h) => ipv6_collector.add_histogram(&config, h),
StatisticsMessage::PeerAdded(peer_id) => {
if process_peer_client_data {
peers
.entry(peer_id)
.or_insert_with(|| (0, peer_id.client(), peer_id.first_8_bytes_hex()))
.0 += 1;
}
}
StatisticsMessage::PeerRemoved(peer_id) => {
if process_peer_client_data {
if let Some((count, _, _)) = peers.get_mut(&peer_id) {
*count -= 1;
if *count == 0 {
peers.remove(&peer_id);
}
}
}
}
}
}
let statistics_ipv4 = ipv4_collector.collect_from_shared(
#[cfg(feature = "prometheus")]
&config,
);
let statistics_ipv6 = ipv6_collector.collect_from_shared(
#[cfg(feature = "prometheus")]
&config,
);
let peer_clients = if process_peer_client_data {
let mut clients: IndexMap<PeerClient, usize> = IndexMap::default();
#[cfg(feature = "prometheus")]
let mut prefixes: IndexMap<CompactString, usize> = IndexMap::default();
// Only count peer_ids once, even if they are in multiple torrents
for (_, peer_client, prefix) in peers.values() {
*clients.entry(peer_client.to_owned()).or_insert(0) += 1;
#[cfg(feature = "prometheus")]
if config.statistics.run_prometheus_endpoint
&& config.statistics.prometheus_peer_id_prefixes
{
*prefixes.entry(prefix.to_owned()).or_insert(0) += 1;
}
}
clients.sort_unstable_by(|_, a, _, b| b.cmp(a));
#[cfg(feature = "prometheus")]
if config.statistics.run_prometheus_endpoint
&& config.statistics.prometheus_peer_id_prefixes
{
for (prefix, count) in prefixes {
::metrics::gauge!(
"aquatic_peer_id_prefixes",
count as f64,
"prefix_hex" => prefix.to_string(),
);
}
}
let mut client_vec = Vec::with_capacity(clients.len());
for (client, count) in clients {
if config.statistics.write_html_to_file {
client_vec.push((client.to_string(), count.to_formatted_string(&Locale::en)));
}
#[cfg(feature = "prometheus")]
if config.statistics.run_prometheus_endpoint {
::metrics::gauge!(
"aquatic_peer_clients",
count as f64,
"client" => client.to_string(),
);
}
}
client_vec
} else {
Vec::new()
};
if config.statistics.print_to_stdout {
println!("General:");
println!(
" access list entries: {}",
shared_state.access_list.load().len()
);
if config.network.ipv4_active() {
println!("IPv4:");
print_to_stdout(&config, &statistics_ipv4);
}
if config.network.ipv6_active() {
println!("IPv6:");
print_to_stdout(&config, &statistics_ipv6);
}
println!();
}
if let Some(tt) = opt_tt.as_ref() {
let template_data = TemplateData {
stylesheet: STYLESHEET_CONTENTS.to_string(),
ipv4_active: config.network.ipv4_active(),
ipv6_active: config.network.ipv6_active(),
extended_active: config.statistics.torrent_peer_histograms,
ipv4: statistics_ipv4,
ipv6: statistics_ipv6,
last_updated: OffsetDateTime::now_utc()
.format(&Rfc2822)
.unwrap_or("(formatting error)".into()),
peer_update_interval: format!("{}", config.cleaning.torrent_cleaning_interval),
peer_clients,
};
if let Err(err) = save_html_to_file(&config, tt, &template_data) {
::log::error!("Couldn't save statistics to file: {:#}", err)
}
}
peers.shrink_to_fit();
if let Some(time_remaining) =
Duration::from_secs(config.statistics.interval).checked_sub(start_time.elapsed())
{
::std::thread::sleep(time_remaining);
} else {
::log::warn!(
"statistics interval not long enough to process all data, output may be misleading"
);
}
}
}
fn print_to_stdout(config: &Config, statistics: &CollectedStatistics) {
println!(
" bandwidth: {:>7} Mbit/s in, {:7} Mbit/s out",
statistics.rx_mbits, statistics.tx_mbits,
);
println!(" requests/second: {:>10}", statistics.requests_per_second);
println!(" responses/second");
println!(
" total: {:>10}",
statistics.responses_per_second_total
);
println!(
" connect: {:>10}",
statistics.responses_per_second_connect
);
println!(
" announce: {:>10}",
statistics.responses_per_second_announce
);
println!(
" scrape: {:>10}",
statistics.responses_per_second_scrape
);
println!(
" error: {:>10}",
statistics.responses_per_second_error
);
println!(" torrents: {:>10}", statistics.num_torrents);
println!(
" peers: {:>10} (updated every {}s)",
statistics.num_peers, config.cleaning.torrent_cleaning_interval
);
if config.statistics.torrent_peer_histograms {
println!(
" peers per torrent (updated every {}s)",
config.cleaning.torrent_cleaning_interval
);
println!(" min {:>10}", statistics.peer_histogram.min);
println!(" p10 {:>10}", statistics.peer_histogram.p10);
println!(" p20 {:>10}", statistics.peer_histogram.p20);
println!(" p30 {:>10}", statistics.peer_histogram.p30);
println!(" p40 {:>10}", statistics.peer_histogram.p40);
println!(" p50 {:>10}", statistics.peer_histogram.p50);
println!(" p60 {:>10}", statistics.peer_histogram.p60);
println!(" p70 {:>10}", statistics.peer_histogram.p70);
println!(" p80 {:>10}", statistics.peer_histogram.p80);
println!(" p90 {:>10}", statistics.peer_histogram.p90);
println!(" p95 {:>10}", statistics.peer_histogram.p95);
println!(" p99 {:>10}", statistics.peer_histogram.p99);
println!(" p99.9 {:>10}", statistics.peer_histogram.p999);
println!(" max {:>10}", statistics.peer_histogram.max);
}
}
fn save_html_to_file(
config: &Config,
tt: &TinyTemplate,
template_data: &TemplateData,
) -> anyhow::Result<()> {
let mut file = File::create(&config.statistics.html_file_path).with_context(|| {
format!(
"File path: {}",
&config.statistics.html_file_path.to_string_lossy()
)
})?;
write!(file, "{}", tt.render(TEMPLATE_KEY, template_data)?)?;
Ok(())
}

View file

@ -0,0 +1,200 @@
mod storage;
use std::net::IpAddr;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;
use aquatic_common::ServerStartInstant;
use crossbeam_channel::Receiver;
use crossbeam_channel::Sender;
use rand::{rngs::SmallRng, SeedableRng};
use aquatic_common::{CanonicalSocketAddr, PanicSentinel, ValidUntil};
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
use storage::{TorrentMap, TorrentMaps};
pub fn run_swarm_worker(
_sentinel: PanicSentinel,
config: Config,
state: State,
server_start_instant: ServerStartInstant,
request_receiver: Receiver<(SocketWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>,
response_sender: ConnectedResponseSender,
statistics_sender: Sender<StatisticsMessage>,
worker_index: SwarmWorkerIndex,
) {
let mut torrents = TorrentMaps::default();
let mut rng = SmallRng::from_entropy();
let timeout = Duration::from_millis(config.request_channel_recv_timeout_ms);
let mut peer_valid_until = ValidUntil::new(server_start_instant, config.cleaning.max_peer_age);
let cleaning_interval = Duration::from_secs(config.cleaning.torrent_cleaning_interval);
let statistics_update_interval = Duration::from_secs(config.statistics.interval);
let mut last_cleaning = Instant::now();
let mut last_statistics_update = Instant::now();
let mut iter_counter = 0usize;
loop {
if let Ok((sender_index, request, src)) = request_receiver.recv_timeout(timeout) {
let response = match (request, src.get().ip()) {
(ConnectedRequest::Announce(request), IpAddr::V4(ip)) => {
let response = handle_announce_request(
&config,
&mut rng,
&statistics_sender,
&mut torrents.ipv4,
request,
ip,
peer_valid_until,
);
ConnectedResponse::AnnounceIpv4(response)
}
(ConnectedRequest::Announce(request), IpAddr::V6(ip)) => {
let response = handle_announce_request(
&config,
&mut rng,
&statistics_sender,
&mut torrents.ipv6,
request,
ip,
peer_valid_until,
);
ConnectedResponse::AnnounceIpv6(response)
}
(ConnectedRequest::Scrape(request), IpAddr::V4(_)) => {
ConnectedResponse::Scrape(handle_scrape_request(&mut torrents.ipv4, request))
}
(ConnectedRequest::Scrape(request), IpAddr::V6(_)) => {
ConnectedResponse::Scrape(handle_scrape_request(&mut torrents.ipv6, request))
}
};
response_sender.try_send_to(sender_index, response, src);
}
// Run periodic tasks
if iter_counter % 128 == 0 {
let now = Instant::now();
peer_valid_until = ValidUntil::new(server_start_instant, config.cleaning.max_peer_age);
if now > last_cleaning + cleaning_interval {
torrents.clean_and_update_statistics(
&config,
&state,
&statistics_sender,
&state.access_list,
server_start_instant,
worker_index,
);
last_cleaning = now;
}
if config.statistics.active()
&& now > last_statistics_update + statistics_update_interval
{
state.statistics_ipv4.torrents[worker_index.0]
.store(torrents.ipv4.num_torrents(), Ordering::Release);
state.statistics_ipv6.torrents[worker_index.0]
.store(torrents.ipv6.num_torrents(), Ordering::Release);
last_statistics_update = now;
}
}
iter_counter = iter_counter.wrapping_add(1);
}
}
fn handle_announce_request<I: Ip>(
config: &Config,
rng: &mut SmallRng,
statistics_sender: &Sender<StatisticsMessage>,
torrents: &mut TorrentMap<I>,
request: AnnounceRequest,
peer_ip: I,
peer_valid_until: ValidUntil,
) -> AnnounceResponse<I> {
let max_num_peers_to_take: usize = if request.peers_wanted.0 <= 0 {
config.protocol.max_response_peers
} else {
::std::cmp::min(
config.protocol.max_response_peers,
request.peers_wanted.0.try_into().unwrap(),
)
};
let torrent_data = torrents.0.entry(request.info_hash).or_default();
let peer_status = PeerStatus::from_event_and_bytes_left(request.event, request.bytes_left);
torrent_data.update_peer(
config,
statistics_sender,
request.peer_id,
peer_ip,
request.port,
peer_status,
peer_valid_until,
);
let response_peers = if let PeerStatus::Stopped = peer_status {
Vec::new()
} else {
torrent_data.extract_response_peers(rng, request.peer_id, max_num_peers_to_take)
};
AnnounceResponse {
transaction_id: request.transaction_id,
announce_interval: AnnounceInterval(config.protocol.peer_announce_interval),
leechers: NumberOfPeers(torrent_data.num_leechers().try_into().unwrap_or(i32::MAX)),
seeders: NumberOfPeers(torrent_data.num_seeders().try_into().unwrap_or(i32::MAX)),
peers: response_peers,
}
}
fn handle_scrape_request<I: Ip>(
torrents: &mut TorrentMap<I>,
request: PendingScrapeRequest,
) -> PendingScrapeResponse {
const EMPTY_STATS: TorrentScrapeStatistics = create_torrent_scrape_statistics(0, 0);
let torrent_stats = request
.info_hashes
.into_iter()
.map(|(i, info_hash)| {
let stats = torrents
.0
.get(&info_hash)
.map(|torrent_data| torrent_data.scrape_statistics())
.unwrap_or(EMPTY_STATS);
(i, stats)
})
.collect();
PendingScrapeResponse {
slab_key: request.slab_key,
torrent_stats,
}
}
#[inline(always)]
const fn create_torrent_scrape_statistics(seeders: i32, leechers: i32) -> TorrentScrapeStatistics {
TorrentScrapeStatistics {
seeders: NumberOfPeers(seeders),
completed: NumberOfDownloads(0), // No implementation planned
leechers: NumberOfPeers(leechers),
}
}

View file

@ -0,0 +1,402 @@
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use aquatic_common::IndexMap;
use aquatic_common::SecondsSinceServerStart;
use aquatic_common::ServerStartInstant;
use aquatic_common::{
access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache, AccessListMode},
extract_response_peers, ValidUntil,
};
use aquatic_udp_protocol::*;
use crossbeam_channel::Sender;
use hdrhistogram::Histogram;
use rand::prelude::SmallRng;
use crate::common::*;
use crate::config::Config;
use super::create_torrent_scrape_statistics;
#[derive(Clone, Debug)]
struct Peer<I: Ip> {
ip_address: I,
port: Port,
is_seeder: bool,
valid_until: ValidUntil,
}
impl<I: Ip> Peer<I> {
fn to_response_peer(&self) -> ResponsePeer<I> {
ResponsePeer {
ip_address: self.ip_address,
port: self.port,
}
}
}
type PeerMap<I> = IndexMap<PeerId, Peer<I>>;
pub struct TorrentData<I: Ip> {
peers: PeerMap<I>,
num_seeders: usize,
}
impl<I: Ip> TorrentData<I> {
pub fn update_peer(
&mut self,
config: &Config,
statistics_sender: &Sender<StatisticsMessage>,
peer_id: PeerId,
ip_address: I,
port: Port,
status: PeerStatus,
valid_until: ValidUntil,
) {
let opt_removed_peer = match status {
PeerStatus::Leeching => {
let peer = Peer {
ip_address,
port,
is_seeder: false,
valid_until,
};
self.peers.insert(peer_id, peer)
}
PeerStatus::Seeding => {
let peer = Peer {
ip_address,
port,
is_seeder: true,
valid_until,
};
self.num_seeders += 1;
self.peers.insert(peer_id, peer)
}
PeerStatus::Stopped => self.peers.remove(&peer_id),
};
if config.statistics.peer_clients {
match (status, opt_removed_peer.is_some()) {
// We added a new peer
(PeerStatus::Leeching | PeerStatus::Seeding, false) => {
if let Err(_) =
statistics_sender.try_send(StatisticsMessage::PeerAdded(peer_id))
{
// Should never happen in practice
::log::error!("Couldn't send StatisticsMessage::PeerAdded");
}
}
// We removed an existing peer
(PeerStatus::Stopped, true) => {
if let Err(_) =
statistics_sender.try_send(StatisticsMessage::PeerRemoved(peer_id))
{
// Should never happen in practice
::log::error!("Couldn't send StatisticsMessage::PeerRemoved");
}
}
_ => (),
}
}
if let Some(Peer {
is_seeder: true, ..
}) = opt_removed_peer
{
self.num_seeders -= 1;
}
}
pub fn extract_response_peers(
&self,
rng: &mut SmallRng,
peer_id: PeerId,
max_num_peers_to_take: usize,
) -> Vec<ResponsePeer<I>> {
extract_response_peers(
rng,
&self.peers,
max_num_peers_to_take,
peer_id,
Peer::to_response_peer,
)
}
pub fn num_leechers(&self) -> usize {
self.peers.len() - self.num_seeders
}
pub fn num_seeders(&self) -> usize {
self.num_seeders
}
pub fn scrape_statistics(&self) -> TorrentScrapeStatistics {
create_torrent_scrape_statistics(
self.num_seeders.try_into().unwrap_or(i32::MAX),
self.num_leechers().try_into().unwrap_or(i32::MAX),
)
}
/// Remove inactive peers and reclaim space
fn clean(
&mut self,
config: &Config,
statistics_sender: &Sender<StatisticsMessage>,
now: SecondsSinceServerStart,
) {
self.peers.retain(|peer_id, peer| {
let keep = peer.valid_until.valid(now);
if !keep {
if peer.is_seeder {
self.num_seeders -= 1;
}
if config.statistics.peer_clients {
if let Err(_) =
statistics_sender.try_send(StatisticsMessage::PeerRemoved(*peer_id))
{
// Should never happen in practice
::log::error!("Couldn't send StatisticsMessage::PeerRemoved");
}
}
}
keep
});
if !self.peers.is_empty() {
self.peers.shrink_to_fit();
}
}
}
impl<I: Ip> Default for TorrentData<I> {
fn default() -> Self {
Self {
peers: Default::default(),
num_seeders: 0,
}
}
}
#[derive(Default)]
pub struct TorrentMap<I: Ip>(pub IndexMap<InfoHash, TorrentData<I>>);
impl<I: Ip> TorrentMap<I> {
/// Remove forbidden or inactive torrents, reclaim space and return number of remaining peers
fn clean_and_get_statistics(
&mut self,
config: &Config,
statistics_sender: &Sender<StatisticsMessage>,
access_list_cache: &mut AccessListCache,
access_list_mode: AccessListMode,
now: SecondsSinceServerStart,
) -> (usize, Option<Histogram<u64>>) {
let mut num_peers = 0;
let mut opt_histogram: Option<Histogram<u64>> = if config.statistics.torrent_peer_histograms
{
match Histogram::new(3) {
Ok(histogram) => Some(histogram),
Err(err) => {
::log::error!("Couldn't create peer histogram: {:#}", err);
None
}
}
} else {
None
};
self.0.retain(|info_hash, torrent| {
if !access_list_cache
.load()
.allows(access_list_mode, &info_hash.0)
{
return false;
}
torrent.clean(config, statistics_sender, now);
num_peers += torrent.peers.len();
match opt_histogram {
Some(ref mut histogram) if torrent.peers.len() != 0 => {
let n = torrent
.peers
.len()
.try_into()
.expect("Couldn't fit usize into u64");
if let Err(err) = histogram.record(n) {
::log::error!("Couldn't record {} to histogram: {:#}", n, err);
}
}
_ => (),
}
!torrent.peers.is_empty()
});
self.0.shrink_to_fit();
(num_peers, opt_histogram)
}
pub fn num_torrents(&self) -> usize {
self.0.len()
}
}
pub struct TorrentMaps {
pub ipv4: TorrentMap<Ipv4Addr>,
pub ipv6: TorrentMap<Ipv6Addr>,
}
impl Default for TorrentMaps {
fn default() -> Self {
Self {
ipv4: TorrentMap(Default::default()),
ipv6: TorrentMap(Default::default()),
}
}
}
impl TorrentMaps {
/// Remove forbidden or inactive torrents, reclaim space and update statistics
pub fn clean_and_update_statistics(
&mut self,
config: &Config,
state: &State,
statistics_sender: &Sender<StatisticsMessage>,
access_list: &Arc<AccessListArcSwap>,
server_start_instant: ServerStartInstant,
worker_index: SwarmWorkerIndex,
) {
let mut cache = create_access_list_cache(access_list);
let mode = config.access_list.mode;
let now = server_start_instant.seconds_elapsed();
let ipv4 =
self.ipv4
.clean_and_get_statistics(config, statistics_sender, &mut cache, mode, now);
let ipv6 =
self.ipv6
.clean_and_get_statistics(config, statistics_sender, &mut cache, mode, now);
if config.statistics.active() {
state.statistics_ipv4.peers[worker_index.0].store(ipv4.0, Ordering::Release);
state.statistics_ipv6.peers[worker_index.0].store(ipv6.0, Ordering::Release);
if let Some(message) = ipv4.1.map(StatisticsMessage::Ipv4PeerHistogram) {
if let Err(err) = statistics_sender.try_send(message) {
::log::error!("couldn't send statistics message: {:#}", err);
}
}
if let Some(message) = ipv6.1.map(StatisticsMessage::Ipv6PeerHistogram) {
if let Err(err) = statistics_sender.try_send(message) {
::log::error!("couldn't send statistics message: {:#}", err);
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::net::Ipv4Addr;
use quickcheck::{quickcheck, TestResult};
use rand::thread_rng;
use super::*;
fn gen_peer_id(i: u32) -> PeerId {
let mut peer_id = PeerId([0; 20]);
peer_id.0[0..4].copy_from_slice(&i.to_ne_bytes());
peer_id
}
fn gen_peer(i: u32) -> Peer<Ipv4Addr> {
Peer {
ip_address: Ipv4Addr::from(i.to_be_bytes()),
port: Port(1),
is_seeder: false,
valid_until: ValidUntil::new(ServerStartInstant::new(), 0),
}
}
#[test]
fn test_extract_response_peers() {
fn prop(data: (u16, u16)) -> TestResult {
let gen_num_peers = data.0 as u32;
let req_num_peers = data.1 as usize;
let mut peer_map: PeerMap<Ipv4Addr> = Default::default();
let mut opt_sender_key = None;
let mut opt_sender_peer = None;
for i in 0..gen_num_peers {
let key = gen_peer_id(i);
let peer = gen_peer((i << 16) + i);
if i == 0 {
opt_sender_key = Some(key);
opt_sender_peer = Some(peer.to_response_peer());
}
peer_map.insert(key, peer);
}
let mut rng = thread_rng();
let peers = extract_response_peers(
&mut rng,
&peer_map,
req_num_peers,
opt_sender_key.unwrap_or_else(|| gen_peer_id(1)),
Peer::to_response_peer,
);
// Check that number of returned peers is correct
let mut success = peers.len() <= req_num_peers;
if req_num_peers >= gen_num_peers as usize {
success &= peers.len() == gen_num_peers as usize
|| peers.len() + 1 == gen_num_peers as usize;
}
// Check that returned peers are unique (no overlap) and that sender
// isn't returned
let mut ip_addresses = HashSet::with_capacity(peers.len());
for peer in peers {
if peer == opt_sender_peer.clone().unwrap()
|| ip_addresses.contains(&peer.ip_address)
{
success = false;
break;
}
ip_addresses.insert(peer.ip_address);
}
TestResult::from_bool(success)
}
quickcheck(prop as fn((u16, u16)) -> TestResult);
}
}

View file

@ -0,0 +1,22 @@
body {
font-family: arial, sans-serif;
font-size: 16px;
}
table {
border-collapse: collapse
}
caption {
caption-side: bottom;
padding-top: 0.5rem;
}
th, td {
padding: 0.5rem 2rem;
border: 1px solid #ccc;
}
th {
background-color: #eee;
}

View file

@ -0,0 +1,278 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>UDP BitTorrent tracker statistics</title>
{#- Include stylesheet like this to prevent code editor syntax warnings #}
{ stylesheet | unescaped }
</head>
<body>
<h1>BitTorrent tracker statistics</h1>
{#- <p> <strong>Tracker software:</strong> <a href="https://github.com/greatest-ape/aquatic">aquatic_udp</a> </p> #}
<p>
<strong>Updated:</strong> { last_updated } (UTC)
</p>
{{ if ipv4_active }}
<h2>IPv4</h2>
<table>
<caption>* Peer count is updated every { peer_update_interval } seconds</caption>
<tr>
<th scope="row">Number of torrents</th>
<td>{ ipv4.num_torrents }</td>
</tr>
<tr>
<th scope="row">Number of peers</th>
<td>{ ipv4.num_peers } *</td>
</tr>
<tr>
<th scope="row">Requests / second</th>
<td>{ ipv4.requests_per_second }</td>
</tr>
<tr>
<th scope="row">Total responses / second</th>
<td>{ ipv4.responses_per_second_total }</td>
</tr>
<tr>
<th scope="row">Connect responses / second</th>
<td>{ ipv4.responses_per_second_connect }</td>
</tr>
<tr>
<th scope="row">Announce responses / second</th>
<td>{ ipv4.responses_per_second_announce }</td>
</tr>
<tr>
<th scope="row">Scrape responses / second</th>
<td>{ ipv4.responses_per_second_scrape }</td>
</tr>
<tr>
<th scope="row">Error responses / second</th>
<td>{ ipv4.responses_per_second_error }</td>
</tr>
<tr>
<th scope="row">Bandwidth (RX)</th>
<td>{ ipv4.rx_mbits } mbit/s</td>
</tr>
<tr>
<th scope="row">Bandwidth (TX)</th>
<td>{ ipv4.tx_mbits } mbit/s</td>
</tr>
</table>
{{ if extended_active }}
<h3>Peers per torrent</h3>
<table>
<caption>Updated every { peer_update_interval } seconds</caption>
<tr>
<th scope="row">Minimum</th>
<td>{ ipv4.peer_histogram.min }</td>
</tr>
<tr>
<th scope="row">10th percentile</th>
<td>{ ipv4.peer_histogram.p10 }</td>
</tr>
<tr>
<th scope="row">20th percentile</th>
<td>{ ipv4.peer_histogram.p20 }</td>
</tr>
<tr>
<th scope="row">30th percentile</th>
<td>{ ipv4.peer_histogram.p30 }</td>
</tr>
<tr>
<th scope="row">40th percentile</th>
<td>{ ipv4.peer_histogram.p40 }</td>
</tr>
<tr>
<th scope="row">50th percentile</th>
<td>{ ipv4.peer_histogram.p50 }</td>
</tr>
<tr>
<th scope="row">60th percentile</th>
<td>{ ipv4.peer_histogram.p60 }</td>
</tr>
<tr>
<th scope="row">70th percentile</th>
<td>{ ipv4.peer_histogram.p70 }</td>
</tr>
<tr>
<th scope="row">80th percentile</th>
<td>{ ipv4.peer_histogram.p80 }</td>
</tr>
<tr>
<th scope="row">90th percentile</th>
<td>{ ipv4.peer_histogram.p90 }</td>
</tr>
<tr>
<th scope="row">95th percentile</th>
<td>{ ipv4.peer_histogram.p95 }</td>
</tr>
<tr>
<th scope="row">99th percentile</th>
<td>{ ipv4.peer_histogram.p99 }</td>
</tr>
<tr>
<th scope="row">99.9th percentile</th>
<td>{ ipv4.peer_histogram.p999 }</td>
</tr>
<tr>
<th scope="row">Maximum</th>
<td>{ ipv4.peer_histogram.max }</td>
</tr>
</table>
{{ endif }}
{{ endif }}
{{ if ipv6_active }}
<h2>IPv6</h2>
<table>
<caption>* Peer count is updated every { peer_update_interval } seconds</caption>
<tr>
<th scope="row">Number of torrents</th>
<td>{ ipv6.num_torrents }</td>
</tr>
<tr>
<th scope="row">Number of peers</th>
<td>{ ipv6.num_peers } *</td>
</tr>
<tr>
<th scope="row">Requests / second</th>
<td>{ ipv6.requests_per_second }</td>
</tr>
<tr>
<th scope="row">Total responses / second</th>
<td>{ ipv6.responses_per_second_total }</td>
</tr>
<tr>
<th scope="row">Connect responses / second</th>
<td>{ ipv6.responses_per_second_connect }</td>
</tr>
<tr>
<th scope="row">Announce responses / second</th>
<td>{ ipv6.responses_per_second_announce }</td>
</tr>
<tr>
<th scope="row">Scrape responses / second</th>
<td>{ ipv6.responses_per_second_scrape }</td>
</tr>
<tr>
<th scope="row">Error responses / second</th>
<td>{ ipv6.responses_per_second_error }</td>
</tr>
<tr>
<th scope="row">Bandwidth (RX)</th>
<td>{ ipv6.rx_mbits } mbit/s</td>
</tr>
<tr>
<th scope="row">Bandwidth (TX)</th>
<td>{ ipv6.tx_mbits } mbit/s</td>
</tr>
</table>
{{ if extended_active }}
<h3>Peers per torrent</h3>
<table>
<caption>Updated every { peer_update_interval } seconds</caption>
<tr>
<th scope="row">Minimum</th>
<td>{ ipv6.peer_histogram.min }</td>
</tr>
<tr>
<th scope="row">10th percentile</th>
<td>{ ipv6.peer_histogram.p10 }</td>
</tr>
<tr>
<th scope="row">20th percentile</th>
<td>{ ipv6.peer_histogram.p20 }</td>
</tr>
<tr>
<th scope="row">30th percentile</th>
<td>{ ipv6.peer_histogram.p30 }</td>
</tr>
<tr>
<th scope="row">40th percentile</th>
<td>{ ipv6.peer_histogram.p40 }</td>
</tr>
<tr>
<th scope="row">50th percentile</th>
<td>{ ipv6.peer_histogram.p50 }</td>
</tr>
<tr>
<th scope="row">60th percentile</th>
<td>{ ipv6.peer_histogram.p60 }</td>
</tr>
<tr>
<th scope="row">70th percentile</th>
<td>{ ipv6.peer_histogram.p70 }</td>
</tr>
<tr>
<th scope="row">80th percentile</th>
<td>{ ipv6.peer_histogram.p80 }</td>
</tr>
<tr>
<th scope="row">90th percentile</th>
<td>{ ipv6.peer_histogram.p90 }</td>
</tr>
<tr>
<th scope="row">95th percentile</th>
<td>{ ipv6.peer_histogram.p95 }</td>
</tr>
<tr>
<th scope="row">99th percentile</th>
<td>{ ipv6.peer_histogram.p99 }</td>
</tr>
<tr>
<th scope="row">99.9th percentile</th>
<td>{ ipv6.peer_histogram.p999 }</td>
</tr>
<tr>
<th scope="row">Maximum</th>
<td>{ ipv6.peer_histogram.max }</td>
</tr>
</table>
{{ endif }}
{{ endif }}
{{ if extended_active }}
<h2>Peer clients</h2>
<table>
<thead>
<tr>
<th>Client</th>
<th>Count</th>
</tr>
</thead>
<tbody>
{{ for value in peer_clients }}
<tr>
<td>{ value.0 }</td>
<td>{ value.1 }</td>
</tr>
{{ endfor }}
</tbody>
</table>
{{ endif }}
</body>
</html>

View file

@ -0,0 +1,108 @@
mod common;
use common::*;
use std::{
fs::File,
io::Write,
net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket},
time::Duration,
};
use anyhow::Context;
use aquatic_common::access_list::AccessListMode;
use aquatic_udp::config::Config;
use aquatic_udp_protocol::{InfoHash, Response};
#[test]
fn test_access_list_deny() -> anyhow::Result<()> {
const TRACKER_PORT: u16 = 40_113;
let deny = InfoHash([0; 20]);
let allow = InfoHash([1; 20]);
test_access_list(TRACKER_PORT, allow, deny, deny, AccessListMode::Deny)?;
Ok(())
}
#[test]
fn test_access_list_allow() -> anyhow::Result<()> {
const TRACKER_PORT: u16 = 40_114;
let allow = InfoHash([0; 20]);
let deny = InfoHash([1; 20]);
test_access_list(TRACKER_PORT, allow, deny, allow, AccessListMode::Allow)?;
Ok(())
}
fn test_access_list(
tracker_port: u16,
info_hash_success: InfoHash,
info_hash_fail: InfoHash,
info_hash_in_list: InfoHash,
mode: AccessListMode,
) -> anyhow::Result<()> {
let access_list_dir = tempfile::tempdir().with_context(|| "get temporary directory")?;
let access_list_path = access_list_dir.path().join("access-list.txt");
let mut access_list_file =
File::create(&access_list_path).with_context(|| "create access list file")?;
writeln!(
access_list_file,
"{}",
hex::encode_upper(info_hash_in_list.0)
)
.with_context(|| "write to access list file")?;
let mut config = Config::default();
config.network.address.set_port(tracker_port);
config.access_list.mode = mode;
config.access_list.path = access_list_path;
run_tracker(config);
let tracker_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, tracker_port));
let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
let socket = UdpSocket::bind(peer_addr)?;
socket.set_read_timeout(Some(Duration::from_secs(1)))?;
let connection_id = connect(&socket, tracker_addr).with_context(|| "connect")?;
let response = announce(
&socket,
tracker_addr,
connection_id,
1,
info_hash_fail,
10,
false,
)
.with_context(|| "announce")?;
assert!(
matches!(response, Response::Error(_)),
"response should be error but is {:?}",
response
);
let response = announce(
&socket,
tracker_addr,
connection_id,
1,
info_hash_success,
10,
false,
)
.with_context(|| "announce")?;
assert!(matches!(response, Response::AnnounceIpv4(_)));
Ok(())
}

View file

@ -0,0 +1,123 @@
#![allow(dead_code)]
use std::{
io::Cursor,
net::{SocketAddr, UdpSocket},
time::Duration,
};
use anyhow::Context;
use aquatic_udp::{common::BUFFER_SIZE, config::Config};
use aquatic_udp_protocol::{
common::PeerId, AnnounceEvent, AnnounceRequest, ConnectRequest, ConnectionId, InfoHash,
NumberOfBytes, NumberOfPeers, PeerKey, Port, Request, Response, ScrapeRequest, ScrapeResponse,
TransactionId,
};
// FIXME: should ideally try different ports and use sync primitives to find
// out if tracker was successfully started
pub fn run_tracker(config: Config) {
::std::thread::spawn(move || {
aquatic_udp::run(config).unwrap();
});
::std::thread::sleep(Duration::from_secs(1));
}
pub fn connect(socket: &UdpSocket, tracker_addr: SocketAddr) -> anyhow::Result<ConnectionId> {
let request = Request::Connect(ConnectRequest {
transaction_id: TransactionId(0),
});
let response = request_and_response(&socket, tracker_addr, request)?;
if let Response::Connect(response) = response {
Ok(response.connection_id)
} else {
Err(anyhow::anyhow!("not connect response: {:?}", response))
}
}
pub fn announce(
socket: &UdpSocket,
tracker_addr: SocketAddr,
connection_id: ConnectionId,
peer_port: u16,
info_hash: InfoHash,
peers_wanted: usize,
seeder: bool,
) -> anyhow::Result<Response> {
let mut peer_id = PeerId([0; 20]);
for chunk in peer_id.0.chunks_exact_mut(2) {
chunk.copy_from_slice(&peer_port.to_ne_bytes());
}
let request = Request::Announce(AnnounceRequest {
connection_id,
transaction_id: TransactionId(0),
info_hash,
peer_id,
bytes_downloaded: NumberOfBytes(0),
bytes_uploaded: NumberOfBytes(0),
bytes_left: NumberOfBytes(if seeder { 0 } else { 1 }),
event: AnnounceEvent::Started,
ip_address: None,
key: PeerKey(0),
peers_wanted: NumberOfPeers(peers_wanted as i32),
port: Port(peer_port),
});
Ok(request_and_response(&socket, tracker_addr, request)?)
}
pub fn scrape(
socket: &UdpSocket,
tracker_addr: SocketAddr,
connection_id: ConnectionId,
info_hashes: Vec<InfoHash>,
) -> anyhow::Result<ScrapeResponse> {
let request = Request::Scrape(ScrapeRequest {
connection_id,
transaction_id: TransactionId(0),
info_hashes,
});
let response = request_and_response(&socket, tracker_addr, request)?;
if let Response::Scrape(response) = response {
Ok(response)
} else {
return Err(anyhow::anyhow!("not scrape response: {:?}", response));
}
}
pub fn request_and_response(
socket: &UdpSocket,
tracker_addr: SocketAddr,
request: Request,
) -> anyhow::Result<Response> {
let mut buffer = [0u8; BUFFER_SIZE];
{
let mut buffer = Cursor::new(&mut buffer[..]);
request
.write(&mut buffer)
.with_context(|| "write request")?;
let bytes_written = buffer.position() as usize;
socket
.send_to(&(buffer.into_inner())[..bytes_written], tracker_addr)
.with_context(|| "send request")?;
}
{
let (bytes_read, _) = socket
.recv_from(&mut buffer)
.with_context(|| "recv response")?;
Ok(Response::from_bytes(&buffer[..bytes_read], true).with_context(|| "parse response")?)
}
}

View file

@ -0,0 +1,94 @@
mod common;
use common::*;
use std::{
io::{Cursor, ErrorKind},
net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket},
time::Duration,
};
use anyhow::Context;
use aquatic_udp::{common::BUFFER_SIZE, config::Config};
use aquatic_udp_protocol::{
common::PeerId, AnnounceEvent, AnnounceRequest, ConnectionId, InfoHash, NumberOfBytes,
NumberOfPeers, PeerKey, Port, Request, ScrapeRequest, TransactionId,
};
#[test]
fn test_invalid_connection_id() -> anyhow::Result<()> {
const TRACKER_PORT: u16 = 40_112;
let mut config = Config::default();
config.network.address.set_port(TRACKER_PORT);
run_tracker(config);
let tracker_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, TRACKER_PORT));
let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
let socket = UdpSocket::bind(peer_addr)?;
socket.set_read_timeout(Some(Duration::from_secs(1)))?;
// Send connect request to make sure that the tracker in fact responds to
// valid requests
let connection_id = connect(&socket, tracker_addr).with_context(|| "connect")?;
let invalid_connection_id = ConnectionId(!connection_id.0);
let announce_request = Request::Announce(AnnounceRequest {
connection_id: invalid_connection_id,
transaction_id: TransactionId(0),
info_hash: InfoHash([0; 20]),
peer_id: PeerId([0; 20]),
bytes_downloaded: NumberOfBytes(0),
bytes_uploaded: NumberOfBytes(0),
bytes_left: NumberOfBytes(0),
event: AnnounceEvent::Started,
ip_address: None,
key: PeerKey(0),
peers_wanted: NumberOfPeers(10),
port: Port(1),
});
let scrape_request = Request::Scrape(ScrapeRequest {
connection_id: invalid_connection_id,
transaction_id: TransactionId(0),
info_hashes: vec![InfoHash([0; 20])],
});
no_response(&socket, tracker_addr, announce_request).with_context(|| "announce")?;
no_response(&socket, tracker_addr, scrape_request).with_context(|| "scrape")?;
Ok(())
}
fn no_response(
socket: &UdpSocket,
tracker_addr: SocketAddr,
request: Request,
) -> anyhow::Result<()> {
let mut buffer = [0u8; BUFFER_SIZE];
{
let mut buffer = Cursor::new(&mut buffer[..]);
request
.write(&mut buffer)
.with_context(|| "write request")?;
let bytes_written = buffer.position() as usize;
socket
.send_to(&(buffer.into_inner())[..bytes_written], tracker_addr)
.with_context(|| "send request")?;
}
match socket.recv_from(&mut buffer) {
Ok(_) => Err(anyhow::anyhow!("received response")),
Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(()),
Err(err) => Err(err.into()),
}
}

View file

@ -0,0 +1,99 @@
mod common;
use common::*;
use std::{
collections::{hash_map::RandomState, HashSet},
net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket},
time::Duration,
};
use anyhow::Context;
use aquatic_udp::config::Config;
use aquatic_udp_protocol::{InfoHash, Response};
#[test]
fn test_multiple_connect_announce_scrape() -> anyhow::Result<()> {
const TRACKER_PORT: u16 = 40_111;
const PEER_PORT_START: u16 = 30_000;
const PEERS_WANTED: usize = 10;
let mut config = Config::default();
config.network.address.set_port(TRACKER_PORT);
run_tracker(config);
let tracker_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, TRACKER_PORT));
let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
let info_hash = InfoHash([0; 20]);
let mut num_seeders = 0;
let mut num_leechers = 0;
for i in 0..20 {
let is_seeder = i % 3 == 0;
if is_seeder {
num_seeders += 1;
} else {
num_leechers += 1;
}
let socket = UdpSocket::bind(peer_addr)?;
socket.set_read_timeout(Some(Duration::from_secs(1)))?;
let connection_id = connect(&socket, tracker_addr).with_context(|| "connect")?;
let announce_response = {
let response = announce(
&socket,
tracker_addr,
connection_id,
PEER_PORT_START + i as u16,
info_hash,
PEERS_WANTED,
is_seeder,
)
.with_context(|| "announce")?;
if let Response::AnnounceIpv4(response) = response {
response
} else {
return Err(anyhow::anyhow!("not announce response: {:?}", response));
}
};
assert_eq!(announce_response.peers.len(), i.min(PEERS_WANTED));
assert_eq!(announce_response.seeders.0, num_seeders);
assert_eq!(announce_response.leechers.0, num_leechers);
let response_peer_ports: HashSet<u16, RandomState> =
HashSet::from_iter(announce_response.peers.iter().map(|p| p.port.0));
let expected_peer_ports: HashSet<u16, RandomState> =
HashSet::from_iter((0..i).map(|i| PEER_PORT_START + i as u16));
if i > PEERS_WANTED {
assert!(response_peer_ports.is_subset(&expected_peer_ports));
} else {
assert_eq!(response_peer_ports, expected_peer_ports);
}
let scrape_response = scrape(
&socket,
tracker_addr,
connection_id,
vec![info_hash, InfoHash([1; 20])],
)
.with_context(|| "scrape")?;
assert_eq!(scrape_response.torrent_stats[0].seeders.0, num_seeders);
assert_eq!(scrape_response.torrent_stats[0].leechers.0, num_leechers);
assert_eq!(scrape_response.torrent_stats[1].seeders.0, 0);
assert_eq!(scrape_response.torrent_stats[1].leechers.0, 0);
}
Ok(())
}