Move all crates to new crates dir

This commit is contained in:
Joakim Frostegård 2023-10-18 23:53:41 +02:00
parent 3835da22ac
commit 9b032f7e24
128 changed files with 27 additions and 26 deletions

41
crates/common/Cargo.toml Normal file
View file

@ -0,0 +1,41 @@
[package]
name = "aquatic_common"
description = "aquatic BitTorrent tracker common code"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
readme.workspace = true
rust-version.workspace = true
[lib]
name = "aquatic_common"
[features]
rustls = ["dep:rustls", "rustls-pemfile"]
[dependencies]
aquatic_toml_config.workspace = true
ahash = "0.8"
anyhow = "1"
arc-swap = "1"
duplicate = "1"
git-testament = "0.2"
hashbrown = "0.14"
hex = "0.4"
indexmap = "2"
libc = "0.2"
log = "0.4"
privdrop = "0.5"
rand = { version = "0.8", features = ["small_rng"] }
serde = { version = "1", features = ["derive"] }
simple_logger = { version = "4", features = ["stderr"] }
toml = "0.5"
# Optional
glommio = { version = "0.8", optional = true }
hwloc = { version = "0.5", optional = true }
rustls = { version = "0.21", optional = true }
rustls-pemfile = { version = "1", optional = true }

View file

