ws: remove mio implementation

This commit is contained in:
Joakim Frostegård 2022-03-18 15:15:34 +01:00
parent 065e007ede
commit 667cf04085
16 changed files with 283 additions and 1831 deletions

86
Cargo.lock generated
View file

@ -264,18 +264,14 @@ dependencies = [
"aquatic_ws_protocol",
"async-tungstenite",
"cfg-if",
"crossbeam-channel",
"either",
"futures",
"futures-lite",
"futures-rustls",
"glommio",
"hashbrown 0.12.0",
"histogram",
"log",
"mimalloc",
"mio",
"parking_lot",
"privdrop",
"quickcheck",
"quickcheck_macros",
@ -285,7 +281,6 @@ dependencies = [
"serde",
"signal-hook",
"slab",
"socket2 0.4.4",
"tungstenite",
]
@ -1066,12 +1061,6 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "histogram"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12cb882ccb290b8646e554b157ab0b71e64e8d5bef775cd66b6531e52d302669"
[[package]]
name = "http"
version = "0.2.6"
@ -1471,29 +1460,6 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72"
[[package]]
name = "parking_lot"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-sys",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@ -1687,15 +1653,6 @@ dependencies = [
"num_cpus",
]
[[package]]
name = "redox_syscall"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "regex"
version = "1.5.4"
@ -2397,46 +2354,3 @@ name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6"
dependencies = [
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_msvc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5"
[[package]]
name = "windows_i686_gnu"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615"
[[package]]
name = "windows_i686_msvc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172"
[[package]]
name = "windows_x86_64_gnu"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc"
[[package]]
name = "windows_x86_64_msvc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316"

View file

@ -14,10 +14,7 @@ name = "aquatic_ws"
name = "aquatic_ws"
[features]
default = ["with-mio"]
cpu-pinning = ["aquatic_common/cpu-pinning"]
with-glommio = ["cpu-pinning", "async-tungstenite", "futures-lite", "futures", "futures-rustls", "glommio"]
with-mio = ["crossbeam-channel", "histogram", "mio", "parking_lot", "socket2"]
[dependencies]
aquatic_cli_helpers = "0.1.0"
@ -39,20 +36,11 @@ serde = { version = "1", features = ["derive"] }
signal-hook = { version = "0.3" }
slab = "0.4"
tungstenite = "0.17"
# mio
crossbeam-channel = { version = "0.5", optional = true }
histogram = { version = "0.6", optional = true }
mio = { version = "0.8", features = ["net", "os-poll"], optional = true }
parking_lot = { version = "0.12", optional = true }
socket2 = { version = "0.4", features = ["all"], optional = true }
# glommio
async-tungstenite = { version = "0.17", optional = true }
futures-lite = { version = "1", optional = true }
futures = { version = "0.3", optional = true }
futures-rustls = { version = "0.22", optional = true }
glommio = { version = "0.7", optional = true }
async-tungstenite = "0.17"
futures-lite = "1"
futures = "0.3"
futures-rustls = "0.22"
glommio = "0.7"
[dev-dependencies]
quickcheck = "1"

View file

@ -1,5 +1,3 @@
pub mod handlers;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
@ -14,6 +12,13 @@ use aquatic_ws_protocol::*;
use crate::config::Config;
pub type TlsConfig = futures_rustls::rustls::ServerConfig;
#[derive(Default, Clone)]
pub struct State {
pub access_list: Arc<AccessListArcSwap>,
}
#[derive(Copy, Clone, Debug)]
pub struct PendingScrapeId(pub usize);

View file

@ -23,15 +23,28 @@ pub struct Config {
pub log_level: LogLevel,
pub network: NetworkConfig,
pub protocol: ProtocolConfig,
#[cfg(feature = "with-mio")]
pub handlers: HandlerConfig,
pub cleaning: CleaningConfig,
pub privileges: PrivilegeConfig,
pub access_list: AccessListConfig,
#[cfg(feature = "cpu-pinning")]
pub cpu_pinning: CpuPinningConfig,
#[cfg(feature = "with-mio")]
pub statistics: StatisticsConfig,
}
impl Default for Config {
fn default() -> Self {
Self {
socket_workers: 1,
request_workers: 1,
log_level: LogLevel::default(),
network: NetworkConfig::default(),
protocol: ProtocolConfig::default(),
cleaning: CleaningConfig::default(),
privileges: PrivilegeConfig::default(),
access_list: AccessListConfig::default(),
#[cfg(feature = "cpu-pinning")]
cpu_pinning: Default::default(),
}
}
}
impl aquatic_cli_helpers::Config for Config {
@ -55,81 +68,6 @@ pub struct NetworkConfig {
pub websocket_max_message_size: usize,
pub websocket_max_frame_size: usize,
#[cfg(feature = "with-mio")]
pub poll_event_capacity: usize,
#[cfg(feature = "with-mio")]
pub poll_timeout_microseconds: u64,
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default)]
pub struct ProtocolConfig {
/// Maximum number of torrents to accept in scrape request
pub max_scrape_torrents: usize,
/// Maximum number of offers to accept in announce request
pub max_offers: usize,
/// Ask peers to announce this often (seconds)
pub peer_announce_interval: usize,
}
#[cfg(feature = "with-mio")]
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default)]
pub struct HandlerConfig {
/// Maximum number of requests to receive from channel before locking
/// mutex and starting work
pub max_requests_per_iter: usize,
pub channel_recv_timeout_microseconds: u64,
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default)]
pub struct CleaningConfig {
/// Clean peers this often (seconds)
pub torrent_cleaning_interval: u64,
/// Remove peers that have not announced for this long (seconds)
pub max_peer_age: u64,
// Clean connections this often (seconds)
#[cfg(feature = "with-glommio")]
pub connection_cleaning_interval: u64,
/// Close connections if no responses have been sent to them for this long (seconds)
#[cfg(feature = "with-glommio")]
pub max_connection_idle: u64,
/// Remove connections that are older than this (seconds)
#[cfg(feature = "with-mio")]
pub max_connection_age: u64,
}
#[cfg(feature = "with-mio")]
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default)]
pub struct StatisticsConfig {
/// Print statistics this often (seconds). Do not print when set to zero.
pub interval: u64,
}
impl Default for Config {
fn default() -> Self {
Self {
socket_workers: 1,
request_workers: 1,
log_level: LogLevel::default(),
network: NetworkConfig::default(),
protocol: ProtocolConfig::default(),
#[cfg(feature = "with-mio")]
handlers: Default::default(),
cleaning: CleaningConfig::default(),
privileges: PrivilegeConfig::default(),
access_list: AccessListConfig::default(),
#[cfg(feature = "cpu-pinning")]
cpu_pinning: Default::default(),
#[cfg(feature = "with-mio")]
statistics: Default::default(),
}
}
}
impl Default for NetworkConfig {
@ -143,13 +81,19 @@ impl Default for NetworkConfig {
websocket_max_message_size: 64 * 1024,
websocket_max_frame_size: 16 * 1024,
}
}
}
#[cfg(feature = "with-mio")]
poll_event_capacity: 4096,
#[cfg(feature = "with-mio")]
poll_timeout_microseconds: 200_000,
}
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default)]
pub struct ProtocolConfig {
/// Maximum number of torrents to accept in scrape request
pub max_scrape_torrents: usize,
/// Maximum number of offers to accept in announce request
pub max_offers: usize,
/// Ask peers to announce this often (seconds)
pub peer_announce_interval: usize,
}
impl Default for ProtocolConfig {
@ -162,14 +106,17 @@ impl Default for ProtocolConfig {
}
}
#[cfg(feature = "with-mio")]
impl Default for HandlerConfig {
fn default() -> Self {
Self {
max_requests_per_iter: 256,
channel_recv_timeout_microseconds: 200,
}
}
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default)]
pub struct CleaningConfig {
/// Clean peers this often (seconds)
pub torrent_cleaning_interval: u64,
/// Remove peers that have not announced for this long (seconds)
pub max_peer_age: u64,
// Clean connections this often (seconds)
pub connection_cleaning_interval: u64,
/// Close connections if no responses have been sent to them for this long (seconds)
pub max_connection_idle: u64,
}
impl Default for CleaningConfig {
@ -177,24 +124,12 @@ impl Default for CleaningConfig {
Self {
torrent_cleaning_interval: 30,
max_peer_age: 1800,
#[cfg(feature = "with-glommio")]
max_connection_idle: 60 * 5,
#[cfg(feature = "with-mio")]
max_connection_age: 1800,
#[cfg(feature = "with-glommio")]
connection_cleaning_interval: 30,
}
}
}
#[cfg(feature = "with-mio")]
impl Default for StatisticsConfig {
fn default() -> Self {
Self { interval: 0 }
}
}
#[cfg(test)]
mod tests {
use super::Config;

View file

@ -1,10 +0,0 @@
use std::sync::Arc;
use aquatic_common::access_list::AccessListArcSwap;
pub type TlsConfig = futures_rustls::rustls::ServerConfig;
#[derive(Default, Clone)]
pub struct State {
pub access_list: Arc<AccessListArcSwap>,
}

View file

@ -1,107 +0,0 @@
pub mod common;
pub mod request;
pub mod socket;
use std::sync::{atomic::AtomicUsize, Arc};
use crate::{common::create_tls_config, config::Config};
#[cfg(feature = "cpu-pinning")]
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
use aquatic_common::privileges::drop_privileges_after_socket_binding;
use self::common::*;
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
pub const SHARED_IN_CHANNEL_SIZE: usize = 1024;
pub fn run(config: Config, state: State) -> anyhow::Result<()> {
let num_peers = config.socket_workers + config.request_workers;
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE);
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16);
let num_bound_sockets = Arc::new(AtomicUsize::new(0));
let tls_config = Arc::new(create_tls_config(&config).unwrap());
let mut executors = Vec::new();
for i in 0..(config.socket_workers) {
let config = config.clone();
let state = state.clone();
let tls_config = tls_config.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let num_bound_sockets = num_bound_sockets.clone();
let builder = LocalExecutorBuilder::default().name("socket");
let executor = builder.spawn(move || async move {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::SocketWorker(i),
);
socket::run_socket_worker(
config,
state,
tls_config,
request_mesh_builder,
response_mesh_builder,
num_bound_sockets,
)
.await
});
executors.push(executor);
}
for i in 0..(config.request_workers) {
let config = config.clone();
let state = state.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let builder = LocalExecutorBuilder::default().name("request");
let executor = builder.spawn(move || async move {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::RequestWorker(i),
);
request::run_request_worker(config, state, request_mesh_builder, response_mesh_builder)
.await
});
executors.push(executor);
}
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::Other,
);
for executor in executors {
executor
.expect("failed to spawn local executor")
.join()
.unwrap();
}
Ok(())
}

