mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-03-31 17:55:36 +00:00
ws: remove mio implementation
This commit is contained in:
parent
065e007ede
commit
667cf04085
16 changed files with 283 additions and 1831 deletions
86
Cargo.lock
generated
86
Cargo.lock
generated
|
|
@ -264,18 +264,14 @@ dependencies = [
|
|||
"aquatic_ws_protocol",
|
||||
"async-tungstenite",
|
||||
"cfg-if",
|
||||
"crossbeam-channel",
|
||||
"either",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"futures-rustls",
|
||||
"glommio",
|
||||
"hashbrown 0.12.0",
|
||||
"histogram",
|
||||
"log",
|
||||
"mimalloc",
|
||||
"mio",
|
||||
"parking_lot",
|
||||
"privdrop",
|
||||
"quickcheck",
|
||||
"quickcheck_macros",
|
||||
|
|
@ -285,7 +281,6 @@ dependencies = [
|
|||
"serde",
|
||||
"signal-hook",
|
||||
"slab",
|
||||
"socket2 0.4.4",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
|
|
@ -1066,12 +1061,6 @@ version = "0.4.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "histogram"
|
||||
version = "0.6.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12cb882ccb290b8646e554b157ab0b71e64e8d5bef775cd66b6531e52d302669"
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.6"
|
||||
|
|
@ -1471,29 +1460,6 @@ version = "2.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.1.0"
|
||||
|
|
@ -1687,15 +1653,6 @@ dependencies = [
|
|||
"num_cpus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.5.4"
|
||||
|
|
@ -2397,46 +2354,3 @@ name = "winapi-x86_64-pc-windows-gnu"
|
|||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.32.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6"
|
||||
dependencies = [
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_msvc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.32.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.32.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.32.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.32.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.32.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316"
|
||||
|
|
|
|||
|
|
@ -14,10 +14,7 @@ name = "aquatic_ws"
|
|||
name = "aquatic_ws"
|
||||
|
||||
[features]
|
||||
default = ["with-mio"]
|
||||
cpu-pinning = ["aquatic_common/cpu-pinning"]
|
||||
with-glommio = ["cpu-pinning", "async-tungstenite", "futures-lite", "futures", "futures-rustls", "glommio"]
|
||||
with-mio = ["crossbeam-channel", "histogram", "mio", "parking_lot", "socket2"]
|
||||
|
||||
[dependencies]
|
||||
aquatic_cli_helpers = "0.1.0"
|
||||
|
|
@ -39,20 +36,11 @@ serde = { version = "1", features = ["derive"] }
|
|||
signal-hook = { version = "0.3" }
|
||||
slab = "0.4"
|
||||
tungstenite = "0.17"
|
||||
|
||||
# mio
|
||||
crossbeam-channel = { version = "0.5", optional = true }
|
||||
histogram = { version = "0.6", optional = true }
|
||||
mio = { version = "0.8", features = ["net", "os-poll"], optional = true }
|
||||
parking_lot = { version = "0.12", optional = true }
|
||||
socket2 = { version = "0.4", features = ["all"], optional = true }
|
||||
|
||||
# glommio
|
||||
async-tungstenite = { version = "0.17", optional = true }
|
||||
futures-lite = { version = "1", optional = true }
|
||||
futures = { version = "0.3", optional = true }
|
||||
futures-rustls = { version = "0.22", optional = true }
|
||||
glommio = { version = "0.7", optional = true }
|
||||
async-tungstenite = "0.17"
|
||||
futures-lite = "1"
|
||||
futures = "0.3"
|
||||
futures-rustls = "0.22"
|
||||
glommio = "0.7"
|
||||
|
||||
[dev-dependencies]
|
||||
quickcheck = "1"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
pub mod handlers;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -14,6 +12,13 @@ use aquatic_ws_protocol::*;
|
|||
|
||||
use crate::config::Config;
|
||||
|
||||
pub type TlsConfig = futures_rustls::rustls::ServerConfig;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct State {
|
||||
pub access_list: Arc<AccessListArcSwap>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct PendingScrapeId(pub usize);
|
||||
|
||||
|
|
@ -23,15 +23,28 @@ pub struct Config {
|
|||
pub log_level: LogLevel,
|
||||
pub network: NetworkConfig,
|
||||
pub protocol: ProtocolConfig,
|
||||
#[cfg(feature = "with-mio")]
|
||||
pub handlers: HandlerConfig,
|
||||
pub cleaning: CleaningConfig,
|
||||
pub privileges: PrivilegeConfig,
|
||||
pub access_list: AccessListConfig,
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pub cpu_pinning: CpuPinningConfig,
|
||||
#[cfg(feature = "with-mio")]
|
||||
pub statistics: StatisticsConfig,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
socket_workers: 1,
|
||||
request_workers: 1,
|
||||
log_level: LogLevel::default(),
|
||||
network: NetworkConfig::default(),
|
||||
protocol: ProtocolConfig::default(),
|
||||
cleaning: CleaningConfig::default(),
|
||||
privileges: PrivilegeConfig::default(),
|
||||
access_list: AccessListConfig::default(),
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
cpu_pinning: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl aquatic_cli_helpers::Config for Config {
|
||||
|
|
@ -55,81 +68,6 @@ pub struct NetworkConfig {
|
|||
|
||||
pub websocket_max_message_size: usize,
|
||||
pub websocket_max_frame_size: usize,
|
||||
|
||||
#[cfg(feature = "with-mio")]
|
||||
pub poll_event_capacity: usize,
|
||||
#[cfg(feature = "with-mio")]
|
||||
pub poll_timeout_microseconds: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct ProtocolConfig {
|
||||
/// Maximum number of torrents to accept in scrape request
|
||||
pub max_scrape_torrents: usize,
|
||||
/// Maximum number of offers to accept in announce request
|
||||
pub max_offers: usize,
|
||||
/// Ask peers to announce this often (seconds)
|
||||
pub peer_announce_interval: usize,
|
||||
}
|
||||
|
||||
#[cfg(feature = "with-mio")]
|
||||
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct HandlerConfig {
|
||||
/// Maximum number of requests to receive from channel before locking
|
||||
/// mutex and starting work
|
||||
pub max_requests_per_iter: usize,
|
||||
pub channel_recv_timeout_microseconds: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct CleaningConfig {
|
||||
/// Clean peers this often (seconds)
|
||||
pub torrent_cleaning_interval: u64,
|
||||
/// Remove peers that have not announced for this long (seconds)
|
||||
pub max_peer_age: u64,
|
||||
|
||||
// Clean connections this often (seconds)
|
||||
#[cfg(feature = "with-glommio")]
|
||||
pub connection_cleaning_interval: u64,
|
||||
/// Close connections if no responses have been sent to them for this long (seconds)
|
||||
#[cfg(feature = "with-glommio")]
|
||||
pub max_connection_idle: u64,
|
||||
|
||||
/// Remove connections that are older than this (seconds)
|
||||
#[cfg(feature = "with-mio")]
|
||||
pub max_connection_age: u64,
|
||||
}
|
||||
|
||||
#[cfg(feature = "with-mio")]
|
||||
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct StatisticsConfig {
|
||||
/// Print statistics this often (seconds). Do not print when set to zero.
|
||||
pub interval: u64,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
socket_workers: 1,
|
||||
request_workers: 1,
|
||||
log_level: LogLevel::default(),
|
||||
network: NetworkConfig::default(),
|
||||
protocol: ProtocolConfig::default(),
|
||||
#[cfg(feature = "with-mio")]
|
||||
handlers: Default::default(),
|
||||
cleaning: CleaningConfig::default(),
|
||||
privileges: PrivilegeConfig::default(),
|
||||
access_list: AccessListConfig::default(),
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
cpu_pinning: Default::default(),
|
||||
#[cfg(feature = "with-mio")]
|
||||
statistics: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NetworkConfig {
|
||||
|
|
@ -143,15 +81,21 @@ impl Default for NetworkConfig {
|
|||
|
||||
websocket_max_message_size: 64 * 1024,
|
||||
websocket_max_frame_size: 16 * 1024,
|
||||
|
||||
#[cfg(feature = "with-mio")]
|
||||
poll_event_capacity: 4096,
|
||||
#[cfg(feature = "with-mio")]
|
||||
poll_timeout_microseconds: 200_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct ProtocolConfig {
|
||||
/// Maximum number of torrents to accept in scrape request
|
||||
pub max_scrape_torrents: usize,
|
||||
/// Maximum number of offers to accept in announce request
|
||||
pub max_offers: usize,
|
||||
/// Ask peers to announce this often (seconds)
|
||||
pub peer_announce_interval: usize,
|
||||
}
|
||||
|
||||
impl Default for ProtocolConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
|
@ -162,14 +106,17 @@ impl Default for ProtocolConfig {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "with-mio")]
|
||||
impl Default for HandlerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_requests_per_iter: 256,
|
||||
channel_recv_timeout_microseconds: 200,
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct CleaningConfig {
|
||||
/// Clean peers this often (seconds)
|
||||
pub torrent_cleaning_interval: u64,
|
||||
/// Remove peers that have not announced for this long (seconds)
|
||||
pub max_peer_age: u64,
|
||||
// Clean connections this often (seconds)
|
||||
pub connection_cleaning_interval: u64,
|
||||
/// Close connections if no responses have been sent to them for this long (seconds)
|
||||
pub max_connection_idle: u64,
|
||||
}
|
||||
|
||||
impl Default for CleaningConfig {
|
||||
|
|
@ -177,24 +124,12 @@ impl Default for CleaningConfig {
|
|||
Self {
|
||||
torrent_cleaning_interval: 30,
|
||||
max_peer_age: 1800,
|
||||
#[cfg(feature = "with-glommio")]
|
||||
max_connection_idle: 60 * 5,
|
||||
|
||||
#[cfg(feature = "with-mio")]
|
||||
max_connection_age: 1800,
|
||||
#[cfg(feature = "with-glommio")]
|
||||
connection_cleaning_interval: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "with-mio")]
|
||||
impl Default for StatisticsConfig {
|
||||
fn default() -> Self {
|
||||
Self { interval: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Config;
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use aquatic_common::access_list::AccessListArcSwap;
|
||||
|
||||
pub type TlsConfig = futures_rustls::rustls::ServerConfig;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct State {
|
||||
pub access_list: Arc<AccessListArcSwap>,
|
||||
}
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
pub mod common;
|
||||
pub mod request;
|
||||
pub mod socket;
|
||||
|
||||
use std::sync::{atomic::AtomicUsize, Arc};
|
||||
|
||||
use crate::{common::create_tls_config, config::Config};
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
|
||||
use aquatic_common::privileges::drop_privileges_after_socket_binding;
|
||||
|
||||
use self::common::*;
|
||||
|
||||
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
|
||||
|
||||
pub const SHARED_IN_CHANNEL_SIZE: usize = 1024;
|
||||
|
||||
pub fn run(config: Config, state: State) -> anyhow::Result<()> {
|
||||
let num_peers = config.socket_workers + config.request_workers;
|
||||
|
||||
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE);
|
||||
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16);
|
||||
|
||||
let num_bound_sockets = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let tls_config = Arc::new(create_tls_config(&config).unwrap());
|
||||
|
||||
let mut executors = Vec::new();
|
||||
|
||||
for i in 0..(config.socket_workers) {
|
||||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
let tls_config = tls_config.clone();
|
||||
let request_mesh_builder = request_mesh_builder.clone();
|
||||
let response_mesh_builder = response_mesh_builder.clone();
|
||||
let num_bound_sockets = num_bound_sockets.clone();
|
||||
|
||||
let builder = LocalExecutorBuilder::default().name("socket");
|
||||
|
||||
let executor = builder.spawn(move || async move {
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::SocketWorker(i),
|
||||
);
|
||||
|
||||
socket::run_socket_worker(
|
||||
config,
|
||||
state,
|
||||
tls_config,
|
||||
request_mesh_builder,
|
||||
response_mesh_builder,
|
||||
num_bound_sockets,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
executors.push(executor);
|
||||
}
|
||||
|
||||
for i in 0..(config.request_workers) {
|
||||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
let request_mesh_builder = request_mesh_builder.clone();
|
||||
let response_mesh_builder = response_mesh_builder.clone();
|
||||
|
||||
let builder = LocalExecutorBuilder::default().name("request");
|
||||
|
||||
let executor = builder.spawn(move || async move {
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::RequestWorker(i),
|
||||
);
|
||||
|
||||
request::run_request_worker(config, state, request_mesh_builder, response_mesh_builder)
|
||||
.await
|
||||
});
|
||||
|
||||
executors.push(executor);
|
||||
}
|
||||
|
||||
drop_privileges_after_socket_binding(
|
||||
&config.privileges,
|
||||
num_bound_sockets,
|
||||
config.socket_workers,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::Other,
|
||||
);
|
||||
|
||||
for executor in executors {
|
||||
executor
|
||||
.expect("failed to spawn local executor")
|
||||
.join()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::StreamExt;
|
||||
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
|
||||
use glommio::enclose;
|
||||
use glommio::prelude::*;
|
||||
use glommio::timer::TimerActionRepeat;
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
|
||||
use aquatic_ws_protocol::*;
|
||||
|
||||
use crate::common::handlers::*;
|
||||
use crate::common::*;
|
||||
use crate::config::Config;
|
||||
|
||||
use super::common::State;
|
||||
use super::SHARED_IN_CHANNEL_SIZE;
|
||||
|
||||
pub async fn run_request_worker(
|
||||
config: Config,
|
||||
state: State,
|
||||
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
|
||||
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
|
||||
) {
|
||||
let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap();
|
||||
let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap();
|
||||
|
||||
let out_message_senders = Rc::new(out_message_senders);
|
||||
|
||||
let torrents = Rc::new(RefCell::new(TorrentMaps::default()));
|
||||
let access_list = state.access_list;
|
||||
|
||||
// Periodically clean torrents
|
||||
TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || {
|
||||
enclose!((config, torrents, access_list) move || async move {
|
||||
torrents.borrow_mut().clean(&config, &access_list);
|
||||
|
||||
Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval))
|
||||
})()
|
||||
}));
|
||||
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for (_, receiver) in in_message_receivers.streams() {
|
||||
let handle = spawn_local(handle_request_stream(
|
||||
config.clone(),
|
||||
torrents.clone(),
|
||||
out_message_senders.clone(),
|
||||
receiver,
|
||||
))
|
||||
.detach();
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_request_stream<S>(
|
||||
config: Config,
|
||||
torrents: Rc<RefCell<TorrentMaps>>,
|
||||
out_message_senders: Rc<Senders<(ConnectionMeta, OutMessage)>>,
|
||||
stream: S,
|
||||
) where
|
||||
S: futures_lite::Stream<Item = (ConnectionMeta, InMessage)> + ::std::marker::Unpin,
|
||||
{
|
||||
let rng = Rc::new(RefCell::new(SmallRng::from_entropy()));
|
||||
|
||||
let max_peer_age = config.cleaning.max_peer_age;
|
||||
let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age)));
|
||||
|
||||
TimerActionRepeat::repeat(enclose!((peer_valid_until) move || {
|
||||
enclose!((peer_valid_until) move || async move {
|
||||
*peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age);
|
||||
|
||||
Some(Duration::from_secs(1))
|
||||
})()
|
||||
}));
|
||||
|
||||
let config = &config;
|
||||
let torrents = &torrents;
|
||||
let peer_valid_until = &peer_valid_until;
|
||||
let rng = &rng;
|
||||
let out_message_senders = &out_message_senders;
|
||||
|
||||
stream
|
||||
.for_each_concurrent(
|
||||
SHARED_IN_CHANNEL_SIZE,
|
||||
move |(meta, in_message)| async move {
|
||||
let mut out_messages = Vec::new();
|
||||
|
||||
match in_message {
|
||||
InMessage::AnnounceRequest(request) => handle_announce_request(
|
||||
&config,
|
||||
&mut rng.borrow_mut(),
|
||||
&mut torrents.borrow_mut(),
|
||||
&mut out_messages,
|
||||
peer_valid_until.borrow().to_owned(),
|
||||
meta,
|
||||
request,
|
||||
),
|
||||
InMessage::ScrapeRequest(request) => handle_scrape_request(
|
||||
&config,
|
||||
&mut torrents.borrow_mut(),
|
||||
&mut out_messages,
|
||||
meta,
|
||||
request,
|
||||
),
|
||||
};
|
||||
|
||||
for (meta, out_message) in out_messages.drain(..) {
|
||||
::log::info!("request worker trying to send OutMessage to socket worker");
|
||||
|
||||
out_message_senders
|
||||
.send_to(meta.out_message_consumer_id.0, (meta, out_message))
|
||||
.await
|
||||
.expect("failed sending out_message to socket worker");
|
||||
|
||||
::log::info!("request worker sent OutMessage to socket worker");
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -1,28 +1,26 @@
|
|||
use aquatic_common::access_list::update_access_list;
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
|
||||
use cfg_if::cfg_if;
|
||||
use common::State;
|
||||
use signal_hook::{consts::SIGUSR1, iterator::Signals};
|
||||
|
||||
use crate::config::Config;
|
||||
use std::sync::{atomic::AtomicUsize, Arc};
|
||||
|
||||
use crate::{common::create_tls_config, config::Config};
|
||||
use aquatic_common::privileges::drop_privileges_after_socket_binding;
|
||||
|
||||
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
|
||||
|
||||
pub mod common;
|
||||
pub mod config;
|
||||
#[cfg(feature = "with-glommio")]
|
||||
pub mod glommio;
|
||||
#[cfg(feature = "with-mio")]
|
||||
pub mod mio;
|
||||
pub mod workers;
|
||||
|
||||
pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker";
|
||||
|
||||
pub const SHARED_IN_CHANNEL_SIZE: usize = 1024;
|
||||
|
||||
pub fn run(config: Config) -> ::anyhow::Result<()> {
|
||||
cfg_if!(
|
||||
if #[cfg(feature = "with-glommio")] {
|
||||
let state = glommio::common::State::default();
|
||||
} else {
|
||||
let state = mio::common::State::default();
|
||||
}
|
||||
);
|
||||
let state = State::default();
|
||||
|
||||
update_access_list(&config.access_list, &state.access_list)?;
|
||||
|
||||
|
|
@ -32,13 +30,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
|
|||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
|
||||
cfg_if!(
|
||||
if #[cfg(feature = "with-glommio")] {
|
||||
::std::thread::spawn(move || glommio::run(config, state));
|
||||
} else {
|
||||
::std::thread::spawn(move || mio::run(config, state));
|
||||
}
|
||||
);
|
||||
::std::thread::spawn(move || run_workers(config, state));
|
||||
}
|
||||
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
|
|
@ -59,3 +51,99 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn run_workers(config: Config, state: State) -> anyhow::Result<()> {
|
||||
let num_peers = config.socket_workers + config.request_workers;
|
||||
|
||||
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE);
|
||||
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16);
|
||||
|
||||
let num_bound_sockets = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let tls_config = Arc::new(create_tls_config(&config).unwrap());
|
||||
|
||||
let mut executors = Vec::new();
|
||||
|
||||
for i in 0..(config.socket_workers) {
|
||||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
let tls_config = tls_config.clone();
|
||||
let request_mesh_builder = request_mesh_builder.clone();
|
||||
let response_mesh_builder = response_mesh_builder.clone();
|
||||
let num_bound_sockets = num_bound_sockets.clone();
|
||||
|
||||
let builder = LocalExecutorBuilder::default().name("socket");
|
||||
|
||||
let executor = builder.spawn(move || async move {
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::SocketWorker(i),
|
||||
);
|
||||
|
||||
workers::socket::run_socket_worker(
|
||||
config,
|
||||
state,
|
||||
tls_config,
|
||||
request_mesh_builder,
|
||||
response_mesh_builder,
|
||||
num_bound_sockets,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
executors.push(executor);
|
||||
}
|
||||
|
||||
for i in 0..(config.request_workers) {
|
||||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
let request_mesh_builder = request_mesh_builder.clone();
|
||||
let response_mesh_builder = response_mesh_builder.clone();
|
||||
|
||||
let builder = LocalExecutorBuilder::default().name("request");
|
||||
|
||||
let executor = builder.spawn(move || async move {
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::RequestWorker(i),
|
||||
);
|
||||
|
||||
workers::request::run_request_worker(
|
||||
config,
|
||||
state,
|
||||
request_mesh_builder,
|
||||
response_mesh_builder,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
executors.push(executor);
|
||||
}
|
||||
|
||||
drop_privileges_after_socket_binding(
|
||||
&config.privileges,
|
||||
num_bound_sockets,
|
||||
config.socket_workers,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::Other,
|
||||
);
|
||||
|
||||
for executor in executors {
|
||||
executor
|
||||
.expect("failed to spawn local executor")
|
||||
.join()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,51 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use aquatic_common::access_list::AccessListArcSwap;
|
||||
use aquatic_ws_protocol::*;
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
use log::error;
|
||||
use mio::Token;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::common::*;
|
||||
|
||||
pub const LISTENER_TOKEN: Token = Token(0);
|
||||
pub const CHANNEL_TOKEN: Token = Token(1);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct State {
|
||||
pub access_list: Arc<AccessListArcSwap>,
|
||||
pub torrent_maps: Arc<Mutex<TorrentMaps>>,
|
||||
}
|
||||
|
||||
impl Default for State {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
access_list: Arc::new(Default::default()),
|
||||
torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type InMessageSender = Sender<(ConnectionMeta, InMessage)>;
|
||||
pub type InMessageReceiver = Receiver<(ConnectionMeta, InMessage)>;
|
||||
pub type OutMessageReceiver = Receiver<(ConnectionMeta, OutMessage)>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OutMessageSender(Vec<Sender<(ConnectionMeta, OutMessage)>>);
|
||||
|
||||
impl OutMessageSender {
|
||||
pub fn new(senders: Vec<Sender<(ConnectionMeta, OutMessage)>>) -> Self {
|
||||
Self(senders)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn send(&self, meta: ConnectionMeta, message: OutMessage) {
|
||||
if let Err(err) = self.0[meta.out_message_consumer_id.0].send((meta, message)) {
|
||||
error!("OutMessageSender: couldn't send message: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SocketWorkerStatus = Option<Result<(), String>>;
|
||||
pub type SocketWorkerStatuses = Arc<Mutex<Vec<SocketWorkerStatus>>>;
|
||||
|
|
@ -1,218 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
use std::thread::Builder;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
|
||||
use histogram::Histogram;
|
||||
use mio::{Poll, Waker};
|
||||
use parking_lot::Mutex;
|
||||
use privdrop::PrivDrop;
|
||||
|
||||
pub mod common;
|
||||
pub mod request;
|
||||
pub mod socket;
|
||||
|
||||
use crate::{common::create_tls_config, config::Config};
|
||||
use common::*;
|
||||
|
||||
pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker";
|
||||
|
||||
const SHARED_IN_CHANNEL_SIZE: usize = 1024;
|
||||
|
||||
pub fn run(config: Config, state: State) -> anyhow::Result<()> {
|
||||
start_workers(config.clone(), state.clone()).expect("couldn't start workers");
|
||||
|
||||
// TODO: privdrop here instead
|
||||
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::Other,
|
||||
);
|
||||
|
||||
loop {
|
||||
::std::thread::sleep(Duration::from_secs(
|
||||
config.cleaning.torrent_cleaning_interval,
|
||||
));
|
||||
|
||||
state.torrent_maps.lock().clean(&config, &state.access_list);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
|
||||
let tls_config = Arc::new(create_tls_config(&config)?);
|
||||
|
||||
let (in_message_sender, in_message_receiver) =
|
||||
::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE);
|
||||
|
||||
let mut out_message_senders = Vec::new();
|
||||
let mut wakers = Vec::new();
|
||||
|
||||
let socket_worker_statuses: SocketWorkerStatuses = {
|
||||
let mut statuses = Vec::new();
|
||||
|
||||
for _ in 0..config.socket_workers {
|
||||
statuses.push(None);
|
||||
}
|
||||
|
||||
Arc::new(Mutex::new(statuses))
|
||||
};
|
||||
|
||||
for i in 0..config.socket_workers {
|
||||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
let socket_worker_statuses = socket_worker_statuses.clone();
|
||||
let in_message_sender = in_message_sender.clone();
|
||||
let tls_config = tls_config.clone();
|
||||
let poll = Poll::new()?;
|
||||
let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?);
|
||||
|
||||
let (out_message_sender, out_message_receiver) =
|
||||
::crossbeam_channel::bounded(SHARED_IN_CHANNEL_SIZE * 16);
|
||||
|
||||
out_message_senders.push(out_message_sender);
|
||||
wakers.push(waker);
|
||||
|
||||
Builder::new()
|
||||
.name(format!("socket-{:02}", i + 1))
|
||||
.spawn(move || {
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::SocketWorker(i),
|
||||
);
|
||||
|
||||
socket::run_socket_worker(
|
||||
config,
|
||||
state,
|
||||
i,
|
||||
socket_worker_statuses,
|
||||
poll,
|
||||
in_message_sender,
|
||||
out_message_receiver,
|
||||
tls_config,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
// Wait for socket worker statuses. On error from any, quit program.
|
||||
// On success from all, drop privileges if corresponding setting is set
|
||||
// and continue program.
|
||||
loop {
|
||||
::std::thread::sleep(::std::time::Duration::from_millis(10));
|
||||
|
||||
if let Some(statuses) = socket_worker_statuses.try_lock() {
|
||||
for opt_status in statuses.iter() {
|
||||
if let Some(Err(err)) = opt_status {
|
||||
return Err(::anyhow::anyhow!(err.to_owned()));
|
||||
}
|
||||
}
|
||||
|
||||
if statuses.iter().all(Option::is_some) {
|
||||
if config.privileges.drop_privileges {
|
||||
PrivDrop::default()
|
||||
.chroot(config.privileges.chroot_path.clone())
|
||||
.user(config.privileges.user.clone())
|
||||
.apply()
|
||||
.context("Couldn't drop root privileges")?;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let out_message_sender = OutMessageSender::new(out_message_senders);
|
||||
|
||||
for i in 0..config.request_workers {
|
||||
let config = config.clone();
|
||||
let state = state.clone();
|
||||
let in_message_receiver = in_message_receiver.clone();
|
||||
let out_message_sender = out_message_sender.clone();
|
||||
let wakers = wakers.clone();
|
||||
|
||||
Builder::new()
|
||||
.name(format!("request-{:02}", i + 1))
|
||||
.spawn(move || {
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::RequestWorker(i),
|
||||
);
|
||||
|
||||
request::run_request_worker(
|
||||
config,
|
||||
state,
|
||||
in_message_receiver,
|
||||
out_message_sender,
|
||||
wakers,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
if config.statistics.interval != 0 {
|
||||
let state = state.clone();
|
||||
let config = config.clone();
|
||||
|
||||
Builder::new()
|
||||
.name("statistics".to_string())
|
||||
.spawn(move || {
|
||||
#[cfg(feature = "cpu-pinning")]
|
||||
pin_current_if_configured_to(
|
||||
&config.cpu_pinning,
|
||||
config.socket_workers,
|
||||
WorkerIndex::Other,
|
||||
);
|
||||
|
||||
loop {
|
||||
::std::thread::sleep(Duration::from_secs(config.statistics.interval));
|
||||
|
||||
print_statistics(&state);
|
||||
}
|
||||
})
|
||||
.expect("spawn statistics thread");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_statistics(state: &State) {
|
||||
let mut peers_per_torrent = Histogram::new();
|
||||
|
||||
{
|
||||
let torrents = &mut state.torrent_maps.lock();
|
||||
|
||||
for torrent in torrents.ipv4.values() {
|
||||
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
|
||||
|
||||
if let Err(err) = peers_per_torrent.increment(num_peers) {
|
||||
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
|
||||
}
|
||||
}
|
||||
for torrent in torrents.ipv6.values() {
|
||||
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
|
||||
|
||||
if let Err(err) = peers_per_torrent.increment(num_peers) {
|
||||
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if peers_per_torrent.entries() != 0 {
|
||||
println!(
|
||||
"peers per torrent: min: {}, p50: {}, p75: {}, p90: {}, p99: {}, p999: {}, max: {}",
|
||||
peers_per_torrent.minimum().unwrap(),
|
||||
peers_per_torrent.percentile(50.0).unwrap(),
|
||||
peers_per_torrent.percentile(75.0).unwrap(),
|
||||
peers_per_torrent.percentile(90.0).unwrap(),
|
||||
peers_per_torrent.percentile(99.0).unwrap(),
|
||||
peers_per_torrent.percentile(99.9).unwrap(),
|
||||
peers_per_torrent.maximum().unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use mio::Waker;
|
||||
use parking_lot::MutexGuard;
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
|
||||
use aquatic_ws_protocol::*;
|
||||
|
||||
use crate::common::handlers::{handle_announce_request, handle_scrape_request};
|
||||
use crate::common::*;
|
||||
use crate::config::Config;
|
||||
|
||||
use super::common::*;
|
||||
|
||||
pub fn run_request_worker(
|
||||
config: Config,
|
||||
state: State,
|
||||
in_message_receiver: InMessageReceiver,
|
||||
out_message_sender: OutMessageSender,
|
||||
wakers: Vec<Arc<Waker>>,
|
||||
) {
|
||||
let mut wake_socket_workers: Vec<bool> = (0..config.socket_workers).map(|_| false).collect();
|
||||
|
||||
let mut announce_requests = Vec::new();
|
||||
let mut scrape_requests = Vec::new();
|
||||
let mut out_messages = Vec::new();
|
||||
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
|
||||
let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds);
|
||||
|
||||
loop {
|
||||
let mut opt_torrent_map_guard: Option<MutexGuard<TorrentMaps>> = None;
|
||||
|
||||
for i in 0..config.handlers.max_requests_per_iter {
|
||||
let opt_in_message = if i == 0 {
|
||||
in_message_receiver.recv().ok()
|
||||
} else {
|
||||
in_message_receiver.recv_timeout(timeout).ok()
|
||||
};
|
||||
|
||||
match opt_in_message {
|
||||
Some((meta, InMessage::AnnounceRequest(r))) => {
|
||||
announce_requests.push((meta, r));
|
||||
}
|
||||
Some((meta, InMessage::ScrapeRequest(r))) => {
|
||||
scrape_requests.push((meta, r));
|
||||
}
|
||||
None => {
|
||||
if let Some(torrent_guard) = state.torrent_maps.try_lock() {
|
||||
opt_torrent_map_guard = Some(torrent_guard);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut torrent_map_guard =
|
||||
opt_torrent_map_guard.unwrap_or_else(|| state.torrent_maps.lock());
|
||||
|
||||
let valid_until = ValidUntil::new(config.cleaning.max_peer_age);
|
||||
|
||||
for (meta, request) in announce_requests.drain(..) {
|
||||
handle_announce_request(
|
||||
&config,
|
||||
&mut rng,
|
||||
&mut torrent_map_guard,
|
||||
&mut out_messages,
|
||||
valid_until,
|
||||
meta,
|
||||
request,
|
||||
);
|
||||
}
|
||||
for (meta, request) in scrape_requests.drain(..) {
|
||||
handle_scrape_request(
|
||||
&config,
|
||||
&mut torrent_map_guard,
|
||||
&mut out_messages,
|
||||
meta,
|
||||
request,
|
||||
);
|
||||
}
|
||||
|
||||
::std::mem::drop(torrent_map_guard);
|
||||
|
||||
for (meta, out_message) in out_messages.drain(..) {
|
||||
wake_socket_workers[meta.out_message_consumer_id.0] = true;
|
||||
out_message_sender.send(meta, out_message);
|
||||
}
|
||||
|
||||
for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() {
|
||||
if *wake {
|
||||
if let Err(err) = wakers[worker_index].wake() {
|
||||
::log::error!("request handler couldn't wake poll: {:?}", err);
|
||||
}
|
||||
|
||||
*wake = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,577 +0,0 @@
|
|||
use std::{collections::VecDeque, io::ErrorKind, marker::PhantomData, net::Shutdown, sync::Arc};
|
||||
|
||||
use aquatic_common::ValidUntil;
|
||||
use aquatic_ws_protocol::{InMessage, OutMessage};
|
||||
use mio::{net::TcpStream, Interest, Poll, Token};
|
||||
use rustls::{ServerConfig, ServerConnection};
|
||||
use tungstenite::{
|
||||
handshake::{server::NoCallback, MidHandshake},
|
||||
protocol::WebSocketConfig,
|
||||
HandshakeError, ServerHandshake,
|
||||
};
|
||||
|
||||
use crate::common::ConnectionMeta;
|
||||
|
||||
const MAX_PENDING_MESSAGES: usize = 16;
|
||||
|
||||
type TlsStream = rustls::StreamOwned<ServerConnection, TcpStream>;
|
||||
|
||||
type WsHandshakeResult<S> =
|
||||
Result<tungstenite::WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>;
|
||||
|
||||
type ConnectionReadResult<T> = ::std::io::Result<ConnectionReadStatus<T>>;
|
||||
|
||||
pub trait RegistryStatus {}
|
||||
|
||||
pub struct Registered;
|
||||
|
||||
impl RegistryStatus for Registered {}
|
||||
|
||||
pub struct NotRegistered;
|
||||
|
||||
impl RegistryStatus for NotRegistered {}
|
||||
|
||||
enum ConnectionReadStatus<T> {
|
||||
Message(T, InMessage),
|
||||
Ok(T),
|
||||
WouldBlock(T),
|
||||
}
|
||||
|
||||
enum ConnectionState<R: RegistryStatus> {
|
||||
TlsHandshaking(TlsHandshaking<R>),
|
||||
WsHandshaking(WsHandshaking<R>),
|
||||
WsConnection(WsConnection<R>),
|
||||
}
|
||||
|
||||
pub struct Connection<R: RegistryStatus> {
|
||||
pub valid_until: ValidUntil,
|
||||
meta: ConnectionMeta,
|
||||
state: ConnectionState<R>,
|
||||
pub message_queue: VecDeque<OutMessage>,
|
||||
pub interest: Interest,
|
||||
phantom_data: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<R: RegistryStatus> Connection<R> {
|
||||
pub fn get_meta(&self) -> ConnectionMeta {
|
||||
self.meta
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection<NotRegistered> {
|
||||
pub fn new(
|
||||
tls_config: Arc<ServerConfig>,
|
||||
ws_config: WebSocketConfig,
|
||||
tcp_stream: TcpStream,
|
||||
valid_until: ValidUntil,
|
||||
meta: ConnectionMeta,
|
||||
) -> Self {
|
||||
let state =
|
||||
ConnectionState::TlsHandshaking(TlsHandshaking::new(tls_config, ws_config, tcp_stream));
|
||||
|
||||
Self {
|
||||
valid_until,
|
||||
meta,
|
||||
state,
|
||||
message_queue: Default::default(),
|
||||
interest: Interest::READABLE,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read until stream blocks (or error occurs)
|
||||
///
|
||||
/// Requires Connection not to be registered, since it might be dropped on errors
|
||||
pub fn read<F>(
|
||||
mut self,
|
||||
message_handler: &mut F,
|
||||
) -> ::std::io::Result<Connection<NotRegistered>>
|
||||
where
|
||||
F: FnMut(ConnectionMeta, InMessage),
|
||||
{
|
||||
loop {
|
||||
let result = match self.state {
|
||||
ConnectionState::TlsHandshaking(inner) => inner.read(),
|
||||
ConnectionState::WsHandshaking(inner) => inner.read(),
|
||||
ConnectionState::WsConnection(inner) => inner.read(),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(ConnectionReadStatus::Message(state, message)) => {
|
||||
self.state = state;
|
||||
|
||||
message_handler(self.meta, message);
|
||||
|
||||
// Stop looping even if WouldBlock wasn't necessarily reached. Otherwise,
|
||||
// we might get stuck reading from this connection only. Since we register
|
||||
// the connection again upon reinsertion into the ConnectionMap, we should
|
||||
// be getting new events anyway.
|
||||
return Ok(self);
|
||||
}
|
||||
Ok(ConnectionReadStatus::Ok(state)) => {
|
||||
self.state = state;
|
||||
|
||||
::log::debug!("read connection");
|
||||
}
|
||||
Ok(ConnectionReadStatus::WouldBlock(state)) => {
|
||||
self.state = state;
|
||||
|
||||
::log::debug!("reading connection would block");
|
||||
|
||||
return Ok(self);
|
||||
}
|
||||
Err(err) => {
|
||||
::log::debug!("Connection::read error: {}", err);
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(self, poll: &mut Poll, token: Token) -> Connection<Registered> {
|
||||
let state = match self.state {
|
||||
ConnectionState::TlsHandshaking(inner) => {
|
||||
ConnectionState::TlsHandshaking(inner.register(poll, token, self.interest))
|
||||
}
|
||||
ConnectionState::WsHandshaking(inner) => {
|
||||
ConnectionState::WsHandshaking(inner.register(poll, token, self.interest))
|
||||
}
|
||||
ConnectionState::WsConnection(inner) => {
|
||||
ConnectionState::WsConnection(inner.register(poll, token, self.interest))
|
||||
}
|
||||
};
|
||||
|
||||
Connection {
|
||||
valid_until: self.valid_until,
|
||||
meta: self.meta,
|
||||
state,
|
||||
message_queue: self.message_queue,
|
||||
interest: self.interest,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close(self) {
|
||||
::log::debug!("will close connection to {}", self.meta.peer_addr.get());
|
||||
|
||||
match self.state {
|
||||
ConnectionState::TlsHandshaking(inner) => inner.close(),
|
||||
ConnectionState::WsHandshaking(inner) => inner.close(),
|
||||
ConnectionState::WsConnection(inner) => inner.close(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection<Registered> {
|
||||
pub fn write_or_queue_message(
|
||||
&mut self,
|
||||
poll: &mut Poll,
|
||||
message: OutMessage,
|
||||
) -> ::std::io::Result<()> {
|
||||
let message_clone = message.clone();
|
||||
|
||||
match self.write_message(message) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) if err.kind() == ErrorKind::WouldBlock => {
|
||||
if self.message_queue.len() < MAX_PENDING_MESSAGES {
|
||||
self.message_queue.push_back(message_clone);
|
||||
|
||||
if !self.interest.is_writable() {
|
||||
self.interest = Interest::WRITABLE;
|
||||
|
||||
self.reregister(poll)?;
|
||||
}
|
||||
} else {
|
||||
::log::info!("Connection::message_queue is full, dropping message");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
|
||||
if let ConnectionState::WsConnection(_) = self.state {
|
||||
while let Some(message) = self.message_queue.pop_front() {
|
||||
let message_clone = message.clone();
|
||||
|
||||
match self.write_message(message) {
|
||||
Ok(()) => {}
|
||||
Err(err) if err.kind() == ErrorKind::WouldBlock => {
|
||||
// Can't make message queue longer than it was before pop_front
|
||||
self.message_queue.push_front(message_clone);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if self.message_queue.is_empty() {
|
||||
self.interest = Interest::READABLE;
|
||||
}
|
||||
|
||||
self.reregister(poll)?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(std::io::Error::new(
|
||||
ErrorKind::NotConnected,
|
||||
"WebSocket connection not established",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn write_message(&mut self, message: OutMessage) -> ::std::io::Result<()> {
|
||||
if let ConnectionState::WsConnection(WsConnection {
|
||||
ref mut web_socket, ..
|
||||
}) = self.state
|
||||
{
|
||||
match web_socket.write_message(message.to_ws_message()) {
|
||||
Ok(_) => {}
|
||||
Err(tungstenite::Error::SendQueueFull(_message)) => {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::WouldBlock,
|
||||
"Send queue full",
|
||||
))
|
||||
}
|
||||
Err(tungstenite::Error::Io(err)) => return Err(err),
|
||||
Err(err) => return Err(std::io::Error::new(ErrorKind::Other, err))?,
|
||||
}
|
||||
|
||||
match web_socket.write_pending() {
|
||||
Ok(()) => Ok(()),
|
||||
Err(tungstenite::Error::Io(err)) => Err(err),
|
||||
Err(err) => Err(std::io::Error::new(ErrorKind::Other, err))?,
|
||||
}
|
||||
} else {
|
||||
Err(std::io::Error::new(
|
||||
ErrorKind::NotConnected,
|
||||
"WebSocket connection not established",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
|
||||
let token = Token(self.meta.connection_id.0);
|
||||
|
||||
match self.state {
|
||||
ConnectionState::TlsHandshaking(ref mut inner) => {
|
||||
inner.reregister(poll, token, self.interest)
|
||||
}
|
||||
ConnectionState::WsHandshaking(ref mut inner) => {
|
||||
inner.reregister(poll, token, self.interest)
|
||||
}
|
||||
ConnectionState::WsConnection(ref mut inner) => {
|
||||
inner.reregister(poll, token, self.interest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deregister(self, poll: &mut Poll) -> Connection<NotRegistered> {
|
||||
let state = match self.state {
|
||||
ConnectionState::TlsHandshaking(inner) => {
|
||||
ConnectionState::TlsHandshaking(inner.deregister(poll))
|
||||
}
|
||||
ConnectionState::WsHandshaking(inner) => {
|
||||
ConnectionState::WsHandshaking(inner.deregister(poll))
|
||||
}
|
||||
ConnectionState::WsConnection(inner) => {
|
||||
ConnectionState::WsConnection(inner.deregister(poll))
|
||||
}
|
||||
};
|
||||
|
||||
Connection {
|
||||
valid_until: self.valid_until,
|
||||
meta: self.meta,
|
||||
state,
|
||||
message_queue: self.message_queue,
|
||||
interest: self.interest,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TlsHandshaking<R: RegistryStatus> {
|
||||
tls_conn: ServerConnection,
|
||||
ws_config: WebSocketConfig,
|
||||
tcp_stream: TcpStream,
|
||||
phantom_data: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl TlsHandshaking<NotRegistered> {
|
||||
fn new(tls_config: Arc<ServerConfig>, ws_config: WebSocketConfig, stream: TcpStream) -> Self {
|
||||
Self {
|
||||
tls_conn: ServerConnection::new(tls_config).unwrap(),
|
||||
ws_config,
|
||||
tcp_stream: stream,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn read(mut self) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
|
||||
match self.tls_conn.read_tls(&mut self.tcp_stream) {
|
||||
Ok(0) => {
|
||||
return Err(::std::io::Error::new(
|
||||
ErrorKind::ConnectionReset,
|
||||
"Connection closed",
|
||||
))
|
||||
}
|
||||
Ok(_) => match self.tls_conn.process_new_packets() {
|
||||
Ok(_) => {
|
||||
while self.tls_conn.wants_write() {
|
||||
self.tls_conn.write_tls(&mut self.tcp_stream)?;
|
||||
}
|
||||
|
||||
if self.tls_conn.is_handshaking() {
|
||||
Ok(ConnectionReadStatus::WouldBlock(
|
||||
ConnectionState::TlsHandshaking(self),
|
||||
))
|
||||
} else {
|
||||
let tls_stream = TlsStream::new(self.tls_conn, self.tcp_stream);
|
||||
|
||||
WsHandshaking::handle_handshake_result(tungstenite::accept_with_config(
|
||||
tls_stream,
|
||||
Some(self.ws_config),
|
||||
))
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = self.tls_conn.write_tls(&mut self.tcp_stream);
|
||||
|
||||
Err(::std::io::Error::new(ErrorKind::InvalidData, err))
|
||||
}
|
||||
},
|
||||
Err(err) if err.kind() == ErrorKind::WouldBlock => {
|
||||
return Ok(ConnectionReadStatus::WouldBlock(
|
||||
ConnectionState::TlsHandshaking(self),
|
||||
))
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
fn register(
|
||||
mut self,
|
||||
poll: &mut Poll,
|
||||
token: Token,
|
||||
interest: Interest,
|
||||
) -> TlsHandshaking<Registered> {
|
||||
poll.registry()
|
||||
.register(&mut self.tcp_stream, token, interest)
|
||||
.unwrap();
|
||||
|
||||
TlsHandshaking {
|
||||
tls_conn: self.tls_conn,
|
||||
ws_config: self.ws_config,
|
||||
tcp_stream: self.tcp_stream,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn close(self) {
|
||||
::log::debug!("closing connection (TlsHandshaking state)");
|
||||
|
||||
let _ = self.tcp_stream.shutdown(Shutdown::Both);
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsHandshaking<Registered> {
|
||||
fn deregister(mut self, poll: &mut Poll) -> TlsHandshaking<NotRegistered> {
|
||||
poll.registry().deregister(&mut self.tcp_stream).unwrap();
|
||||
|
||||
TlsHandshaking {
|
||||
tls_conn: self.tls_conn,
|
||||
ws_config: self.ws_config,
|
||||
tcp_stream: self.tcp_stream,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reregister(
|
||||
&mut self,
|
||||
poll: &mut Poll,
|
||||
token: Token,
|
||||
interest: Interest,
|
||||
) -> std::io::Result<()> {
|
||||
poll.registry()
|
||||
.reregister(&mut self.tcp_stream, token, interest)
|
||||
}
|
||||
}
|
||||
|
||||
struct WsHandshaking<R: RegistryStatus> {
|
||||
mid_handshake: MidHandshake<ServerHandshake<TlsStream, NoCallback>>,
|
||||
phantom_data: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl WsHandshaking<NotRegistered> {
|
||||
fn read(self) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
|
||||
Self::handle_handshake_result(self.mid_handshake.handshake())
|
||||
}
|
||||
|
||||
fn handle_handshake_result(
|
||||
handshake_result: WsHandshakeResult<TlsStream>,
|
||||
) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
|
||||
match handshake_result {
|
||||
Ok(web_socket) => {
|
||||
let conn = ConnectionState::WsConnection(WsConnection {
|
||||
web_socket,
|
||||
phantom_data: PhantomData::default(),
|
||||
});
|
||||
|
||||
Ok(ConnectionReadStatus::Ok(conn))
|
||||
}
|
||||
Err(HandshakeError::Interrupted(mid_handshake)) => {
|
||||
let conn = ConnectionState::WsHandshaking(WsHandshaking {
|
||||
mid_handshake,
|
||||
phantom_data: PhantomData::default(),
|
||||
});
|
||||
|
||||
Ok(ConnectionReadStatus::WouldBlock(conn))
|
||||
}
|
||||
Err(HandshakeError::Failure(err)) => {
|
||||
return Err(std::io::Error::new(ErrorKind::InvalidData, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register(
|
||||
mut self,
|
||||
poll: &mut Poll,
|
||||
token: Token,
|
||||
interest: Interest,
|
||||
) -> WsHandshaking<Registered> {
|
||||
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
|
||||
|
||||
poll.registry()
|
||||
.register(tcp_stream, token, interest)
|
||||
.unwrap();
|
||||
|
||||
WsHandshaking {
|
||||
mid_handshake: self.mid_handshake,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn close(mut self) {
|
||||
::log::debug!("closing connection (WsHandshaking state)");
|
||||
|
||||
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
|
||||
|
||||
let _ = tcp_stream.shutdown(Shutdown::Both);
|
||||
}
|
||||
}
|
||||
|
||||
impl WsHandshaking<Registered> {
|
||||
fn deregister(mut self, poll: &mut Poll) -> WsHandshaking<NotRegistered> {
|
||||
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
|
||||
|
||||
poll.registry().deregister(tcp_stream).unwrap();
|
||||
|
||||
WsHandshaking {
|
||||
mid_handshake: self.mid_handshake,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reregister(
|
||||
&mut self,
|
||||
poll: &mut Poll,
|
||||
token: Token,
|
||||
interest: Interest,
|
||||
) -> std::io::Result<()> {
|
||||
let tcp_stream = &mut self.mid_handshake.get_mut().get_mut().sock;
|
||||
|
||||
poll.registry().reregister(tcp_stream, token, interest)
|
||||
}
|
||||
}
|
||||
|
||||
struct WsConnection<R: RegistryStatus> {
|
||||
web_socket: tungstenite::WebSocket<TlsStream>,
|
||||
phantom_data: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl WsConnection<NotRegistered> {
|
||||
fn read(mut self) -> ConnectionReadResult<ConnectionState<NotRegistered>> {
|
||||
match self.web_socket.read_message() {
|
||||
Ok(
|
||||
message @ tungstenite::Message::Text(_) | message @ tungstenite::Message::Binary(_),
|
||||
) => match InMessage::from_ws_message(message) {
|
||||
Ok(message) => {
|
||||
::log::debug!("received WebSocket message");
|
||||
|
||||
Ok(ConnectionReadStatus::Message(
|
||||
ConnectionState::WsConnection(self),
|
||||
message,
|
||||
))
|
||||
}
|
||||
Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)),
|
||||
},
|
||||
Ok(message) => {
|
||||
::log::info!("received unexpected WebSocket message: {}", message);
|
||||
|
||||
Err(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unexpected WebSocket message type",
|
||||
))
|
||||
}
|
||||
Err(tungstenite::Error::Io(err)) if err.kind() == ErrorKind::WouldBlock => {
|
||||
let conn = ConnectionState::WsConnection(self);
|
||||
|
||||
Ok(ConnectionReadStatus::WouldBlock(conn))
|
||||
}
|
||||
Err(tungstenite::Error::Io(err)) => Err(err),
|
||||
Err(err) => Err(std::io::Error::new(ErrorKind::InvalidData, err)),
|
||||
}
|
||||
}
|
||||
|
||||
fn register(
|
||||
mut self,
|
||||
poll: &mut Poll,
|
||||
token: Token,
|
||||
interest: Interest,
|
||||
) -> WsConnection<Registered> {
|
||||
poll.registry()
|
||||
.register(self.web_socket.get_mut().get_mut(), token, interest)
|
||||
.unwrap();
|
||||
|
||||
WsConnection {
|
||||
web_socket: self.web_socket,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn close(mut self) {
|
||||
::log::debug!("closing connection (WsConnection state)");
|
||||
|
||||
let _ = self.web_socket.close(None);
|
||||
let _ = self.web_socket.write_pending();
|
||||
}
|
||||
}
|
||||
|
||||
impl WsConnection<Registered> {
|
||||
fn deregister(mut self, poll: &mut Poll) -> WsConnection<NotRegistered> {
|
||||
poll.registry()
|
||||
.deregister(self.web_socket.get_mut().get_mut())
|
||||
.unwrap();
|
||||
|
||||
WsConnection {
|
||||
web_socket: self.web_socket,
|
||||
phantom_data: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reregister(
|
||||
&mut self,
|
||||
poll: &mut Poll,
|
||||
token: Token,
|
||||
interest: Interest,
|
||||
) -> std::io::Result<()> {
|
||||
poll.registry()
|
||||
.reregister(self.web_socket.get_mut().get_mut(), token, interest)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,403 +0,0 @@
|
|||
use std::io::ErrorKind;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::Context;
|
||||
use aquatic_common::access_list::AccessListQuery;
|
||||
use aquatic_common::CanonicalSocketAddr;
|
||||
use hashbrown::HashMap;
|
||||
use mio::net::TcpListener;
|
||||
use mio::{Events, Interest, Poll, Token};
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
|
||||
use aquatic_ws_protocol::*;
|
||||
|
||||
use crate::common::*;
|
||||
use crate::config::Config;
|
||||
|
||||
pub mod connection;
|
||||
|
||||
use super::common::*;
|
||||
|
||||
use connection::{Connection, NotRegistered, Registered};
|
||||
|
||||
struct ConnectionMap {
|
||||
token_counter: Token,
|
||||
connections: HashMap<Token, Connection<Registered>>,
|
||||
}
|
||||
|
||||
impl Default for ConnectionMap {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
token_counter: Token(2),
|
||||
connections: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionMap {
|
||||
fn insert_and_register_new<F>(&mut self, poll: &mut Poll, connection_creator: F)
|
||||
where
|
||||
F: FnOnce(Token) -> Connection<NotRegistered>,
|
||||
{
|
||||
self.token_counter.0 = self.token_counter.0.wrapping_add(1);
|
||||
|
||||
// Don't assign LISTENER_TOKEN or CHANNEL_TOKEN
|
||||
if self.token_counter.0 < 2 {
|
||||
self.token_counter.0 = 2;
|
||||
}
|
||||
|
||||
let token = self.token_counter;
|
||||
|
||||
// Remove, deregister and close any existing connection with this token.
|
||||
// This shouldn't happen in practice.
|
||||
if let Some(connection) = self.connections.remove(&token) {
|
||||
::log::warn!(
|
||||
"removing existing connection {} because of token reuse",
|
||||
token.0
|
||||
);
|
||||
|
||||
connection.deregister(poll).close();
|
||||
}
|
||||
|
||||
let connection = connection_creator(token);
|
||||
|
||||
self.insert_and_register(poll, token, connection);
|
||||
}
|
||||
|
||||
fn insert_and_register(
|
||||
&mut self,
|
||||
poll: &mut Poll,
|
||||
key: Token,
|
||||
conn: Connection<NotRegistered>,
|
||||
) {
|
||||
self.connections.insert(key, conn.register(poll, key));
|
||||
}
|
||||
|
||||
fn remove_and_deregister(
|
||||
&mut self,
|
||||
poll: &mut Poll,
|
||||
key: &Token,
|
||||
) -> Option<Connection<NotRegistered>> {
|
||||
if let Some(connection) = self.connections.remove(key) {
|
||||
Some(connection.deregister(poll))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mut(&mut self, key: &Token) -> Option<&mut Connection<Registered>> {
|
||||
self.connections.get_mut(key)
|
||||
}
|
||||
|
||||
/// Close and remove inactive connections
|
||||
fn clean(mut self, poll: &mut Poll) -> Self {
|
||||
let now = Instant::now();
|
||||
|
||||
let mut retained_connections = HashMap::default();
|
||||
|
||||
for (token, connection) in self.connections.drain() {
|
||||
if connection.valid_until.0 < now {
|
||||
connection.deregister(poll).close();
|
||||
} else {
|
||||
retained_connections.insert(token, connection);
|
||||
}
|
||||
}
|
||||
|
||||
ConnectionMap {
|
||||
connections: retained_connections,
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_socket_worker(
|
||||
config: Config,
|
||||
state: State,
|
||||
socket_worker_index: usize,
|
||||
socket_worker_statuses: SocketWorkerStatuses,
|
||||
poll: Poll,
|
||||
in_message_sender: InMessageSender,
|
||||
out_message_receiver: OutMessageReceiver,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) {
|
||||
match create_listener(&config) {
|
||||
Ok(listener) => {
|
||||
socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(()));
|
||||
|
||||
run_poll_loop(
|
||||
config,
|
||||
&state,
|
||||
socket_worker_index,
|
||||
poll,
|
||||
in_message_sender,
|
||||
out_message_receiver,
|
||||
listener,
|
||||
tls_config,
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
socket_worker_statuses.lock()[socket_worker_index] =
|
||||
Some(Err(format!("Couldn't open socket: {:#}", err)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_poll_loop(
|
||||
config: Config,
|
||||
state: &State,
|
||||
socket_worker_index: usize,
|
||||
mut poll: Poll,
|
||||
in_message_sender: InMessageSender,
|
||||
out_message_receiver: OutMessageReceiver,
|
||||
listener: ::std::net::TcpListener,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) {
|
||||
let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds);
|
||||
let ws_config = WebSocketConfig {
|
||||
max_message_size: Some(config.network.websocket_max_message_size),
|
||||
max_frame_size: Some(config.network.websocket_max_frame_size),
|
||||
max_send_queue: Some(2),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut listener = TcpListener::from_std(listener);
|
||||
let mut events = Events::with_capacity(config.network.poll_event_capacity);
|
||||
|
||||
poll.registry()
|
||||
.register(&mut listener, LISTENER_TOKEN, Interest::READABLE)
|
||||
.unwrap();
|
||||
|
||||
let mut connections = ConnectionMap::default();
|
||||
let mut local_responses = Vec::new();
|
||||
|
||||
let mut iter_counter = 0usize;
|
||||
|
||||
loop {
|
||||
poll.poll(&mut events, Some(poll_timeout))
|
||||
.expect("failed polling");
|
||||
|
||||
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
|
||||
|
||||
for event in events.iter() {
|
||||
let token = event.token();
|
||||
|
||||
match token {
|
||||
LISTENER_TOKEN => {
|
||||
accept_new_streams(
|
||||
&tls_config,
|
||||
ws_config,
|
||||
socket_worker_index,
|
||||
&mut listener,
|
||||
&mut poll,
|
||||
&mut connections,
|
||||
valid_until,
|
||||
);
|
||||
}
|
||||
CHANNEL_TOKEN => {
|
||||
write_or_queue_messages(
|
||||
&mut poll,
|
||||
out_message_receiver
|
||||
.try_iter()
|
||||
.take(out_message_receiver.len()),
|
||||
&mut connections,
|
||||
);
|
||||
}
|
||||
token => {
|
||||
if event.is_writable() {
|
||||
let mut remove_connection = false;
|
||||
|
||||
if let Some(connection) = connections.get_mut(&token) {
|
||||
if let Err(err) = connection.write(&mut poll) {
|
||||
::log::debug!("Connection::write error: {}", err);
|
||||
|
||||
remove_connection = true;
|
||||
}
|
||||
}
|
||||
|
||||
if remove_connection {
|
||||
if let Some(connection) =
|
||||
connections.remove_and_deregister(&mut poll, &token)
|
||||
{
|
||||
connection.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
if event.is_readable() {
|
||||
handle_stream_read_event(
|
||||
&config,
|
||||
state,
|
||||
&mut local_responses,
|
||||
&in_message_sender,
|
||||
&mut poll,
|
||||
&mut connections,
|
||||
token,
|
||||
valid_until,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
write_or_queue_messages(&mut poll, local_responses.drain(..), &mut connections);
|
||||
}
|
||||
|
||||
// Remove inactive connections, but not every iteration
|
||||
if iter_counter % 128 == 0 {
|
||||
connections = connections.clean(&mut poll);
|
||||
}
|
||||
|
||||
iter_counter = iter_counter.wrapping_add(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_new_streams(
|
||||
tls_config: &Arc<rustls::ServerConfig>,
|
||||
ws_config: WebSocketConfig,
|
||||
socket_worker_index: usize,
|
||||
listener: &mut TcpListener,
|
||||
poll: &mut Poll,
|
||||
connections: &mut ConnectionMap,
|
||||
valid_until: ValidUntil,
|
||||
) {
|
||||
loop {
|
||||
match listener.accept() {
|
||||
Ok((stream, _)) => {
|
||||
let peer_addr = if let Ok(peer_addr) = stream.peer_addr() {
|
||||
CanonicalSocketAddr::new(peer_addr)
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
connections.insert_and_register_new(poll, move |token| {
|
||||
let meta = ConnectionMeta {
|
||||
out_message_consumer_id: ConsumerId(socket_worker_index),
|
||||
connection_id: ConnectionId(token.0),
|
||||
peer_addr,
|
||||
pending_scrape_id: None, // FIXME
|
||||
};
|
||||
|
||||
Connection::new(tls_config.clone(), ws_config, stream, valid_until, meta)
|
||||
});
|
||||
}
|
||||
Err(err) if err.kind() == ErrorKind::WouldBlock => {
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
::log::info!("error while accepting streams: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_stream_read_event(
|
||||
config: &Config,
|
||||
state: &State,
|
||||
local_responses: &mut Vec<(ConnectionMeta, OutMessage)>,
|
||||
in_message_sender: &InMessageSender,
|
||||
poll: &mut Poll,
|
||||
connections: &mut ConnectionMap,
|
||||
token: Token,
|
||||
valid_until: ValidUntil,
|
||||
) {
|
||||
let access_list_mode = config.access_list.mode;
|
||||
|
||||
if let Some(mut connection) = connections.remove_and_deregister(poll, &token) {
|
||||
let message_handler = &mut |meta, message| match message {
|
||||
InMessage::AnnounceRequest(ref request)
|
||||
if !state
|
||||
.access_list
|
||||
.allows(access_list_mode, &request.info_hash.0) =>
|
||||
{
|
||||
let out_message = OutMessage::ErrorResponse(ErrorResponse {
|
||||
failure_reason: "Info hash not allowed".into(),
|
||||
action: Some(ErrorResponseAction::Announce),
|
||||
info_hash: Some(request.info_hash),
|
||||
});
|
||||
|
||||
local_responses.push((meta, out_message));
|
||||
}
|
||||
in_message => {
|
||||
if let Err(err) = in_message_sender.send((meta, in_message)) {
|
||||
::log::info!("InMessageSender: couldn't send message: {:?}", err);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
connection.valid_until = valid_until;
|
||||
|
||||
match connection.read(message_handler) {
|
||||
Ok(connection) => {
|
||||
connections.insert_and_register(poll, token, connection);
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_or_queue_messages<I>(poll: &mut Poll, responses: I, connections: &mut ConnectionMap)
|
||||
where
|
||||
I: Iterator<Item = (ConnectionMeta, OutMessage)>,
|
||||
{
|
||||
for (meta, out_message) in responses {
|
||||
let token = Token(meta.connection_id.0);
|
||||
|
||||
let mut remove_connection = false;
|
||||
|
||||
if let Some(connection) = connections.get_mut(&token) {
|
||||
if connection.get_meta().peer_addr != meta.peer_addr {
|
||||
::log::warn!(
|
||||
"socket worker error: connection socket addr {} didn't match channel {}. Token: {}.",
|
||||
connection.get_meta().peer_addr.get(),
|
||||
meta.peer_addr.get(),
|
||||
token.0
|
||||
);
|
||||
|
||||
remove_connection = true;
|
||||
} else {
|
||||
match connection.write_or_queue_message(poll, out_message) {
|
||||
Ok(()) => {}
|
||||
Err(err) => {
|
||||
::log::debug!("Connection::write_or_queue_message error: {}", err);
|
||||
|
||||
remove_connection = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if remove_connection {
|
||||
connections.remove_and_deregister(poll, &token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> {
|
||||
let builder = if config.network.address.is_ipv4() {
|
||||
Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))
|
||||
} else {
|
||||
Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))
|
||||
}
|
||||
.context("Couldn't create socket2::Socket")?;
|
||||
|
||||
if config.network.ipv6_only {
|
||||
builder
|
||||
.set_only_v6(true)
|
||||
.context("Couldn't put socket in ipv6 only mode")?
|
||||
}
|
||||
|
||||
builder
|
||||
.set_nonblocking(true)
|
||||
.context("Couldn't put socket in non-blocking mode")?;
|
||||
builder
|
||||
.set_reuse_port(true)
|
||||
.context("Couldn't put socket in reuse_port mode")?;
|
||||
builder
|
||||
.bind(&config.network.address.into())
|
||||
.with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?;
|
||||
builder
|
||||
.listen(128)
|
||||
.context("Couldn't listen for connections on socket")?;
|
||||
|
||||
Ok(builder.into())
|
||||
}
|
||||
2
aquatic_ws/src/workers/mod.rs
Normal file
2
aquatic_ws/src/workers/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod request;
|
||||
pub mod socket;
|
||||
|
|
@ -1,11 +1,130 @@
|
|||
use aquatic_common::extract_response_peers;
|
||||
use hashbrown::HashMap;
|
||||
use rand::rngs::SmallRng;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::StreamExt;
|
||||
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
|
||||
use glommio::enclose;
|
||||
use glommio::prelude::*;
|
||||
use glommio::timer::TimerActionRepeat;
|
||||
use hashbrown::HashMap;
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
|
||||
use aquatic_common::extract_response_peers;
|
||||
use aquatic_ws_protocol::*;
|
||||
|
||||
use crate::common::*;
|
||||
use crate::config::Config;
|
||||
use crate::SHARED_IN_CHANNEL_SIZE;
|
||||
|
||||
pub async fn run_request_worker(
|
||||
config: Config,
|
||||
state: State,
|
||||
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
|
||||
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
|
||||
) {
|
||||
let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap();
|
||||
let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap();
|
||||
|
||||
let out_message_senders = Rc::new(out_message_senders);
|
||||
|
||||
let torrents = Rc::new(RefCell::new(TorrentMaps::default()));
|
||||
let access_list = state.access_list;
|
||||
|
||||
// Periodically clean torrents
|
||||
TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || {
|
||||
enclose!((config, torrents, access_list) move || async move {
|
||||
torrents.borrow_mut().clean(&config, &access_list);
|
||||
|
||||
Some(Duration::from_secs(config.cleaning.torrent_cleaning_interval))
|
||||
})()
|
||||
}));
|
||||
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for (_, receiver) in in_message_receivers.streams() {
|
||||
let handle = spawn_local(handle_request_stream(
|
||||
config.clone(),
|
||||
torrents.clone(),
|
||||
out_message_senders.clone(),
|
||||
receiver,
|
||||
))
|
||||
.detach();
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_request_stream<S>(
|
||||
config: Config,
|
||||
torrents: Rc<RefCell<TorrentMaps>>,
|
||||
out_message_senders: Rc<Senders<(ConnectionMeta, OutMessage)>>,
|
||||
stream: S,
|
||||
) where
|
||||
S: futures_lite::Stream<Item = (ConnectionMeta, InMessage)> + ::std::marker::Unpin,
|
||||
{
|
||||
let rng = Rc::new(RefCell::new(SmallRng::from_entropy()));
|
||||
|
||||
let max_peer_age = config.cleaning.max_peer_age;
|
||||
let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age)));
|
||||
|
||||
TimerActionRepeat::repeat(enclose!((peer_valid_until) move || {
|
||||
enclose!((peer_valid_until) move || async move {
|
||||
*peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age);
|
||||
|
||||
Some(Duration::from_secs(1))
|
||||
})()
|
||||
}));
|
||||
|
||||
let config = &config;
|
||||
let torrents = &torrents;
|
||||
let peer_valid_until = &peer_valid_until;
|
||||
let rng = &rng;
|
||||
let out_message_senders = &out_message_senders;
|
||||
|
||||
stream
|
||||
.for_each_concurrent(
|
||||
SHARED_IN_CHANNEL_SIZE,
|
||||
move |(meta, in_message)| async move {
|
||||
let mut out_messages = Vec::new();
|
||||
|
||||
match in_message {
|
||||
InMessage::AnnounceRequest(request) => handle_announce_request(
|
||||
&config,
|
||||
&mut rng.borrow_mut(),
|
||||
&mut torrents.borrow_mut(),
|
||||
&mut out_messages,
|
||||
peer_valid_until.borrow().to_owned(),
|
||||
meta,
|
||||
request,
|
||||
),
|
||||
InMessage::ScrapeRequest(request) => handle_scrape_request(
|
||||
&config,
|
||||
&mut torrents.borrow_mut(),
|
||||
&mut out_messages,
|
||||
meta,
|
||||
request,
|
||||
),
|
||||
};
|
||||
|
||||
for (meta, out_message) in out_messages.drain(..) {
|
||||
::log::info!("request worker trying to send OutMessage to socket worker");
|
||||
|
||||
out_message_senders
|
||||
.send_to(meta.out_message_consumer_id.0, (meta, out_message))
|
||||
.await
|
||||
.expect("failed sending out_message to socket worker");
|
||||
|
||||
::log::info!("request worker sent OutMessage to socket worker");
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub fn handle_announce_request(
|
||||
config: &Config,
|
||||
|
|
@ -29,8 +29,6 @@ use crate::config::Config;
|
|||
|
||||
use crate::common::*;
|
||||
|
||||
use super::common::*;
|
||||
|
||||
const LOCAL_CHANNEL_SIZE: usize = 16;
|
||||
|
||||
struct PendingScrapeResponse {
|
||||
Loading…
Add table
Add a link
Reference in a new issue