@ -0,0 +1,196 @@
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Context;
use aquatic_toml_config::TomlConfig;
use arc_swap::{ArcSwap, Cache};
use hashbrown::HashSet;
use serde::{Deserialize, Serialize};
/// Access list mode. Available modes are allow, deny and off.
#[derive(Clone, Copy, Debug, PartialEq, TomlConfig, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AccessListMode {
/// Only serve torrents with info hash present in file
Allow,
/// Do not serve torrents if info hash present in file
Deny,
/// Turn off access list functionality
Off,
}
impl AccessListMode {
pub fn is_on(&self) -> bool {
!matches!(self, Self::Off)
}
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct AccessListConfig {
pub mode: AccessListMode,
/// Path to access list file consisting of newline-separated hex-encoded info hashes.
///
/// If using chroot mode, path must be relative to new root.
pub path: PathBuf,
}
impl Default for AccessListConfig {
fn default() -> Self {
Self {
path: "./access-list.txt".into(),
mode: AccessListMode::Off,
}
}
}
#[derive(Default, Clone)]
pub struct AccessList(HashSet<[u8; 20]>);
impl AccessList {
pub fn insert_from_line(&mut self, line: &str) -> anyhow::Result<()> {
self.0.insert(parse_info_hash(line)?);
Ok(())
}
pub fn create_from_path(path: &PathBuf) -> anyhow::Result<Self> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut new_list = Self::default();
for line in reader.lines() {
let line = line?;
let line = line.trim();
if line.is_empty() {
continue;
}
new_list
.insert_from_line(&line)
.with_context(|| format!("Invalid line in access list: {}", line))?;
}
Ok(new_list)
}
pub fn allows(&self, mode: AccessListMode, info_hash: &[u8; 20]) -> bool {
match mode {
AccessListMode::Allow => self.0.contains(info_hash),
AccessListMode::Deny => !self.0.contains(info_hash),
AccessListMode::Off => true,
}
}
pub fn len(&self) -> usize {
self.0.len()
}
}
pub trait AccessListQuery {
fn update(&self, config: &AccessListConfig) -> anyhow::Result<()>;
fn allows(&self, list_mode: AccessListMode, info_hash_bytes: &[u8; 20]) -> bool;
}
pub type AccessListArcSwap = ArcSwap<AccessList>;
pub type AccessListCache = Cache<Arc<AccessListArcSwap>, Arc<AccessList>>;
impl AccessListQuery for AccessListArcSwap {
fn update(&self, config: &AccessListConfig) -> anyhow::Result<()> {
self.store(Arc::new(AccessList::create_from_path(&config.path)?));
Ok(())
}
fn allows(&self, mode: AccessListMode, info_hash_bytes: &[u8; 20]) -> bool {
match mode {
AccessListMode::Allow => self.load().0.contains(info_hash_bytes),
AccessListMode::Deny => !self.load().0.contains(info_hash_bytes),
AccessListMode::Off => true,
}
}
}
pub fn create_access_list_cache(arc_swap: &Arc<AccessListArcSwap>) -> AccessListCache {
Cache::from(Arc::clone(arc_swap))
}
pub fn update_access_list(
config: &AccessListConfig,
access_list: &Arc<AccessListArcSwap>,
) -> anyhow::Result<()> {
if config.mode.is_on() {
match access_list.update(config) {
Ok(()) => {
::log::info!("Access list updated")
}
Err(err) => {
::log::error!("Updating access list failed: {:#}", err);
return Err(err);
}
}
}
Ok(())
}
fn parse_info_hash(line: &str) -> anyhow::Result<[u8; 20]> {
let mut bytes = [0u8; 20];
hex::decode_to_slice(line, &mut bytes)?;
Ok(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_info_hash() {
let f = parse_info_hash;
assert!(f("aaaabbbbccccddddeeeeaaaabbbbccccddddeeee".into()).is_ok());
assert!(f("aaaabbbbccccddddeeeeaaaabbbbccccddddeeeef".into()).is_err());
assert!(f("aaaabbbbccccddddeeeeaaaabbbbccccddddeee".into()).is_err());
assert!(f("aaaabbbbccccddddeeeeaaaabbbbccccddddeeeö".into()).is_err());
}
#[test]
fn test_cache_allows() {
let mut access_list = AccessList::default();
let a = parse_info_hash("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").unwrap();
let b = parse_info_hash("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb").unwrap();
let c = parse_info_hash("cccccccccccccccccccccccccccccccccccccccc").unwrap();
access_list.0.insert(a);
access_list.0.insert(b);
let access_list = Arc::new(ArcSwap::new(Arc::new(access_list)));
let mut access_list_cache = Cache::new(Arc::clone(&access_list));
assert!(access_list_cache.load().allows(AccessListMode::Allow, &a));
assert!(access_list_cache.load().allows(AccessListMode::Allow, &b));
assert!(!access_list_cache.load().allows(AccessListMode::Allow, &c));
assert!(!access_list_cache.load().allows(AccessListMode::Deny, &a));
assert!(!access_list_cache.load().allows(AccessListMode::Deny, &b));
assert!(access_list_cache.load().allows(AccessListMode::Deny, &c));
assert!(access_list_cache.load().allows(AccessListMode::Off, &a));
assert!(access_list_cache.load().allows(AccessListMode::Off, &b));
assert!(access_list_cache.load().allows(AccessListMode::Off, &c));
access_list.store(Arc::new(AccessList::default()));
assert!(access_list_cache.load().allows(AccessListMode::Deny, &a));
assert!(access_list_cache.load().allows(AccessListMode::Deny, &b));
}
}

239
crates/common/src/cli.rs Normal file
View file

@ -0,0 +1,239 @@
use std::fs::File;
use std::io::Read;
use anyhow::Context;
use aquatic_toml_config::TomlConfig;
use git_testament::{git_testament, CommitKind};
use log::LevelFilter;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use simple_logger::SimpleLogger;
/// Log level. Available values are off, error, warn, info, debug and trace.
#[derive(Debug, Clone, Copy, PartialEq, TomlConfig, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
Off,
Error,
Warn,
Info,
Debug,
Trace,
}
impl Default for LogLevel {
fn default() -> Self {
Self::Warn
}
}
pub trait Config: Default + TomlConfig + DeserializeOwned + std::fmt::Debug {
fn get_log_level(&self) -> Option<LogLevel> {
None
}
}
#[derive(Debug, Default)]
pub struct Options {
config_file: Option<String>,
print_config: bool,
print_parsed_config: bool,
print_version: bool,
}
impl Options {
pub fn parse_args<I>(mut arg_iter: I) -> Result<Options, Option<String>>
where
I: Iterator<Item = String>,
{
let mut options = Options::default();
loop {
if let Some(arg) = arg_iter.next() {
match arg.as_str() {
"-c" | "--config-file" => {
if let Some(path) = arg_iter.next() {
options.config_file = Some(path);
} else {
return Err(Some("No config file path given".to_string()));
}
}
"-p" | "--print-config" => {
options.print_config = true;
}
"-P" => {
options.print_parsed_config = true;
}
"-v" | "--version" => {
options.print_version = true;
}
"-h" | "--help" => {
return Err(None);
}
"" => (),
_ => {
return Err(Some("Unrecognized argument".to_string()));
}
}
} else {
break;
}
}
Ok(options)
}
}
pub fn run_app_with_cli_and_config<T>(
app_title: &str,
crate_version: &str,
// Function that takes config file and runs application
app_fn: fn(T) -> anyhow::Result<()>,
opts: Option<Options>,
) where
T: Config,
{
::std::process::exit(match run_inner(app_title, crate_version, app_fn, opts) {
Ok(()) => 0,
Err(err) => {
eprintln!("Error: {:#}", err);
1
}
})
}
fn run_inner<T>(
app_title: &str,
crate_version: &str,
// Function that takes config file and runs application
app_fn: fn(T) -> anyhow::Result<()>,
// Possibly preparsed options
options: Option<Options>,
) -> anyhow::Result<()>
where
T: Config,
{
let options = if let Some(options) = options {
options
} else {
let mut arg_iter = ::std::env::args();
let app_path = arg_iter.next().unwrap();
match Options::parse_args(arg_iter) {
Ok(options) => options,
Err(opt_err) => {
let gen_info = || format!("{}\n\nUsage: {} [OPTIONS]", app_title, app_path);
print_help(gen_info, opt_err);
return Ok(());
}
}
};
if options.print_version {
let commit_info = get_commit_info();
println!("{}{}", crate_version, commit_info);
Ok(())
} else if options.print_config {
print!("{}", default_config_as_toml::<T>());
Ok(())
} else {
let config = if let Some(path) = options.config_file {
config_from_toml_file(path)?
} else {
T::default()
};
if let Some(log_level) = config.get_log_level() {
start_logger(log_level)?;
}
if options.print_parsed_config {
println!("Running with configuration: {:#?}", config);
}
app_fn(config)
}
}
pub fn print_help<F>(info_generator: F, opt_error: Option<String>)
where
F: FnOnce() -> String,
{
println!("{}", info_generator());
println!("\nOptions:");
println!(" -c, --config-file Load config from this path");
println!(" -h, --help Print this help message");
println!(" -p, --print-config Print default config");
println!(" -P Print parsed config");
println!(" -v, --version Print version information");
if let Some(error) = opt_error {
println!("\nError: {}.", error);
}
}
fn config_from_toml_file<T>(path: String) -> anyhow::Result<T>
where
T: DeserializeOwned,
{
let mut file = File::open(path.clone())
.with_context(|| format!("Couldn't open config file {}", path.clone()))?;
let mut data = String::new();
file.read_to_string(&mut data)
.with_context(|| format!("Couldn't read config file {}", path.clone()))?;
toml::from_str(&data).with_context(|| format!("Couldn't parse config file {}", path.clone()))
}
fn default_config_as_toml<T>() -> String
where
T: Default + TomlConfig,
{
<T as TomlConfig>::default_to_string()
}
fn start_logger(log_level: LogLevel) -> ::anyhow::Result<()> {
let level_filter = match log_level {
LogLevel::Off => LevelFilter::Off,
LogLevel::Error => LevelFilter::Error,
LogLevel::Warn => LevelFilter::Warn,
LogLevel::Info => LevelFilter::Info,
LogLevel::Debug => LevelFilter::Debug,
LogLevel::Trace => LevelFilter::Trace,
};
SimpleLogger::new()
.with_level(level_filter)
.with_utc_timestamps()
.init()
.context("Couldn't initialize logger")?;
Ok(())
}
fn get_commit_info() -> String {
git_testament!(TESTAMENT);
match TESTAMENT.commit {
CommitKind::NoTags(hash, date) => {
format!(" ({} - {})", first_8_chars(hash), date)
}
CommitKind::FromTag(_tag, hash, date, _tag_distance) => {
format!(" ({} - {})", first_8_chars(hash), date)
}
_ => String::new(),
}
}
fn first_8_chars(input: &str) -> String {
input.chars().take(8).collect()
}