View file

@ -1,128 +0,0 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::time::Duration;
use futures::StreamExt;
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::enclose;
use glommio::prelude::*;
use glommio::timer::TimerActionRepeat;
use rand::{rngs::SmallRng, SeedableRng};
use aquatic_ws_protocol::*;
use crate::common::handlers::*;
use crate::common::*;
use crate::config::Config;
use super::common::State;
use super::SHARED_IN_CHANNEL_SIZE;
pub async fn run_request_worker(
config: Config,
state: State,
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
) {
let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap();
let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap();
let out_message_senders = Rc::new(out_message_senders);
let torrents = Rc::new(RefCell::new(TorrentMaps::default()));
let access_list = state.access_list;
// Periodically clean torrents
TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || {
enclose!((config, torrents, access_list) move || async move {
torrents.borrow_mut().clean(&config, &access_list);
Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval))
})()
}));
let mut handles = Vec::new();
for (_, receiver) in in_message_receivers.streams() {
let handle = spawn_local(handle_request_stream(
config.clone(),
torrents.clone(),
out_message_senders.clone(),
receiver,
))
.detach();
handles.push(handle);
}
for handle in handles {
handle.await;
}
}
async fn handle_request_stream<S>(
config: Config,
torrents: Rc<RefCell<TorrentMaps>>,
out_message_senders: Rc<Senders<(ConnectionMeta, OutMessage)>>,
stream: S,
) where
S: futures_lite::Stream<Item = (ConnectionMeta, InMessage)> + ::std::marker::Unpin,
{
let rng = Rc::new(RefCell::new(SmallRng::from_entropy()));
let max_peer_age = config.cleaning.max_peer_age;
let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age)));
TimerActionRepeat::repeat(enclose!((peer_valid_until) move || {
enclose!((peer_valid_until) move || async move {
*peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age);
Some(Duration::from_secs(1))
})()
}));
let config = &config;
let torrents = &torrents;
let peer_valid_until = &peer_valid_until;
let rng = &rng;
let out_message_senders = &out_message_senders;
stream
.for_each_concurrent(
SHARED_IN_CHANNEL_SIZE,
move |(meta, in_message)| async move {
let mut out_messages = Vec::new();
match in_message {
InMessage::AnnounceRequest(request) => handle_announce_request(
&config,
&mut rng.borrow_mut(),
&mut torrents.borrow_mut(),
&mut out_messages,
peer_valid_until.borrow().to_owned(),
meta,
request,
),
InMessage::ScrapeRequest(request) => handle_scrape_request(
&config,
&mut torrents.borrow_mut(),
&mut out_messages,
meta,
request,
),
};
for (meta, out_message) in out_messages.drain(..) {
::log::info!("request worker trying to send OutMessage to socket worker");
out_message_senders
.send_to(meta.out_message_consumer_id.0, (meta, out_message))
.await
.expect("failed sending out_message to socket worker");
::log::info!("request worker sent OutMessage to socket worker");
}
},
)
.await;
}

View file

@ -1,28 +1,26 @@
use aquatic_common::access_list::update_access_list;
#[cfg(feature = "cpu-pinning")]
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
use cfg_if::cfg_if;
use common::State;
use signal_hook::{consts::SIGUSR1, iterator::Signals};
use crate::config::Config;
use std::sync::{atomic::AtomicUsize, Arc};
use crate::{common::create_tls_config, config::Config};
use aquatic_common::privileges::drop_privileges_after_socket_binding;
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
pub mod common;
pub mod config;
#[cfg(feature = "with-glommio")]
pub mod glommio;
#[cfg(feature = "with-mio")]
pub mod mio;
pub mod workers;
pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker";
pub const SHARED_IN_CHANNEL_SIZE: usize = 1024;
pub fn run(config: Config) -> ::anyhow::Result<()> {
cfg_if!(
if #[cfg(feature = "with-glommio")] {
let state = glommio::common::State::default();
} else {
let state = mio::common::State::default();
}
);
let state = State::default();
update_access_list(&config.access_list, &state.access_list)?;
@ -32,13 +30,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
let config = config.clone();
let state = state.clone();
cfg_if!(
if #[cfg(feature = "with-glommio")] {
::std::thread::spawn(move || glommio::run(config, state));
} else {
::std::thread::spawn(move || mio::run(config, state));
}
);
::std::thread::spawn(move || run_workers(config, state));
}
#[cfg(feature = "cpu-pinning")]
@ -59,3 +51,99 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
Ok(())
}
pub fn run_workers(config: Config, state: State) -> anyhow::Result<()> {
let num_peers = config.socket_workers + config.request_workers;
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE);
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16);
let num_bound_sockets = Arc::new(AtomicUsize::new(0));
let tls_config = Arc::new(create_tls_config(&config).unwrap());
let mut executors = Vec::new();
for i in 0..(config.socket_workers) {
let config = config.clone();
let state = state.clone();
let tls_config = tls_config.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let num_bound_sockets = num_bound_sockets.clone();
let builder = LocalExecutorBuilder::default().name("socket");
let executor = builder.spawn(move || async move {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::SocketWorker(i),
);
workers::socket::run_socket_worker(
config,
state,
tls_config,
request_mesh_builder,
response_mesh_builder,
num_bound_sockets,
)
.await
});
executors.push(executor);
}
for i in 0..(config.request_workers) {
let config = config.clone();
let state = state.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let builder = LocalExecutorBuilder::default().name("request");
let executor = builder.spawn(move || async move {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::RequestWorker(i),
);
workers::request::run_request_worker(
config,
state,
request_mesh_builder,
response_mesh_builder,
)
.await
});
executors.push(executor);
}
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::Other,
);
for executor in executors {
executor
.expect("failed to spawn local executor")
.join()
.unwrap();
}
Ok(())
}

View file

@ -1,51 +0,0 @@
use std::sync::Arc;
use aquatic_common::access_list::AccessListArcSwap;
use aquatic_ws_protocol::*;
use crossbeam_channel::{Receiver, Sender};
use log::error;
use mio::Token;
use parking_lot::Mutex;
use crate::common::*;
pub const LISTENER_TOKEN: Token = Token(0);
pub const CHANNEL_TOKEN: Token = Token(1);
#[derive(Clone)]
pub struct State {
pub access_list: Arc<AccessListArcSwap>,
pub torrent_maps: Arc<Mutex<TorrentMaps>>,
}
impl Default for State {
fn default() -> Self {
Self {
access_list: Arc::new(Default::default()),
torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())),
}
}
}
pub type InMessageSender = Sender<(ConnectionMeta, InMessage)>;
pub type InMessageReceiver = Receiver<(ConnectionMeta, InMessage)>;
pub type OutMessageReceiver = Receiver<(ConnectionMeta, OutMessage)>;
#[derive(Clone)]
pub struct OutMessageSender(Vec<Sender<(ConnectionMeta, OutMessage)>>);
impl OutMessageSender {
pub fn new(senders: Vec<Sender<(ConnectionMeta, OutMessage)>>) -> Self {
Self(senders)
}
#[inline]
pub fn send(&self, meta: ConnectionMeta, message: OutMessage) {
if let Err(err) = self.0[meta.out_message_consumer_id.0].send((meta, message)) {
error!("OutMessageSender: couldn't send message: {:?}", err);
}
}
}
pub type SocketWorkerStatus = Option<Result<(), String>>;
pub type SocketWorkerStatuses = Arc<Mutex<Vec<SocketWorkerStatus>>>;

View file