View file

@ -0,0 +1,415 @@
//! Experimental CPU pinning
use aquatic_toml_config::TomlConfig;
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, TomlConfig, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CpuPinningDirection {
Ascending,
Descending,
}
impl Default for CpuPinningDirection {
fn default() -> Self {
Self::Ascending
}
}
#[cfg(feature = "glommio")]
#[derive(Clone, Copy, Debug, PartialEq, TomlConfig, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum HyperThreadMapping {
System,
Subsequent,
Split,
}
#[cfg(feature = "glommio")]
impl Default for HyperThreadMapping {
fn default() -> Self {
Self::System
}
}
pub trait CpuPinningConfig {
fn active(&self) -> bool;
fn direction(&self) -> CpuPinningDirection;
#[cfg(feature = "glommio")]
fn hyperthread(&self) -> HyperThreadMapping;
fn core_offset(&self) -> usize;
}
// Do these shenanigans for compatibility with aquatic_toml_config
#[duplicate::duplicate_item(
mod_name struct_name cpu_pinning_direction;
[asc] [CpuPinningConfigAsc] [CpuPinningDirection::Ascending];
[desc] [CpuPinningConfigDesc] [CpuPinningDirection::Descending];
)]
pub mod mod_name {
use super::*;
/// Experimental cpu pinning
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
pub struct struct_name {
pub active: bool,
pub direction: CpuPinningDirection,
#[cfg(feature = "glommio")]
pub hyperthread: HyperThreadMapping,
pub core_offset: usize,
}
impl Default for struct_name {
fn default() -> Self {
Self {
active: false,
direction: cpu_pinning_direction,
#[cfg(feature = "glommio")]
hyperthread: Default::default(),
core_offset: 0,
}
}
}
impl CpuPinningConfig for struct_name {
fn active(&self) -> bool {
self.active
}
fn direction(&self) -> CpuPinningDirection {
self.direction
}
#[cfg(feature = "glommio")]
fn hyperthread(&self) -> HyperThreadMapping {
self.hyperthread
}
fn core_offset(&self) -> usize {
self.core_offset
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum WorkerIndex {
SocketWorker(usize),
SwarmWorker(usize),
Util,
}
impl WorkerIndex {
pub fn get_core_index<C: CpuPinningConfig>(
&self,
config: &C,
socket_workers: usize,
swarm_workers: usize,
num_cores: usize,
) -> usize {
let ascending_index = match self {
Self::SocketWorker(index) => config.core_offset() + index,
Self::SwarmWorker(index) => config.core_offset() + socket_workers + index,
Self::Util => config.core_offset() + socket_workers + swarm_workers,
};
let max_core_index = num_cores - 1;
let ascending_index = ascending_index.min(max_core_index);
match config.direction() {
CpuPinningDirection::Ascending => ascending_index,
CpuPinningDirection::Descending => max_core_index - ascending_index,
}
}
}
#[cfg(feature = "glommio")]
pub mod glommio {
use ::glommio::{CpuSet, Placement};
use super::*;
fn get_cpu_set() -> anyhow::Result<CpuSet> {
CpuSet::online().map_err(|err| anyhow::anyhow!("Couldn't get CPU set: {:#}", err))
}
fn get_num_cpu_cores() -> anyhow::Result<usize> {
get_cpu_set()?
.iter()
.map(|l| l.core)
.max()
.map(|index| index + 1)
.ok_or(anyhow::anyhow!("CpuSet is empty"))
}
fn logical_cpus_string(cpu_set: &CpuSet) -> String {
let mut logical_cpus = cpu_set.iter().map(|l| l.cpu).collect::<Vec<usize>>();
logical_cpus.sort_unstable();
logical_cpus
.into_iter()
.map(|cpu| cpu.to_string())
.collect::<Vec<String>>()
.join(", ")
}
fn get_worker_cpu_set<C: CpuPinningConfig>(
config: &C,
socket_workers: usize,
swarm_workers: usize,
worker_index: WorkerIndex,
) -> anyhow::Result<CpuSet> {
let num_cpu_cores = get_num_cpu_cores()?;
let core_index =
worker_index.get_core_index(config, socket_workers, swarm_workers, num_cpu_cores);
let too_many_workers = match (&config.hyperthread(), &config.direction()) {
(
HyperThreadMapping::Split | HyperThreadMapping::Subsequent,
CpuPinningDirection::Ascending,
) => core_index >= num_cpu_cores / 2,
(
HyperThreadMapping::Split | HyperThreadMapping::Subsequent,
CpuPinningDirection::Descending,
) => core_index < num_cpu_cores / 2,
(_, _) => false,
};
if too_many_workers {
return Err(anyhow::anyhow!("CPU pinning: total number of workers (including the single utility worker) can not exceed number of virtual CPUs / 2 - core_offset in this hyperthread mapping mode"));
}
let cpu_set = match config.hyperthread() {
HyperThreadMapping::System => get_cpu_set()?.filter(|l| l.core == core_index),
HyperThreadMapping::Split => match config.direction() {
CpuPinningDirection::Ascending => get_cpu_set()?
.filter(|l| l.cpu == core_index || l.cpu == core_index + num_cpu_cores / 2),
CpuPinningDirection::Descending => get_cpu_set()?
.filter(|l| l.cpu == core_index || l.cpu == core_index - num_cpu_cores / 2),
},
HyperThreadMapping::Subsequent => {
let cpu_index_offset = match config.direction() {
// 0 -> 0 and 1
// 1 -> 2 and 3
// 2 -> 4 and 5
CpuPinningDirection::Ascending => core_index * 2,
// 15 -> 14 and 15
// 14 -> 12 and 13
// 13 -> 10 and 11
CpuPinningDirection::Descending => {
num_cpu_cores - 2 * (num_cpu_cores - core_index)
}
};
get_cpu_set()?
.filter(|l| l.cpu == cpu_index_offset || l.cpu == cpu_index_offset + 1)
}
};
if cpu_set.is_empty() {
Err(anyhow::anyhow!(
"CPU pinning: produced empty CPU set for {:?}. Try decreasing number of workers",
worker_index
))
} else {
::log::info!(
"Logical CPUs for {:?}: {}",
worker_index,
logical_cpus_string(&cpu_set)
);
Ok(cpu_set)
}
}
pub fn get_worker_placement<C: CpuPinningConfig>(
config: &C,
socket_workers: usize,
swarm_workers: usize,
worker_index: WorkerIndex,
) -> anyhow::Result<Placement> {
if config.active() {
let cpu_set = get_worker_cpu_set(config, socket_workers, swarm_workers, worker_index)?;
Ok(Placement::Fenced(cpu_set))
} else {
Ok(Placement::Unbound)
}
}
pub fn set_affinity_for_util_worker<C: CpuPinningConfig>(
config: &C,
socket_workers: usize,
swarm_workers: usize,
) -> anyhow::Result<()> {
let worker_cpu_set =
get_worker_cpu_set(config, socket_workers, swarm_workers, WorkerIndex::Util)?;
unsafe {
let mut set: libc::cpu_set_t = ::std::mem::zeroed();
for cpu_location in worker_cpu_set {
libc::CPU_SET(cpu_location.cpu, &mut set);
}
let status = libc::pthread_setaffinity_np(
libc::pthread_self(),
::std::mem::size_of::<libc::cpu_set_t>(),
&set,
);
if status != 0 {
return Err(anyhow::Error::new(::std::io::Error::from_raw_os_error(
status,
)));
}
}
Ok(())
}
}
/// Pin current thread to a suitable core
///
/// Requires hwloc (`apt-get install libhwloc-dev`)
#[cfg(feature = "hwloc")]
pub fn pin_current_if_configured_to<C: CpuPinningConfig>(
config: &C,
socket_workers: usize,
swarm_workers: usize,
worker_index: WorkerIndex,
) {
use hwloc::{CpuSet, ObjectType, Topology, CPUBIND_THREAD};
if config.active() {
let mut topology = Topology::new();
let core_cpu_sets: Vec<CpuSet> = topology
.objects_with_type(&ObjectType::Core)
.expect("hwloc: list cores")
.into_iter()
.map(|core| core.allowed_cpuset().expect("hwloc: get core cpu set"))
.collect();
let num_cores = core_cpu_sets.len();
let core_index =
worker_index.get_core_index(config, socket_workers, swarm_workers, num_cores);
let cpu_set = core_cpu_sets
.get(core_index)
.expect(&format!("get cpu set for core {}", core_index))
.to_owned();
topology
.set_cpubind(cpu_set, CPUBIND_THREAD)
.expect(&format!("bind thread to core {}", core_index));
::log::info!(
"Pinned worker {:?} to cpu core {}",
worker_index,
core_index
);
}
}
/// Tell Linux that incoming messages should be handled by the socket worker
/// with the same index as the CPU core receiving the interrupt.
///
/// Requires that sockets are actually bound in order, so waiting has to be done
/// in socket workers.
///
/// It might make sense to first enable RSS or RPS (if hardware doesn't support
/// RSS) and enable sending interrupts to all CPUs that have socket workers
/// running on them. Possibly, CPU 0 should be excluded.
///
/// More Information:
/// - https://talawah.io/blog/extreme-http-performance-tuning-one-point-two-million/
/// - https://www.kernel.org/doc/Documentation/networking/scaling.txt
/// - https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rps
#[cfg(target_os = "linux")]
pub fn socket_attach_cbpf<S: ::std::os::unix::prelude::AsRawFd>(
socket: &S,
_num_sockets: usize,
) -> ::std::io::Result<()> {
use std::mem::size_of;
use std::os::raw::c_void;
use libc::{setsockopt, sock_filter, sock_fprog, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF};
// Good BPF documentation: https://man.openbsd.org/bpf.4
// Values of constants were copied from the following Linux source files:
// - include/uapi/linux/bpf_common.h
// - include/uapi/linux/filter.h
// Instruction
const BPF_LD: u16 = 0x00; // Load into A
// const BPF_LDX: u16 = 0x01; // Load into X
// const BPF_ALU: u16 = 0x04; // Load into X
const BPF_RET: u16 = 0x06; // Return value
// const BPF_MOD: u16 = 0x90; // Run modulo on A
// Size
const BPF_W: u16 = 0x00; // 32-bit width
// Source
// const BPF_IMM: u16 = 0x00; // Use constant (k)
const BPF_ABS: u16 = 0x20;
// Registers
// const BPF_K: u16 = 0x00;
const BPF_A: u16 = 0x10;
// k
const SKF_AD_OFF: i32 = -0x1000; // Activate extensions
const SKF_AD_CPU: i32 = 36; // Extension for getting CPU
// Return index of socket that should receive packet
let mut filter = [
// Store index of CPU receiving packet in register A
sock_filter {
code: BPF_LD | BPF_W | BPF_ABS,
jt: 0,
jf: 0,
k: u32::from_ne_bytes((SKF_AD_OFF + SKF_AD_CPU).to_ne_bytes()),
},
/* Disabled, because it doesn't make a lot of sense
// Run A = A % socket_workers
sock_filter {
code: BPF_ALU | BPF_MOD,
jt: 0,
jf: 0,
k: num_sockets as u32,
},
*/
// Return A
sock_filter {
code: BPF_RET | BPF_A,
jt: 0,
jf: 0,
k: 0,
},
];
let program = sock_fprog {
filter: filter.as_mut_ptr(),
len: filter.len() as u16,
};
let program_ptr: *const sock_fprog = &program;
unsafe {
let result = setsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_ATTACH_REUSEPORT_CBPF,
program_ptr as *const c_void,
size_of::<sock_fprog>() as u32,
);
if result != 0 {
Err(::std::io::Error::last_os_error())
} else {
Ok(())
}
}
}

279
crates/common/src/lib.rs Normal file
View file

@ -0,0 +1,279 @@
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use ahash::RandomState;
use rand::Rng;
pub mod access_list;
pub mod cli;
pub mod cpu_pinning;
pub mod privileges;
#[cfg(feature = "rustls")]
pub mod rustls_config;
/// IndexMap using AHash hasher
pub type IndexMap<K, V> = indexmap::IndexMap<K, V, RandomState>;
/// Peer, connection or similar valid until this instant
#[derive(Debug, Clone, Copy)]
pub struct ValidUntil(SecondsSinceServerStart);
impl ValidUntil {
#[inline]
pub fn new(start_instant: ServerStartInstant, offset_seconds: u32) -> Self {
Self(SecondsSinceServerStart(
start_instant.seconds_elapsed().0 + offset_seconds,
))
}
pub fn new_with_now(now: SecondsSinceServerStart, offset_seconds: u32) -> Self {
Self(SecondsSinceServerStart(now.0 + offset_seconds))
}
pub fn valid(&self, now: SecondsSinceServerStart) -> bool {
self.0 .0 > now.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct ServerStartInstant(Instant);
impl ServerStartInstant {
pub fn new() -> Self {
Self(Instant::now())
}
pub fn seconds_elapsed(&self) -> SecondsSinceServerStart {
SecondsSinceServerStart(
self.0
.elapsed()
.as_secs()
.try_into()
.expect("server ran for more seconds than what fits in a u32"),
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SecondsSinceServerStart(u32);
pub struct PanicSentinelWatcher(Arc<AtomicBool>);
impl PanicSentinelWatcher {
pub fn create_with_sentinel() -> (Self, PanicSentinel) {
let triggered = Arc::new(AtomicBool::new(false));
let sentinel = PanicSentinel(triggered.clone());
(Self(triggered), sentinel)
}
pub fn panic_was_triggered(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
/// Raises SIGTERM when dropped
///
/// Pass to threads to have panics in them cause whole program to exit.
#[derive(Clone)]
pub struct PanicSentinel(Arc<AtomicBool>);
impl Drop for PanicSentinel {
fn drop(&mut self) {
if ::std::thread::panicking() {
let already_triggered = self.0.fetch_or(true, Ordering::SeqCst);
if !already_triggered {
if unsafe { libc::raise(15) } == -1 {
panic!(
"Could not raise SIGTERM: {:#}",
::std::io::Error::last_os_error()
)
}
}
}
}
}
/// SocketAddr that is not an IPv6-mapped IPv4 address
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub struct CanonicalSocketAddr(SocketAddr);
impl CanonicalSocketAddr {
pub fn new(addr: SocketAddr) -> Self {
match addr {
addr @ SocketAddr::V4(_) => Self(addr),
SocketAddr::V6(addr) => {
match addr.ip().octets() {
// Convert IPv4-mapped address (available in std but nightly-only)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => Self(SocketAddr::V4(
SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), addr.port()),
)),
_ => Self(addr.into()),
}
}
}
}
pub fn get_ipv6_mapped(self) -> SocketAddr {
match self.0 {
SocketAddr::V4(addr) => {
let ip = addr.ip().to_ipv6_mapped();
SocketAddr::V6(SocketAddrV6::new(ip, addr.port(), 0, 0))
}
addr => addr,
}
}
pub fn get(self) -> SocketAddr {
self.0
}
pub fn get_ipv4(self) -> Option<SocketAddr> {
match self.0 {
addr @ SocketAddr::V4(_) => Some(addr),
_ => None,
}
}
pub fn is_ipv4(&self) -> bool {
self.0.is_ipv4()
}
}
/// Extract response peers
///
/// If there are more peers in map than `max_num_peers_to_take`, do a random
/// selection of peers from first and second halves of map in order to avoid
/// returning too homogeneous peers.
#[inline]
pub fn extract_response_peers<K, V, R, F>(
rng: &mut impl Rng,
peer_map: &IndexMap<K, V>,
max_num_peers_to_take: usize,
sender_peer_map_key: K,
peer_conversion_function: F,
) -> Vec<R>
where
K: Eq + ::std::hash::Hash,
F: Fn(&V) -> R,
{
if peer_map.len() <= max_num_peers_to_take + 1 {
// This branch: number of peers in map (minus sender peer) is less than
// or equal to number of peers to take, so return all except sender
// peer.
let mut peers = Vec::with_capacity(peer_map.len());
peers.extend(peer_map.iter().filter_map(|(k, v)| {
(*k != sender_peer_map_key).then_some(peer_conversion_function(v))
}));
// Handle the case when sender peer is not in peer list. Typically,
// this function will not be called when this is the case.
if peers.len() > max_num_peers_to_take {
peers.pop();
}
peers
} else {
// Note: if this branch is taken, the peer map contains at least two
// more peers than max_num_peers_to_take
let middle_index = peer_map.len() / 2;
// Add one to take two extra peers in case sender peer is among
// selected peers and will need to be filtered out
let num_to_take_per_half = (max_num_peers_to_take / 2) + 1;
let offset_half_one = {
let from = 0;
let to = usize::max(1, middle_index - num_to_take_per_half);
rng.gen_range(from..to)
};
let offset_half_two = {
let from = middle_index;
let to = usize::max(middle_index + 1, peer_map.len() - num_to_take_per_half);
rng.gen_range(from..to)
};
let end_half_one = offset_half_one + num_to_take_per_half;
let end_half_two = offset_half_two + num_to_take_per_half;
let mut peers = Vec::with_capacity(max_num_peers_to_take + 2);
if let Some(slice) = peer_map.get_range(offset_half_one..end_half_one) {
peers.extend(slice.iter().filter_map(|(k, v)| {
(*k != sender_peer_map_key).then_some(peer_conversion_function(v))
}));
}
if let Some(slice) = peer_map.get_range(offset_half_two..end_half_two) {
peers.extend(slice.iter().filter_map(|(k, v)| {
(*k != sender_peer_map_key).then_some(peer_conversion_function(v))
}));
}
while peers.len() > max_num_peers_to_take {
peers.pop();
}
peers
}
}
#[cfg(test)]
mod tests {
use ahash::HashSet;
use rand::{rngs::SmallRng, SeedableRng};
use super::*;
#[test]
fn test_extract_response_peers() {
let mut rng = SmallRng::from_entropy();
for num_peers_in_map in 0..50 {
for max_num_peers_to_take in 0..50 {
for sender_peer_map_key in 0..50 {
test_extract_response_peers_helper(
&mut rng,
num_peers_in_map,
max_num_peers_to_take,
sender_peer_map_key,
);
}
}
}
}
fn test_extract_response_peers_helper(
rng: &mut SmallRng,
num_peers_in_map: usize,
max_num_peers_to_take: usize,
sender_peer_map_key: usize,
) {
let peer_map = IndexMap::from_iter((0..num_peers_in_map).map(|i| (i, i)));
let response_peers = extract_response_peers(
rng,
&peer_map,
max_num_peers_to_take,
sender_peer_map_key,
|p| *p,
);
if num_peers_in_map > max_num_peers_to_take + 1 {
assert_eq!(response_peers.len(), max_num_peers_to_take);
} else {
assert!(response_peers.len() <= max_num_peers_to_take);
}
assert!(!response_peers.contains(&sender_peer_map_key));
assert_eq!(
response_peers.len(),
HashSet::from_iter(response_peers.iter().copied()).len()
);
}
}

View file

@ -0,0 +1,64 @@
use std::{
path::PathBuf,
sync::{Arc, Barrier},
};
use anyhow::Context;
use privdrop::PrivDrop;
use serde::Deserialize;
use aquatic_toml_config::TomlConfig;
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct PrivilegeConfig {
/// Chroot and switch group and user after binding to sockets
pub drop_privileges: bool,
/// Chroot to this path
pub chroot_path: PathBuf,
/// Group to switch to after chrooting
pub group: String,
/// User to switch to after chrooting
pub user: String,
}
impl Default for PrivilegeConfig {
fn default() -> Self {
Self {
drop_privileges: false,
chroot_path: ".".into(),
user: "nobody".to_string(),
group: "nogroup".to_string(),
}
}
}
#[derive(Clone)]
pub struct PrivilegeDropper {
barrier: Arc<Barrier>,
config: Arc<PrivilegeConfig>,
}
impl PrivilegeDropper {
pub fn new(config: PrivilegeConfig, num_sockets: usize) -> Self {
Self {
barrier: Arc::new(Barrier::new(num_sockets)),
config: Arc::new(config),
}
}
pub fn after_socket_creation(self) -> anyhow::Result<()> {
if self.config.drop_privileges {
if self.barrier.wait().is_leader() {
PrivDrop::default()
.chroot(self.config.chroot_path.clone())
.group(self.config.group.clone())
.user(self.config.user.clone())
.apply()
.with_context(|| "couldn't drop privileges after socket creation")?;
}
}
Ok(())
}
}

View file

@ -0,0 +1,48 @@
use std::{fs::File, io::BufReader, path::Path};
use anyhow::Context;
pub type RustlsConfig = rustls::ServerConfig;
pub fn create_rustls_config(
tls_certificate_path: &Path,
tls_private_key_path: &Path,
) -> anyhow::Result<RustlsConfig> {
let certs = {
let f = File::open(tls_certificate_path).with_context(|| {
format!(
"open tls certificate file at {}",
tls_certificate_path.to_string_lossy()
)
})?;
let mut f = BufReader::new(f);
rustls_pemfile::certs(&mut f)?
.into_iter()
.map(|bytes| rustls::Certificate(bytes))
.collect()
};
let private_key = {
let f = File::open(tls_private_key_path).with_context(|| {
format!(
"open tls private key file at {}",
tls_private_key_path.to_string_lossy()
)
})?;
let mut f = BufReader::new(f);
rustls_pemfile::pkcs8_private_keys(&mut f)?
.first()
.map(|bytes| rustls::PrivateKey(bytes.clone()))
.ok_or(anyhow::anyhow!("No private keys in file"))?
};
let tls_config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, private_key)
.with_context(|| "create rustls config")?;
Ok(tls_config)
}