@ -1,218 +0,0 @@
use std::sync::Arc;
use std::thread::Builder;
use std::time::Duration;
use anyhow::Context;
#[cfg(feature = "cpu-pinning")]
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
use histogram::Histogram;
use mio::{Poll, Waker};
use parking_lot::Mutex;
use privdrop::PrivDrop;
pub mod common;
pub mod request;
pub mod socket;
use crate::{common::create_tls_config, config::Config};
use common::*;
pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker";
const SHARED_IN_CHANNEL_SIZE: usize = 1024;
pub fn run(config: Config, state: State) -> anyhow::Result<()> {
start_workers(config.clone(), state.clone()).expect("couldn't start workers");
// TODO: privdrop here instead
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::Other,
);
loop {
::std::thread::sleep(Duration::from_secs(
config.cleaning.torrent_cleaning_interval,
));
state.torrent_maps.lock().clean(&config, &state.access_list);
}
}
pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
let tls_config = Arc::new(create_tls_config(&config)?);
let (in_message_sender, in_message_receiver) =
::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE);
let mut out_message_senders = Vec::new();
let mut wakers = Vec::new();
let socket_worker_statuses: SocketWorkerStatuses = {
let mut statuses = Vec::new();
for _ in 0..config.socket_workers {
statuses.push(None);
}
Arc::new(Mutex::new(statuses))
};
for i in 0..config.socket_workers {
let config = config.clone();
let state = state.clone();
let socket_worker_statuses = socket_worker_statuses.clone();
let in_message_sender = in_message_sender.clone();
let tls_config = tls_config.clone();
let poll = Poll::new()?;
let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?);
let (out_message_sender, out_message_receiver) =
::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE * 16);
out_message_senders.push(out_message_sender);
wakers.push(waker);
Builder::new()
.name(format!("socket-{:02}", i + 1))
.spawn(move || {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::SocketWorker(i),
);
socket::run_socket_worker(
config,
state,
i,
socket_worker_statuses,
poll,
in_message_sender,
out_message_receiver,
tls_config,
);
})?;
}
// Wait for socket worker statuses. On error from any, quit program.
// On success from all, drop privileges if corresponding setting is set
// and continue program.
loop {
::std::thread::sleep(::std::time::Duration::from_millis(10));
if let Some(statuses) = socket_worker_statuses.try_lock() {
for opt_status in statuses.iter() {
if let Some(Err(err)) = opt_status {
return Err(::anyhow::anyhow!(err.to_owned()));
}
}
if statuses.iter().all(Option::is_some) {
if config.privileges.drop_privileges {
PrivDrop::default()
.chroot(config.privileges.chroot_path.clone())
.user(config.privileges.user.clone())
.apply()
.context("Couldn't drop root privileges")?;
}
break;
}
}
}
let out_message_sender = OutMessageSender::new(out_message_senders);
for i in 0..config.request_workers {
let config = config.clone();
let state = state.clone();
let in_message_receiver = in_message_receiver.clone();
let out_message_sender = out_message_sender.clone();
let wakers = wakers.clone();
Builder::new()
.name(format!("request-{:02}", i + 1))
.spawn(move || {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::RequestWorker(i),
);
request::run_request_worker(
config,
state,
in_message_receiver,
out_message_sender,
wakers,
);
})?;
}
if config.statistics.interval != 0 {
let state = state.clone();
let config = config.clone();
Builder::new()
.name("statistics".to_string())
.spawn(move || {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
WorkerIndex::Other,
);
loop {
::std::thread::sleep(Duration::from_secs(config.statistics.interval));
print_statistics(&state);
}
})
.expect("spawn statistics thread");
}
Ok(())
}
fn print_statistics(state: &State) {
let mut peers_per_torrent = Histogram::new();
{
let torrents = &mut state.torrent_maps.lock();
for torrent in torrents.ipv4.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
for torrent in torrents.ipv6.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
}
if peers_per_torrent.entries() != 0 {
println!(
"peers per torrent: min: {}, p50: {}, p75: {}, p90: {}, p99: {}, p999: {}, max: {}",
peers_per_torrent.minimum().unwrap(),
peers_per_torrent.percentile(50.0).unwrap(),
peers_per_torrent.percentile(75.0).unwrap(),
peers_per_torrent.percentile(90.0).unwrap(),
peers_per_torrent.percentile(99.0).unwrap(),
peers_per_torrent.percentile(99.9).unwrap(),
peers_per_torrent.maximum().unwrap(),
);
}
}

View file

@ -1,103 +0,0 @@
use std::sync::Arc;
use std::time::Duration;
use mio::Waker;
use parking_lot::MutexGuard;
use rand::{rngs::SmallRng, SeedableRng};
use aquatic_ws_protocol::*;
use crate::common::handlers::{handle_announce_request, handle_scrape_request};
use crate::common::*;
use crate::config::Config;
use super::common::*;
pub fn run_request_worker(
config: Config,
state: State,
in_message_receiver: InMessageReceiver,
out_message_sender: OutMessageSender,
wakers: Vec<Arc<Waker>>,
) {
let mut wake_socket_workers: Vec<bool> = (0..config.socket_workers).map(|_| false).collect();
let mut announce_requests = Vec::new();
let mut scrape_requests = Vec::new();
let mut out_messages = Vec::new();
let mut rng = SmallRng::from_entropy();
let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds);
loop {
let mut opt_torrent_map_guard: Option<MutexGuard<TorrentMaps>> = None;
for i in 0..config.handlers.max_requests_per_iter {
let opt_in_message = if i == 0 {
in_message_receiver.recv().ok()
} else {
in_message_receiver.recv_timeout(timeout).ok()
};
match opt_in_message {
Some((meta, InMessage::AnnounceRequest(r))) => {
announce_requests.push((meta, r));
}
Some((meta, InMessage::ScrapeRequest(r))) => {
scrape_requests.push((meta, r));
}
None => {
if let Some(torrent_guard) = state.torrent_maps.try_lock() {
opt_torrent_map_guard = Some(torrent_guard);
break;
}
}
}
}
let mut torrent_map_guard =
opt_torrent_map_guard.unwrap_or_else(|| state.torrent_maps.lock());
let valid_until = ValidUntil::new(config.cleaning.max_peer_age);
for (meta, request) in announce_requests.drain(..) {
handle_announce_request(
&config,
&mut rng,
&mut torrent_map_guard,
&mut out_messages,
valid_until,
meta,
request,
);
}
for (meta, request) in scrape_requests.drain(..) {
handle_scrape_request(
&config,
&mut torrent_map_guard,
&mut out_messages,
meta,
request,
);
}
::std::mem::drop(torrent_map_guard);
for (meta, out_message) in out_messages.drain(..) {
wake_socket_workers[meta.out_message_consumer_id.0] = true;
out_message_sender.send(meta, out_message);
}
for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() {
if *wake {
if let Err(err) = wakers[worker_index].wake() {
::log::error!("request handler couldn't wake poll: {:?}", err);
}
*wake = false;
}
}
}
}

View file

@ -1,577 +0,0 @@
use std::{collections::VecDeque, io::ErrorKind, marker::PhantomData, net::Shutdown, sync::Arc};
use aquatic_common::ValidUntil;
use aquatic_ws_protocol::{InMessage, OutMessage};
use mio::{net::TcpStream, Interest, Poll, Token};
use rustls::{ServerConfig, ServerConnection};
use tungstenite::{
handshake::{server::NoCallback, MidHandshake},
protocol::WebSocketConfig,
HandshakeError, ServerHandshake,
};
use crate::common::ConnectionMeta;
const MAX_PENDING_MESSAGES: usize = 16;
type TlsStream = rustls::StreamOwned<ServerConnection, TcpStream>;
type WsHandshakeResult<S> =
Result<tungstenite::WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>;
type ConnectionReadResult<T> = ::std::io::Result<ConnectionReadStatus<T>>;
pub trait RegistryStatus {}
pub struct Registered;
impl RegistryStatus for Registered {}
pub struct NotRegistered;
impl RegistryStatus for NotRegistered {}
enum ConnectionReadStatus<T> {
Message(T, InMessage),
Ok(T),
WouldBlock(T),
}
enum ConnectionState<R: RegistryStatus> {
TlsHandshaking(TlsHandshaking<R>),
WsHandshaking(WsHandshaking<R>),
WsConnection(WsConnection<R>),
}
pub struct Connection<R: RegistryStatus> {
pub valid_until: ValidUntil,
meta: ConnectionMeta,
state: ConnectionState<R>,
pub message_queue: VecDeque<OutMessage>,
pub interest: Interest,
phantom_data: PhantomData<R>,
}
impl<R: RegistryStatus> Connection<R> {
pub fn get_meta(&self) -> ConnectionMeta {
self.meta
}
}
impl Connection<NotRegistered> {
pub fn new(
tls_config: Arc<ServerConfig>,
ws_config: WebSocketConfig,
tcp_stream: TcpStream,
valid_until: ValidUntil,
meta: ConnectionMeta,
) -> Self {
let state =
ConnectionState::TlsHandshaking(TlsHandshaking::new(tls_config, ws_config, tcp_stream));
Self {
valid_until,
meta,
state,
message_queue: Default::default(),
interest: Interest::READABLE,
phantom_data: PhantomData::default(),
}
}
/// Read until stream blocks (or error occurs)
///
/// Requires Connection not to be registered, since it might be dropped on errors
pub fn read<F>(
mut self,
message_handler: &mut F,
) -> ::std::io::Result<Connection<NotRegistered>>
where
F: FnMut(ConnectionMeta, InMessage),
{
loop {
let result = match self.state {
ConnectionState::TlsHandshaking(inner) => inner.read(),
ConnectionState::WsHandshaking(inner) => inner.read(),
ConnectionState::WsConnection(inner) => inner.read(),
};
match result {
Ok(ConnectionReadStatus::Message(state, message)) => {
self.state = state;
message_handler(self.meta, message);
// Stop looping even if WouldBlock wasn't necessarily reached. Otherwise,
// we might get stuck reading from this connection only. Since we register
// the connection again upon reinsertion into the ConnectionMap, we should
// be getting new events anyway.
return Ok(self);
}
Ok(ConnectionReadStatus::Ok(state)) => {
self.state = state;
::log::debug!("read connection");
}
Ok(ConnectionReadStatus::WouldBlock(state)) => {
self.state = state;
::log::debug!("reading connection would block");
return Ok(self);
}
Err(err) => {
::log::debug!("Connection::read error: {}", err);
return Err(err);
}
}
}
}
pub fn register(self, poll: &mut Poll, token: Token) -> Connection<Registered> {
let state = match self.state {
ConnectionState::TlsHandshaking(inner) => {
ConnectionState::TlsHandshaking(inner.register(poll, token, self.interest))
}
ConnectionState::WsHandshaking(inner) => {
ConnectionState::WsHandshaking(inner.register(poll, token, self.interest))
}
ConnectionState::WsConnection(inner) => {
ConnectionState::WsConnection(inner.register(poll, token, self.interest))
}
};
Connection {
valid_until: self.valid_until,
meta: self.meta,
state,
message_queue: self.message_queue,
interest: self.interest,
phantom_data: PhantomData::default(),
}
}
pub fn close(self) {
::log::debug!("will close connection to {}", self.meta.peer_addr.get());
match self.state {
ConnectionState::TlsHandshaking(inner) => inner.close(),
ConnectionState::WsHandshaking(inner) => inner.close(),
ConnectionState::WsConnection(inner) => inner.close(),
}
}
}
impl Connection<Registered> {
pub fn write_or_queue_message(
&mut self,
poll: &mut Poll,
message: OutMessage,
) -> ::std::io::Result<()> {
let message_clone = message.clone();
match self.write_message(message) {
Ok(()) => Ok(()),
Err(err) if err.kind() == ErrorKind::WouldBlock => {
if self.message_queue.len() < MAX_PENDING_MESSAGES {
self.message_queue.push_back(message_clone);
if !self.interest.is_writable() {
self.interest = Interest::WRITABLE;
self.reregister(poll)?;
}
} else {
::log::info!("Connection::message_queue is full, dropping message");
}
Ok(())
}
Err(err) => Err(err),
}
}
pub fn write(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
if let ConnectionState::WsConnection(_) = self.state {
while let Some(message) = self.message_queue.pop_front() {
let message_clone = message.clone();
match self.write_message(message) {
Ok(()) => {}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
// Can't make message queue longer than it was before pop_front
self.message_queue.push_front(message_clone);
return Ok(());
}
Err(err) => {
return Err(err);
}
}
}
if self.message_queue.is_empty() {
self.interest = Interest::READABLE;
}
self.reregister(poll)?;
Ok(())
} else {
Err(std::io::Error::new(
ErrorKind::NotConnected,
"WebSocket connection not established",
))
}
}
fn write_message(&mut self, message: OutMessage) -> ::std::io::Result<()> {
if let ConnectionState::WsConnection(WsConnection {
ref mut web_socket, ..
}) = self.state
{
match web_socket.write_message(message.to_ws_message()) {
Ok(_) => {}
Err(tungstenite::Error::SendQueueFull(_message)) => {
return Err(std::io::Error::new(
ErrorKind::WouldBlock,
"Send queue full",
))
}
Err(tungstenite::Error::Io(err)) => return Err(err),
Err(err) => return Err(std::io::Error::new(ErrorKind::Other, err))?,
}
match web_socket.write_pending() {
Ok(()) => Ok(()),
Err(tungstenite::Error::Io(err)) => Err(err),
Err(err) => Err(std::io::Error::new(ErrorKind::Other, err))?,
}
} else {
Err(std::io::Error::new(
ErrorKind::NotConnected,
"WebSocket connection not established",
))
}
}
pub fn reregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
let token = Token(self.meta.connection_id.0);
match self.state {
ConnectionState::TlsHandshaking(ref mut inner) => {
inner.reregister(poll, token, self.interest)
}
ConnectionState::WsHandshaking(ref mut inner) => {
inner.reregister(poll, token, self.interest)
}
ConnectionState::WsConnection(ref mut inner) => {
inner.reregister(poll, token, self.interest)
}
}
}
pub fn deregister(self, poll: &mut Poll) -> Connection<NotRegistered> {
let state = match self.state {
ConnectionState::TlsHandshaking(inner) => {
ConnectionState::TlsHandshaking(inner.deregister(poll))
}
ConnectionState::WsHandshaking(inner) => {
ConnectionState::WsHandshaking(inner.deregister(poll))
}
ConnectionState::WsConnection(inner) => {
ConnectionState::WsConnection(inner.deregister(poll))
}
};
Connection {
valid_until: self.valid_until,
meta: self.meta,
state,
message_queue: self.message_queue,
interest: self.interest,
phantom_data: PhantomData::default(),
}
}
}
struct TlsHandshaking<R: RegistryStatus> {
tls_conn: ServerConnection,
ws_config: WebSocketConfig,
tcp_stream: TcpStream,
phantom_data: PhantomData<R>,
}
impl TlsHandshaking<NotRegistered> {
fn new(tls_config: Arc<ServerConfig>, ws_config: WebSocketConfig, stream: TcpStream) -> Self {
Self {
tls_conn: ServerConnection::new(tls_config).unwrap(),
ws_config,
tcp_stream: stream,
phantom_data: PhantomData::default(),
}
}
fn read(mut self) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
match self.tls_conn.read_tls(&mut self.tcp_stream) {
Ok(0) => {
return Err(::std::io::Error::new(
ErrorKind::ConnectionReset,
"Connection closed",
))
}
Ok(_) => match self.tls_conn.process_new_packets() {
Ok(_) => {
while self.tls_conn.wants_write() {
self.tls_conn.write_tls(&mut self.tcp_stream)?;
}
if self.tls_conn.is_handshaking() {
Ok(ConnectionReadStatus::WouldBlock(
ConnectionState::TlsHandshaking(self),
))
} else {
let tls_stream = TlsStream::new(self.tls_conn, self.tcp_stream);
WsHandshaking::handle_handshake_result(tungstenite::accept_with_config(
tls_stream,
Some(self.ws_config),
))
}
}
Err(err) => {
let _ = self.tls_conn.write_tls(&mut self.tcp_stream);
Err(::std::io::Error::new(ErrorKind::InvalidData, err))
}
},
Err(err) if err.kind() == ErrorKind::WouldBlock => {
return Ok(ConnectionReadStatus::WouldBlock(
ConnectionState::TlsHandshaking(self),
))
}
Err(err) => return Err(err),
}
}
fn register(
mut self,
poll: &mut Poll,
token: Token,
interest: Interest,
) -> TlsHandshaking<Registered> {
poll.registry()
.register(&mut self.tcp_stream, token, interest)
.unwrap();
TlsHandshaking {
tls_conn: self.tls_conn,
ws_config: self.ws_config,
tcp_stream: self.tcp_stream,
phantom_data: PhantomData::default(),
}
}
fn close(self) {
::log::debug!("closing connection (TlsHandshaking state)");
let _ = self.tcp_stream.shutdown(Shutdown::Both);
}
}
impl TlsHandshaking<Registered> {
fn deregister(mut self, poll: &mut Poll) -> TlsHandshaking<NotRegistered> {
poll.registry().deregister(&mut self.tcp_stream).unwrap();
TlsHandshaking {
tls_conn: self.tls_conn,
ws_config: self.ws_config,
tcp_stream: self.tcp_stream,
phantom_data: PhantomData::default(),
}
}
fn reregister(
&mut self,
poll: &mut Poll,
token: Token,
interest: Interest,
) -> std::io::Result<()> {
poll.registry()
.reregister(&mut self.tcp_stream, token, interest)
}
}
struct WsHandshaking<R: RegistryStatus> {
mid_handshake: MidHandshake<ServerHandshake<TlsStream, NoCallback>>,
phantom_data: PhantomData<R>,
}
impl WsHandshaking<NotRegistered> {
fn read(self) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
Self::handle_handshake_result(self.mid_handshake.handshake())
}
fn handle_handshake_result(
handshake_result: WsHandshakeResult<TlsStream>,
) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
match handshake_result {
Ok(web_socket) => {
let conn = ConnectionState::WsConnection(WsConnection {
web_socket,
phantom_data: PhantomData::default(),
});
Ok(ConnectionReadStatus::Ok(conn))
}
Err(HandshakeError::Interrupted(mid_handshake)) => {
let conn = ConnectionState::WsHandshaking(WsHandshaking {
mid_handshake,
phantom_data: PhantomData::default(),
});
Ok(ConnectionReadStatus::WouldBlock(conn))
}
Err(HandshakeError::Failure(err)) => {
return Err(std::io::Error::new(ErrorKind::InvalidData, err))
}
}
}
fn register(
mut self,
poll: &mut Poll,
token: Token,
interest: Interest,
) -> WsHandshaking<Registered> {
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
poll.registry()
.register(tcp_stream, token, interest)
.unwrap();
WsHandshaking {
mid_handshake: self.mid_handshake,
phantom_data: PhantomData::default(),
}
}
fn close(mut self) {
::log::debug!("closing connection (WsHandshaking state)");
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
let _ = tcp_stream.shutdown(Shutdown::Both);
}
}
impl WsHandshaking<Registered> {
fn deregister(mut self, poll: &mut Poll) -> WsHandshaking<NotRegistered> {
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
poll.registry().deregister(tcp_stream).unwrap();
WsHandshaking {
mid_handshake: self.mid_handshake,
phantom_data: PhantomData::default(),
}
}
fn reregister(
&mut self,
poll: &mut Poll,
token: Token,
interest: Interest,
) -> std::io::Result<()> {
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
poll.registry().reregister(tcp_stream, token, interest)
}
}
struct WsConnection<R: RegistryStatus> {
web_socket: tungstenite::WebSocket<TlsStream>,
phantom_data: PhantomData<R>,
}
impl WsConnection<NotRegistered> {
fn read(mut self) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
match self.web_socket.read_message() {
Ok(
message @ tungstenite::Message::Text(_) | message @ tungstenite::Message::Binary(_),
) => match InMessage::from_ws_message(message) {
Ok(message) => {
::log::debug!("received WebSocket message");
Ok(ConnectionReadStatus::Message(
ConnectionState::WsConnection(self),
message,
))
}
Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)),
},
Ok(message) => {
::log::info!("received unexpected WebSocket message: {}", message);
Err(std::io::Error::new(
ErrorKind::InvalidData,
"unexpected WebSocket message type",
))
}
Err(tungstenite::Error::Io(err)) if err.kind() == ErrorKind::WouldBlock => {
let conn = ConnectionState::WsConnection(self);
Ok(ConnectionReadStatus::WouldBlock(conn))
}
Err(tungstenite::Error::Io(err)) => Err(err),
Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)),
}
}
fn register(
mut self,
poll: &mut Poll,
token: Token,
interest: Interest,
) -> WsConnection<Registered> {
poll.registry()
.register(self.web_socket.get_mut().get_mut(), token, interest)
.unwrap();
WsConnection {
web_socket: self.web_socket,
phantom_data: PhantomData::default(),
}
}
fn close(mut self) {
::log::debug!("closing connection (WsConnection state)");
let _ = self.web_socket.close(None);
let _ = self.web_socket.write_pending();
}
}
impl WsConnection<Registered> {
fn deregister(mut self, poll: &mut Poll) -> WsConnection<NotRegistered> {
poll.registry()
.deregister(self.web_socket.get_mut().get_mut())
.unwrap();
WsConnection {
web_socket: self.web_socket,
phantom_data: PhantomData::default(),
}
}
fn reregister(
&mut self,
poll: &mut Poll,
token: Token,
interest: Interest,
) -> std::io::Result<()> {
poll.registry()
.reregister(self.web_socket.get_mut().get_mut(), token, interest)
}
}

View file

@ -1,403 +0,0 @@
use std::io::ErrorKind;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Context;
use aquatic_common::access_list::AccessListQuery;
use aquatic_common::CanonicalSocketAddr;
use hashbrown::HashMap;
use mio::net::TcpListener;
use mio::{Events, Interest, Poll, Token};
use socket2::{Domain, Protocol, Socket, Type};
use tungstenite::protocol::WebSocketConfig;
use aquatic_ws_protocol::*;
use crate::common::*;
use crate::config::Config;
pub mod connection;
use super::common::*;
use connection::{Connection, NotRegistered, Registered};
struct ConnectionMap {
token_counter: Token,
connections: HashMap<Token, Connection<Registered>>,
}
impl Default for ConnectionMap {
fn default() -> Self {
Self {
token_counter: Token(2),
connections: Default::default(),
}
}
}
impl ConnectionMap {
fn insert_and_register_new<F>(&mut self, poll: &mut Poll, connection_creator: F)
where
F: FnOnce(Token) -> Connection<NotRegistered>,
{
self.token_counter.0 = self.token_counter.0.wrapping_add(1);
// Don't assign LISTENER_TOKEN or CHANNEL_TOKEN
if self.token_counter.0 < 2 {
self.token_counter.0 = 2;
}
let token = self.token_counter;
// Remove, deregister and close any existing connection with this token.
// This shouldn't happen in practice.
if let Some(connection) = self.connections.remove(&token) {
::log::warn!(
"removing existing connection {} because of token reuse",
token.0
);
connection.deregister(poll).close();
}
let connection = connection_creator(token);
self.insert_and_register(poll, token, connection);
}
fn insert_and_register(
&mut self,
poll: &mut Poll,
key: Token,
conn: Connection<NotRegistered>,
) {
self.connections.insert(key, conn.register(poll, key));
}
fn remove_and_deregister(
&mut self,
poll: &mut Poll,
key: &Token,
) -> Option<Connection<NotRegistered>> {
if let Some(connection) = self.connections.remove(key) {
Some(connection.deregister(poll))
} else {
None
}
}
fn get_mut(&mut self, key: &Token) -> Option<&mut Connection<Registered>> {
self.connections.get_mut(key)
}
/// Close and remove inactive connections
fn clean(mut self, poll: &mut Poll) -> Self {
let now = Instant::now();
let mut retained_connections = HashMap::default();
for (token, connection) in self.connections.drain() {
if connection.valid_until.0 < now {
connection.deregister(poll).close();
} else {
retained_connections.insert(token, connection);
}
}
ConnectionMap {
connections: retained_connections,
..self
}
}
}
pub fn run_socket_worker(
config: Config,
state: State,
socket_worker_index: usize,
socket_worker_statuses: SocketWorkerStatuses,
poll: Poll,
in_message_sender: InMessageSender,
out_message_receiver: OutMessageReceiver,
tls_config: Arc<rustls::ServerConfig>,
) {
match create_listener(&config) {
Ok(listener) => {
socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(()));
run_poll_loop(
config,
&state,
socket_worker_index,
poll,
in_message_sender,
out_message_receiver,
listener,
tls_config,
);
}
Err(err) => {
socket_worker_statuses.lock()[socket_worker_index] =
Some(Err(format!("Couldn't open socket: {:#}", err)));
}
}
}
fn run_poll_loop(
config: Config,
state: &State,
socket_worker_index: usize,
mut poll: Poll,
in_message_sender: InMessageSender,
out_message_receiver: OutMessageReceiver,
listener: ::std::net::TcpListener,
tls_config: Arc<rustls::ServerConfig>,
) {
let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds);
let ws_config = WebSocketConfig {
max_message_size: Some(config.network.websocket_max_message_size),
max_frame_size: Some(config.network.websocket_max_frame_size),
max_send_queue: Some(2),
..Default::default()
};
let mut listener = TcpListener::from_std(listener);
let mut events = Events::with_capacity(config.network.poll_event_capacity);
poll.registry()
.register(&mut listener, LISTENER_TOKEN, Interest::READABLE)
.unwrap();
let mut connections = ConnectionMap::default();
let mut local_responses = Vec::new();
let mut iter_counter = 0usize;
loop {
poll.poll(&mut events, Some(poll_timeout))
.expect("failed polling");
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
for event in events.iter() {
let token = event.token();
match token {
LISTENER_TOKEN => {
accept_new_streams(
&tls_config,
ws_config,
socket_worker_index,
&mut listener,
&mut poll,
&mut connections,
valid_until,
);
}
CHANNEL_TOKEN => {
write_or_queue_messages(
&mut poll,
out_message_receiver
.try_iter()
.take(out_message_receiver.len()),
&mut connections,
);
}
token => {
if event.is_writable() {
let mut remove_connection = false;
if let Some(connection) = connections.get_mut(&token) {
if let Err(err) = connection.write(&mut poll) {
::log::debug!("Connection::write error: {}", err);
remove_connection = true;
}
}
if remove_connection {
if let Some(connection) =
connections.remove_and_deregister(&mut poll, &token)
{
connection.close();
}
}
}
if event.is_readable() {
handle_stream_read_event(
&config,
state,
&mut local_responses,
&in_message_sender,
&mut poll,
&mut connections,
token,
valid_until,
);
}
}
}
write_or_queue_messages(&mut poll, local_responses.drain(..), &mut connections);
}
// Remove inactive connections, but not every iteration
if iter_counter % 128 == 0 {
connections = connections.clean(&mut poll);
}
iter_counter = iter_counter.wrapping_add(1);
}
}
fn accept_new_streams(
tls_config: &Arc<rustls::ServerConfig>,
ws_config: WebSocketConfig,
socket_worker_index: usize,
listener: &mut TcpListener,
poll: &mut Poll,
connections: &mut ConnectionMap,
valid_until: ValidUntil,
) {
loop {
match listener.accept() {
Ok((stream, _)) => {
let peer_addr = if let Ok(peer_addr) = stream.peer_addr() {
CanonicalSocketAddr::new(peer_addr)
} else {
continue;
};
connections.insert_and_register_new(poll, move |token| {
let meta = ConnectionMeta {
out_message_consumer_id: ConsumerId(socket_worker_index),
connection_id: ConnectionId(token.0),
peer_addr,
pending_scrape_id: None, // FIXME
};
Connection::new(tls_config.clone(), ws_config, stream, valid_until, meta)
});
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
break;
}
Err(err) => {
::log::info!("error while accepting streams: {}", err);
}
}
}
}
fn handle_stream_read_event(
config: &Config,
state: &State,
local_responses: &mut Vec<(ConnectionMeta, OutMessage)>,
in_message_sender: &InMessageSender,
poll: &mut Poll,
connections: &mut ConnectionMap,
token: Token,
valid_until: ValidUntil,
) {
let access_list_mode = config.access_list.mode;
if let Some(mut connection) = connections.remove_and_deregister(poll, &token) {
let message_handler = &mut |meta, message| match message {
InMessage::AnnounceRequest(ref request)
if !state
.access_list
.allows(access_list_mode, &request.info_hash.0) =>
{
let out_message = OutMessage::ErrorResponse(ErrorResponse {
failure_reason: "Info hash not allowed".into(),
action: Some(ErrorResponseAction::Announce),
info_hash: Some(request.info_hash),
});
local_responses.push((meta, out_message));
}
in_message => {
if let Err(err) = in_message_sender.send((meta, in_message)) {
::log::info!("InMessageSender: couldn't send message: {:?}", err);
}
}
};
connection.valid_until = valid_until;
match connection.read(message_handler) {
Ok(connection) => {
connections.insert_and_register(poll, token, connection);
}
Err(_) => {}
}
}
}
fn write_or_queue_messages<I>(poll: &mut Poll, responses: I, connections: &mut ConnectionMap)
where
I: Iterator<Item = (ConnectionMeta, OutMessage)>,
{
for (meta, out_message) in responses {
let token = Token(meta.connection_id.0);
let mut remove_connection = false;
if let Some(connection) = connections.get_mut(&token) {
if connection.get_meta().peer_addr != meta.peer_addr {
::log::warn!(
"socket worker error: connection socket addr {} didn't match channel {}. Token: {}.",
connection.get_meta().peer_addr.get(),
meta.peer_addr.get(),
token.0
);
remove_connection = true;
} else {
match connection.write_or_queue_message(poll, out_message) {
Ok(()) => {}
Err(err) => {
::log::debug!("Connection::write_or_queue_message error: {}", err);
remove_connection = true;
}
}
}
}
if remove_connection {
connections.remove_and_deregister(poll, &token);
}
}
}
pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> {
let builder = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))
} else {
Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))
}
.context("Couldn't create socket2::Socket")?;
if config.network.ipv6_only {
builder
.set_only_v6(true)
.context("Couldn't put socket in ipv6 only mode")?
}
builder
.set_nonblocking(true)
.context("Couldn't put socket in non-blocking mode")?;
builder
.set_reuse_port(true)
.context("Couldn't put socket in reuse_port mode")?;
builder
.bind(&config.network.address.into())
.with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?;
builder
.listen(128)
.context("Couldn't listen for connections on socket")?;
Ok(builder.into())
}

View file

@ -0,0 +1,2 @@
pub mod request;
pub mod socket;

View file

@ -1,11 +1,130 @@
use aquatic_common::extract_response_peers;
use hashbrown::HashMap;
use rand::rngs::SmallRng;
use std::cell::RefCell;
use std::rc::Rc;
use std::time::Duration;
use futures::StreamExt;
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::enclose;
use glommio::prelude::*;
use glommio::timer::TimerActionRepeat;
use hashbrown::HashMap;
use rand::{rngs::SmallRng, SeedableRng};
use aquatic_common::extract_response_peers;
use aquatic_ws_protocol::*;
use crate::common::*;
use crate::config::Config;
use crate::SHARED_IN_CHANNEL_SIZE;
pub async fn run_request_worker(
config: Config,
state: State,
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
) {
let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap();
let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap();
let out_message_senders = Rc::new(out_message_senders);
let torrents = Rc::new(RefCell::new(TorrentMaps::default()));
let access_list = state.access_list;
// Periodically clean torrents
TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || {
enclose!((config, torrents, access_list) move || async move {
torrents.borrow_mut().clean(&config, &access_list);
Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval))
})()
}));
let mut handles = Vec::new();
for (_, receiver) in in_message_receivers.streams() {
let handle = spawn_local(handle_request_stream(
config.clone(),
torrents.clone(),
out_message_senders.clone(),
receiver,
))
.detach();
handles.push(handle);
}
for handle in handles {
handle.await;
}
}
async fn handle_request_stream<S>(
config: Config,
torrents: Rc<RefCell<TorrentMaps>>,
out_message_senders: Rc<Senders<(ConnectionMeta, OutMessage)>>,
stream: S,
) where
S: futures_lite::Stream<Item = (ConnectionMeta, InMessage)> + ::std::marker::Unpin,
{
let rng = Rc::new(RefCell::new(SmallRng::from_entropy()));
let max_peer_age = config.cleaning.max_peer_age;
let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age)));
TimerActionRepeat::repeat(enclose!((peer_valid_until) move || {
enclose!((peer_valid_until) move || async move {
*peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age);
Some(Duration::from_secs(1))
})()
}));
let config = &config;
let torrents = &torrents;
let peer_valid_until = &peer_valid_until;
let rng = &rng;
let out_message_senders = &out_message_senders;
stream
.for_each_concurrent(
SHARED_IN_CHANNEL_SIZE,
move |(meta, in_message)| async move {
let mut out_messages = Vec::new();
match in_message {
InMessage::AnnounceRequest(request) => handle_announce_request(
&config,
&mut rng.borrow_mut(),
&mut torrents.borrow_mut(),
&mut out_messages,
peer_valid_until.borrow().to_owned(),
meta,
request,
),
InMessage::ScrapeRequest(request) => handle_scrape_request(
&config,
&mut torrents.borrow_mut(),
&mut out_messages,
meta,
request,
),
};
for (meta, out_message) in out_messages.drain(..) {
::log::info!("request worker trying to send OutMessage to socket worker");
out_message_senders
.send_to(meta.out_message_consumer_id.0, (meta, out_message))
.await
.expect("failed sending out_message to socket worker");
::log::info!("request worker sent OutMessage to socket worker");
}
},
)
.await;
}
pub fn handle_announce_request(
config: &Config,

View file

@ -29,8 +29,6 @@ use crate::config::Config;
use crate::common::*;
use super::common::*;
const LOCAL_CHANNEL_SIZE: usize = 16;
struct PendingScrapeResponse {