Merge pull request #18 from greatest-ape/ws-glommio

Rewrite aquatic_ws with glommio
This commit is contained in:
Joakim Frostegård 2021-11-02 17:03:37 +01:00 committed by GitHub
commit 0d7119d121
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 1347 additions and 1772 deletions

View file

@ -55,8 +55,6 @@ openssl pkcs8 -in key.pem -topk8 -nocrypt -out key.pk8 # rustls
$SUDO cp cert.crt /usr/local/share/ca-certificates/snakeoil.crt
$SUDO update-ca-certificates
openssl pkcs12 -export -passout "pass:p" -out identity.pfx -inkey key.pem -in cert.crt
# Build and start tracker
cargo build --bin aquatic
@ -85,9 +83,8 @@ echo "log_level = 'trace'
[network]
address = '127.0.0.1:3002'
use_tls = true
tls_pkcs12_path = './identity.pfx'
tls_pkcs12_password = 'p'
tls_certificate_path = './cert.crt'
tls_private_key_path = './key.pk8'
" > ws.toml
./target/debug/aquatic ws -c ws.toml > "$HOME/wss.log" 2>&1 &

278
Cargo.lock generated
View file

@ -239,22 +239,24 @@ dependencies = [
"aquatic_cli_helpers",
"aquatic_common",
"aquatic_ws_protocol",
"crossbeam-channel",
"async-tungstenite",
"core_affinity",
"either",
"futures",
"futures-lite",
"futures-rustls",
"glommio",
"hashbrown 0.11.2",
"histogram",
"indexmap",
"log",
"mimalloc",
"mio",
"native-tls",
"parking_lot",
"privdrop",
"quickcheck",
"quickcheck_macros",
"rand",
"rustls-pemfile",
"serde",
"socket2 0.4.2",
"slab",
"tungstenite",
]
@ -265,16 +267,19 @@ dependencies = [
"anyhow",
"aquatic_cli_helpers",
"aquatic_ws_protocol",
"async-tungstenite",
"futures",
"futures-rustls",
"glommio",
"hashbrown 0.11.2",
"mimalloc",
"mio",
"quickcheck",
"quickcheck_macros",
"rand",
"rand_distr",
"rustls",
"serde",
"serde_json",
"slab",
"tungstenite",
]
@ -308,6 +313,19 @@ dependencies = [
"nodrop",
]
[[package]]
name = "async-tungstenite"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "742cc7dcb20b2f84a42f4691aa999070ec7e78f8e7e7438bf14be7017b44907e"
dependencies = [
"futures-io",
"futures-util",
"log",
"pin-project-lite",
"tungstenite",
]
[[package]]
name = "atty"
version = "0.2.14"
@ -487,22 +505,6 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "core-foundation"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6888e10551bb93e424d8df1d07f1a8b4fceb0001a3a4b048bfc47554946f47b3"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
[[package]]
name = "core_affinity"
version = "0.5.10"
@ -733,21 +735,6 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.0.1"
@ -758,12 +745,48 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a12aa0eb539080d55c3f2d45a67c3b58b6b0773c1a3ca2dfec66d58c97fd66ca"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d"
[[package]]
name = "futures-executor"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.17"
@ -785,6 +808,19 @@ dependencies = [
"waker-fn",
]
[[package]]
name = "futures-macro"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18e4a4b95cea4b4ccbcf1c5675ca7c4ee4e9e75eb79944d07defde18068f79bb"
dependencies = [
"autocfg",
"proc-macro-hack",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-rustls"
version = "0.22.0"
@ -796,6 +832,39 @@ dependencies = [
"webpki",
]
[[package]]
name = "futures-sink"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11"
[[package]]
name = "futures-task"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99"
[[package]]
name = "futures-util"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481"
dependencies = [
"autocfg",
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"proc-macro-hack",
"proc-macro-nested",
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.4"
@ -1150,24 +1219,6 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "native-tls"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d"
dependencies = [
"lazy_static",
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nix"
version = "0.23.0"
@ -1269,39 +1320,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "openssl"
version = "0.10.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d9facdb76fec0b73c406f125d44d86fdad818d66fef0531eec9233ca425ff4a"
dependencies = [
"bitflags",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-sys",
]
[[package]]
name = "openssl-probe"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a"
[[package]]
name = "openssl-sys"
version = "0.9.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69df2d8dfc6ce3aaf44b40dec6f487d5a886516cf6879c49e98e0710f310a058"
dependencies = [
"autocfg",
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "owned-alloc"
version = "0.2.0"
@ -1352,10 +1370,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443"
[[package]]
name = "pkg-config"
version = "0.3.20"
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c9b1041b4387893b91ee6746cddfc28516aff326a3519fb2adf820932c5e6cb"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "plotters"
@ -1401,6 +1419,18 @@ dependencies = [
"nix",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5"
[[package]]
name = "proc-macro-nested"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086"
[[package]]
name = "proc-macro2"
version = "1.0.30"
@ -1548,15 +1578,6 @@ version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "ring"
version = "0.16.20"
@ -1632,16 +1653,6 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75"
dependencies = [
"lazy_static",
"winapi 0.3.9",
]
[[package]]
name = "scoped-tls"
version = "1.0.0"
@ -1664,29 +1675,6 @@ dependencies = [
"untrusted",
]
[[package]]
name = "security-framework"
version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525bc1abfda2e1998d152c45cf13e696f76d0a4972310b22fac1658b05df7c87"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9dd14d83160b528b7bfd66439110573efcfbe281b17fc2ca9f39f550d619c7e"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "semver"
version = "1.0.4"
@ -1873,20 +1861,6 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "tempfile"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22"
dependencies = [
"cfg-if",
"libc",
"rand",
"redox_syscall",
"remove_dir_all",
"winapi 0.3.9",
]
[[package]]
name = "termcolor"
version = "1.1.2"
@ -2105,12 +2079,6 @@ dependencies = [
"ryu",
]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.3"

View file

@ -13,11 +13,11 @@ of sub-implementations for different protocols:
[mio]: https://github.com/tokio-rs/mio
[glommio]: https://github.com/DataDog/glommio
| Name | Protocol | OS requirements |
|--------------|------------------------------------------------|-----------------------------------------------------------------|
| aquatic_udp | [BitTorrent over UDP] | Cross-platform with [mio] (default) / Linux 5.8+ with [glommio] |
| aquatic_http | [BitTorrent over HTTP] with TLS ([rustls]) | Linux 5.8+ |
| aquatic_ws | [WebTorrent], plain or with TLS ([native-tls]) | Cross-platform |
| Name | Protocol | OS requirements |
|--------------|--------------------------------------------|-----------------------------------------------------------------|
| aquatic_udp | [BitTorrent over UDP] | Cross-platform with [mio] (default) / Linux 5.8+ with [glommio] |
| aquatic_http | [BitTorrent over HTTP] with TLS ([rustls]) | Linux 5.8+ |
| aquatic_ws | [WebTorrent] with TLS (rustls) | Linux 5.8+ |
## Usage
@ -25,8 +25,16 @@ of sub-implementations for different protocols:
- Install Rust with [rustup](https://rustup.rs/) (stable is recommended)
- Install cmake with your package manager (e.g., `apt-get install cmake`)
- If you want to run aquatic_ws and are on Linux or BSD, install OpenSSL
components necessary for dynamic linking (e.g., `apt-get install libssl-dev`)
- Unless you're planning to only run aquatic_udp and only the cross-platform,
mio based implementation, make sure locked memory limits are sufficient.
You can do this by adding the following lines to `/etc/security/limits.conf`,
and then logging out and back in:
```
* hard memlock 512
* soft memlock 512
```
- Clone this git repository and enter it
### Compiling
@ -34,6 +42,11 @@ of sub-implementations for different protocols:
Compile the implementations that you are interested in:
```sh
# Tell Rust to enable support for all CPU extensions present on current CPU
# except for those relating to AVX-512. This is necessary for aquatic_ws and
# recommended for the other implementations.
. ./scripts/env-native-cpu-without-avx-512
cargo build --release -p aquatic_udp
cargo build --release -p aquatic_udp --features "with-glommio" --no-default-features
cargo build --release -p aquatic_http
@ -50,13 +63,12 @@ Begin by generating configuration files. They differ between protocols.
./target/release/aquatic_ws -p > "aquatic-ws-config.toml"
```
Make adjustments to the files. The values you will most likely want to adjust
are `socket_workers` (number of threads reading from and writing to sockets)
and `address` under the `network` section (listening address). This goes for
all three protocols.
Make adjustments to the files. You will likely want to adjust `address`
(listening address) under the `network` section.
`aquatic_http` requires configuring a TLS certificate file and a private key file
to run. More information is available futher down in this document.
Both `aquatic_http` and `aquatic_ws` requires configuring a TLS certificate
file as well as a private key file to run. More information is available
in the `aquatic_http` subsection of this document.
Once done, run the tracker:
@ -66,10 +78,16 @@ Once done, run the tracker:
./target/release/aquatic_ws -c "aquatic-ws-config.toml"
```
More documentation of configuration file values might be available under
`src/lib/config.rs` in crates `aquatic_udp`, `aquatic_http`, `aquatic_ws`.
### Configuration values
#### General settings
Starting a lot more socket workers than request workers is recommended. All
implementations are heavily IO-bound and spend most of their time reading from
and writing to sockets. This part is handled by the `socket_workers`, which
also do parsing, serialisation and access control. They pass announce and
scrape requests to the `request_workers`, which update internal tracker state
and pass back responses.
#### Access control
Access control by info hash is supported for all protocols. The relevant part
of configuration is:
@ -80,6 +98,12 @@ mode = 'off' # Change to 'black' (blacklist) or 'white' (whitelist)
path = '' # Path to text file with newline-delimited hex-encoded info hashes
```
#### More information
More documentation of the various configuration options might be available
under `src/lib/config.rs` in directories `aquatic_udp`, `aquatic_http` and
`aquatic_ws`.
## Details on implementations
### aquatic_udp: UDP BitTorrent tracker
@ -154,25 +178,15 @@ tls_private_key_path = './key.pk8'
### aquatic_ws: WebTorrent tracker
Aims for compatibility with [WebTorrent](https://github.com/webtorrent)
clients, including `wss` protocol support (WebSockets over TLS), with some
exceptions:
clients, with some exceptions:
* Only runs over TLS (wss protocol)
* Doesn't track of the number of torrent downloads (0 is always sent).
* Doesn't allow full scrapes, i.e. of all registered info hashes
#### TLS
To run over TLS, a pkcs12 file (`.pkx`) is needed. It can be generated from
Let's Encrypt certificates as follows, assuming you are in the directory where
they are stored:
```sh
openssl pkcs12 -export -out identity.pfx -inkey privkey.pem -in cert.pem -certfile fullchain.pem
```
Enter a password when prompted. Then move `identity.pfx` somewhere suitable,
and enter the path into the tracker configuration field `tls_pkcs12_path`. Set
the password in the field `tls_pkcs12_password` and set `use_tls` to true.
Please see `aquatic_http` TLS section above.
#### Benchmarks
@ -201,10 +215,14 @@ Server responses per second, best result in bold:
Please refer to `documents/aquatic-ws-load-test-2021-08-18.pdf` for more details.
__Note__: these benchmarks were made with the previous mio-based
implementation.
## Load testing
There are load test binaries for all protocols. They use a CLI structure
similar to `aquatic` and support generation and loading of configuration files.
similar to the trackers and support generation and loading of configuration
files.
To run, first start the tracker that you want to test. Then run the
corresponding load test binary:
@ -218,18 +236,6 @@ corresponding load test binary:
To fairly compare HTTP performance to opentracker, set keepalive to false in
`aquatic_http` settings.
## Architectural overview
One or more socket workers open sockets, read and parse requests from peers and
send them through channels to request workers. The request workers go through
the requests, update shared internal tracker state as appropriate and generate
responses that are sent back to the socket workers. The responses are then
serialized and sent back to the peers.
This design means little waiting for locks on internal state occurs,
while network work can be efficiently distributed over multiple threads,
making use of SO_REUSEPORT setting.
## Copyright and license
Copyright (c) 2020-2021 Joakim Frostegård

26
TODO.md
View file

@ -1,5 +1,15 @@
# TODO
* readme
* document privilige dropping, cpu pinning
* config: fail on unrecognized keys
* access lists:
* use signals to reload, use arcswap everywhere
* use arc-swap Cache?
* add CI tests
* aquatic_udp
* CI for both implementations
* glommio
@ -27,16 +37,12 @@
where only part of request is read, but that part is valid, and reading
is stopped, which might lead to various issues.
* access lists:
* use arc-swap Cache?
* add CI tests
* aquatic_ws: should it send back error on message parse error, or does that
just indicate that not enough data has been received yet?
* Consider turning on safety and override flags in mimalloc, mostly for
simd-json. It might be faster to just stop using simd-json if I consider
it insecure, which it maybe isn't.
* aquatic_ws
* ipv6 only flag
* load test cpu pinning
* test with multiple socket and request workers
* should it send back error on message parse error, or does that
just indicate that not enough data has been received yet?
## General
* extract response peers: extract "one extra" to compensate for removal,

View file

@ -17,23 +17,25 @@ path = "src/bin/main.rs"
[dependencies]
anyhow = "1"
async-tungstenite = "0.15"
aquatic_cli_helpers = "0.1.0"
aquatic_common = "0.1.0"
aquatic_ws_protocol = "0.1.0"
crossbeam-channel = "0.5"
core_affinity = "0.5"
either = "1"
futures-lite = "1"
futures = "0.3"
futures-rustls = "0.22"
glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5" }
hashbrown = { version = "0.11.2", features = ["serde"] }
histogram = "0.6"
indexmap = "1"
log = "0.4"
mimalloc = { version = "0.1", default-features = false }
mio = { version = "0.7", features = ["tcp", "os-poll", "os-util"] }
native-tls = "0.2"
parking_lot = "0.11"
privdrop = "0.5"
rand = { version = "0.8", features = ["small_rng"] }
rustls-pemfile = "0.2"
serde = { version = "1", features = ["derive"] }
socket2 = { version = "0.4.1", features = ["all"] }
slab = "0.4"
tungstenite = "0.15"
[dev-dependencies]

View file

@ -1,14 +1,15 @@
use std::borrow::Borrow;
use std::cell::RefCell;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::rc::Rc;
use std::time::Instant;
use aquatic_common::access_list::{AccessList, AccessListArcSwap};
use crossbeam_channel::{Receiver, Sender};
use aquatic_common::access_list::AccessList;
use futures_lite::AsyncBufReadExt;
use glommio::io::{BufferedFile, StreamReaderBuilder};
use glommio::yield_if_needed;
use hashbrown::HashMap;
use indexmap::IndexMap;
use log::error;
use mio::Token;
use parking_lot::Mutex;
pub use aquatic_common::ValidUntil;
@ -16,19 +17,28 @@ use aquatic_ws_protocol::*;
use crate::config::Config;
pub const LISTENER_TOKEN: Token = Token(0);
pub const CHANNEL_TOKEN: Token = Token(1);
pub type TlsConfig = futures_rustls::rustls::ServerConfig;
#[derive(Copy, Clone, Debug)]
pub struct PendingScrapeId(pub usize);
#[derive(Copy, Clone, Debug)]
pub struct ConsumerId(pub usize);
#[derive(Clone, Copy, Debug)]
pub struct ConnectionId(pub usize);
#[derive(Clone, Copy, Debug)]
pub struct ConnectionMeta {
/// Index of socket worker responsible for this connection. Required for
/// sending back response through correct channel to correct worker.
pub worker_index: usize,
pub out_message_consumer_id: ConsumerId,
pub connection_id: ConnectionId,
/// Peer address as received from socket, meaning it wasn't converted to
/// an IPv4 address if it was a IPv4-mapped IPv6 address
pub naive_peer_addr: SocketAddr,
pub converted_peer_ip: IpAddr,
pub poll_token: Token,
pub pending_scrape_id: Option<PendingScrapeId>,
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
@ -89,16 +99,12 @@ pub struct TorrentMaps {
}
impl TorrentMaps {
pub fn clean(&mut self, config: &Config, access_list: &Arc<AccessList>) {
pub fn clean(&mut self, config: &Config, access_list: &AccessList) {
Self::clean_torrent_map(config, access_list, &mut self.ipv4);
Self::clean_torrent_map(config, access_list, &mut self.ipv6);
}
fn clean_torrent_map(
config: &Config,
access_list: &Arc<AccessList>,
torrent_map: &mut TorrentMap,
) {
fn clean_torrent_map(config: &Config, access_list: &AccessList, torrent_map: &mut TorrentMap) {
let now = Instant::now();
torrent_map.retain(|info_hash, torrent_data| {
@ -134,40 +140,44 @@ impl TorrentMaps {
}
}
#[derive(Clone)]
pub struct State {
pub access_list: Arc<AccessListArcSwap>,
pub torrent_maps: Arc<Mutex<TorrentMaps>>,
}
pub async fn update_access_list<C: Borrow<Config>>(
config: C,
access_list: Rc<RefCell<AccessList>>,
) {
if config.borrow().access_list.mode.is_on() {
match BufferedFile::open(&config.borrow().access_list.path).await {
Ok(file) => {
let mut reader = StreamReaderBuilder::new(file).build();
let mut new_access_list = AccessList::default();
impl Default for State {
fn default() -> Self {
Self {
access_list: Arc::new(Default::default()),
torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())),
}
loop {
let mut buf = String::with_capacity(42);
match reader.read_line(&mut buf).await {
Ok(_) => {
if let Err(err) = new_access_list.insert_from_line(&buf) {
::log::error!(
"Couln't parse access list line '{}': {:?}",
buf,
err
);
}
}
Err(err) => {
::log::error!("Couln't read access list line {:?}", err);
break;
}
}
yield_if_needed().await;
}
*access_list.borrow_mut() = new_access_list;
}
Err(err) => {
::log::error!("Couldn't open access list file: {:?}", err)
}
};
}
}
pub type InMessageSender = Sender<(ConnectionMeta, InMessage)>;
pub type InMessageReceiver = Receiver<(ConnectionMeta, InMessage)>;
pub type OutMessageReceiver = Receiver<(ConnectionMeta, OutMessage)>;
#[derive(Clone)]
pub struct OutMessageSender(Vec<Sender<(ConnectionMeta, OutMessage)>>);
impl OutMessageSender {
pub fn new(senders: Vec<Sender<(ConnectionMeta, OutMessage)>>) -> Self {
Self(senders)
}
#[inline]
pub fn send(&self, meta: ConnectionMeta, message: OutMessage) {
if let Err(err) = self.0[meta.worker_index].send((meta, message)) {
error!("OutMessageSender: couldn't send message: {:?}", err);
}
}
}
pub type SocketWorkerStatus = Option<Result<(), String>>;
pub type SocketWorkerStatuses = Arc<Mutex<Vec<SocketWorkerStatus>>>;

View file

@ -1,5 +1,7 @@
use std::net::SocketAddr;
use std::path::PathBuf;
use aquatic_common::cpu_pinning::CpuPinningConfig;
use aquatic_common::{access_list::AccessListConfig, privileges::PrivilegeConfig};
use serde::{Deserialize, Serialize};
@ -18,11 +20,10 @@ pub struct Config {
pub log_level: LogLevel,
pub network: NetworkConfig,
pub protocol: ProtocolConfig,
pub handlers: HandlerConfig,
pub cleaning: CleaningConfig,
pub statistics: StatisticsConfig,
pub privileges: PrivilegeConfig,
pub access_list: AccessListConfig,
pub cpu_pinning: CpuPinningConfig,
}
impl aquatic_cli_helpers::Config for Config {
@ -37,24 +38,12 @@ pub struct NetworkConfig {
/// Bind to this address
pub address: SocketAddr,
pub ipv6_only: bool,
pub use_tls: bool,
pub tls_pkcs12_path: String,
pub tls_pkcs12_password: String,
pub poll_event_capacity: usize,
pub poll_timeout_microseconds: u64,
pub tls_certificate_path: PathBuf,
pub tls_private_key_path: PathBuf,
pub websocket_max_message_size: usize,
pub websocket_max_frame_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct HandlerConfig {
/// Maximum number of requests to receive from channel before locking
/// mutex and starting work
pub max_requests_per_iter: usize,
pub channel_recv_timeout_microseconds: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct ProtocolConfig {
@ -73,15 +62,6 @@ pub struct CleaningConfig {
pub interval: u64,
/// Remove peers that haven't announced for this long (seconds)
pub max_peer_age: u64,
/// Remove connections that are older than this (seconds)
pub max_connection_age: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct StatisticsConfig {
/// Print statistics this often (seconds). Don't print when set to zero.
pub interval: u64,
}
impl Default for Config {
@ -92,11 +72,10 @@ impl Default for Config {
log_level: LogLevel::default(),
network: NetworkConfig::default(),
protocol: ProtocolConfig::default(),
handlers: HandlerConfig::default(),
cleaning: CleaningConfig::default(),
statistics: StatisticsConfig::default(),
privileges: PrivilegeConfig::default(),
access_list: AccessListConfig::default(),
cpu_pinning: Default::default(),
}
}
}
@ -106,11 +85,8 @@ impl Default for NetworkConfig {
Self {
address: SocketAddr::from(([0, 0, 0, 0], 3000)),
ipv6_only: false,
use_tls: false,
tls_pkcs12_path: "".into(),
tls_pkcs12_password: "".into(),
poll_event_capacity: 4096,
poll_timeout_microseconds: 200_000,
tls_certificate_path: "".into(),
tls_private_key_path: "".into(),
websocket_max_message_size: 64 * 1024,
websocket_max_frame_size: 16 * 1024,
}
@ -120,34 +96,18 @@ impl Default for NetworkConfig {
impl Default for ProtocolConfig {
fn default() -> Self {
Self {
max_scrape_torrents: 255, // FIXME: what value is reasonable?
max_scrape_torrents: 255,
max_offers: 10,
peer_announce_interval: 120,
}
}
}
impl Default for HandlerConfig {
fn default() -> Self {
Self {
max_requests_per_iter: 10000,
channel_recv_timeout_microseconds: 200,
}
}
}
impl Default for CleaningConfig {
fn default() -> Self {
Self {
interval: 30,
max_peer_age: 1800,
max_connection_age: 1800,
}
}
}
impl Default for StatisticsConfig {
fn default() -> Self {
Self { interval: 0 }
}
}

View file

@ -1,282 +0,0 @@
use std::sync::Arc;
use std::time::Duration;
use std::vec::Drain;
use hashbrown::HashMap;
use mio::Waker;
use parking_lot::MutexGuard;
use rand::{rngs::SmallRng, Rng, SeedableRng};
use aquatic_common::extract_response_peers;
use aquatic_ws_protocol::*;
use crate::common::*;
use crate::config::Config;
pub fn run_request_worker(
config: Config,
state: State,
in_message_receiver: InMessageReceiver,
out_message_sender: OutMessageSender,
wakers: Vec<Arc<Waker>>,
) {
let mut wake_socket_workers: Vec<bool> = (0..config.socket_workers).map(|_| false).collect();
let mut announce_requests = Vec::new();
let mut scrape_requests = Vec::new();
let mut rng = SmallRng::from_entropy();
let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds);
loop {
let mut opt_torrent_map_guard: Option<MutexGuard<TorrentMaps>> = None;
for i in 0..config.handlers.max_requests_per_iter {
let opt_in_message = if i == 0 {
in_message_receiver.recv().ok()
} else {
in_message_receiver.recv_timeout(timeout).ok()
};
match opt_in_message {
Some((meta, InMessage::AnnounceRequest(r))) => {
announce_requests.push((meta, r));
}
Some((meta, InMessage::ScrapeRequest(r))) => {
scrape_requests.push((meta, r));
}
None => {
if let Some(torrent_guard) = state.torrent_maps.try_lock() {
opt_torrent_map_guard = Some(torrent_guard);
break;
}
}
}
}
let mut torrent_map_guard =
opt_torrent_map_guard.unwrap_or_else(|| state.torrent_maps.lock());
handle_announce_requests(
&config,
&mut rng,
&mut torrent_map_guard,
&out_message_sender,
&mut wake_socket_workers,
announce_requests.drain(..),
);
handle_scrape_requests(
&config,
&mut torrent_map_guard,
&out_message_sender,
&mut wake_socket_workers,
scrape_requests.drain(..),
);
for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() {
if *wake {
if let Err(err) = wakers[worker_index].wake() {
::log::error!("request handler couldn't wake poll: {:?}", err);
}
*wake = false;
}
}
}
}
pub fn handle_announce_requests(
config: &Config,
rng: &mut impl Rng,
torrent_maps: &mut TorrentMaps,
out_message_sender: &OutMessageSender,
wake_socket_workers: &mut Vec<bool>,
requests: Drain<(ConnectionMeta, AnnounceRequest)>,
) {
let valid_until = ValidUntil::new(config.cleaning.max_peer_age);
for (request_sender_meta, request) in requests {
let torrent_data: &mut TorrentData = if request_sender_meta.converted_peer_ip.is_ipv4() {
torrent_maps.ipv4.entry(request.info_hash).or_default()
} else {
torrent_maps.ipv6.entry(request.info_hash).or_default()
};
// If there is already a peer with this peer_id, check that socket
// addr is same as that of request sender. Otherwise, ignore request.
// Since peers have access to each others peer_id's, they could send
// requests using them, causing all sorts of issues. Checking naive
// (non-converted) socket addresses is enough, since state is split
// on converted peer ip.
if let Some(previous_peer) = torrent_data.peers.get(&request.peer_id) {
if request_sender_meta.naive_peer_addr != previous_peer.connection_meta.naive_peer_addr
{
continue;
}
}
::log::trace!("received request from {:?}", request_sender_meta);
// Insert/update/remove peer who sent this request
{
let peer_status = PeerStatus::from_event_and_bytes_left(
request.event.unwrap_or_default(),
request.bytes_left,
);
let peer = Peer {
connection_meta: request_sender_meta,
status: peer_status,
valid_until,
};
let opt_removed_peer = match peer_status {
PeerStatus::Leeching => {
torrent_data.num_leechers += 1;
torrent_data.peers.insert(request.peer_id, peer)
}
PeerStatus::Seeding => {
torrent_data.num_seeders += 1;
torrent_data.peers.insert(request.peer_id, peer)
}
PeerStatus::Stopped => torrent_data.peers.remove(&request.peer_id),
};
match opt_removed_peer.map(|peer| peer.status) {
Some(PeerStatus::Leeching) => {
torrent_data.num_leechers -= 1;
}
Some(PeerStatus::Seeding) => {
torrent_data.num_seeders -= 1;
}
_ => {}
}
}
// If peer sent offers, send them on to random peers
if let Some(offers) = request.offers {
// FIXME: config: also maybe check this when parsing request
let max_num_peers_to_take = offers.len().min(config.protocol.max_offers);
#[inline]
fn f(peer: &Peer) -> Peer {
*peer
}
let offer_receivers: Vec<Peer> = extract_response_peers(
rng,
&torrent_data.peers,
max_num_peers_to_take,
request.peer_id,
f,
);
for (offer, offer_receiver) in offers.into_iter().zip(offer_receivers) {
let middleman_offer = MiddlemanOfferToPeer {
action: AnnounceAction,
info_hash: request.info_hash,
peer_id: request.peer_id,
offer: offer.offer,
offer_id: offer.offer_id,
};
out_message_sender.send(
offer_receiver.connection_meta,
OutMessage::Offer(middleman_offer),
);
::log::trace!(
"sent middleman offer to {:?}",
offer_receiver.connection_meta
);
wake_socket_workers[offer_receiver.connection_meta.worker_index] = true;
}
}
// If peer sent answer, send it on to relevant peer
if let (Some(answer), Some(answer_receiver_id), Some(offer_id)) =
(request.answer, request.to_peer_id, request.offer_id)
{
if let Some(answer_receiver) = torrent_data.peers.get(&answer_receiver_id) {
let middleman_answer = MiddlemanAnswerToPeer {
action: AnnounceAction,
peer_id: request.peer_id,
info_hash: request.info_hash,
answer,
offer_id,
};
out_message_sender.send(
answer_receiver.connection_meta,
OutMessage::Answer(middleman_answer),
);
::log::trace!(
"sent middleman answer to {:?}",
answer_receiver.connection_meta
);
wake_socket_workers[answer_receiver.connection_meta.worker_index] = true;
}
}
let response = OutMessage::AnnounceResponse(AnnounceResponse {
action: AnnounceAction,
info_hash: request.info_hash,
complete: torrent_data.num_seeders,
incomplete: torrent_data.num_leechers,
announce_interval: config.protocol.peer_announce_interval,
});
out_message_sender.send(request_sender_meta, response);
wake_socket_workers[request_sender_meta.worker_index] = true;
}
}
pub fn handle_scrape_requests(
config: &Config,
torrent_maps: &mut TorrentMaps,
out_message_sender: &OutMessageSender,
wake_socket_workers: &mut Vec<bool>,
requests: Drain<(ConnectionMeta, ScrapeRequest)>,
) {
for (meta, request) in requests {
let info_hashes = if let Some(info_hashes) = request.info_hashes {
info_hashes.as_vec()
} else {
continue;
};
let num_to_take = info_hashes.len().min(config.protocol.max_scrape_torrents);
let mut response = ScrapeResponse {
action: ScrapeAction,
files: HashMap::with_capacity(num_to_take),
};
let torrent_map: &mut TorrentMap = if meta.converted_peer_ip.is_ipv4() {
&mut torrent_maps.ipv4
} else {
&mut torrent_maps.ipv6
};
// If request.info_hashes is empty, don't return scrape for all
// torrents, even though reference server does it. It is too expensive.
for info_hash in info_hashes.into_iter().take(num_to_take) {
if let Some(torrent_data) = torrent_map.get(&info_hash) {
let stats = ScrapeStatistics {
complete: torrent_data.num_seeders,
downloaded: 0, // No implementation planned
incomplete: torrent_data.num_leechers,
};
response.files.insert(info_hash, stats);
}
}
out_message_sender.send(meta, OutMessage::ScrapeResponse(response));
wake_socket_workers[meta.worker_index] = true;
}
}

View file

@ -0,0 +1,297 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::time::Duration;
use aquatic_common::access_list::AccessList;
use aquatic_common::extract_response_peers;
use futures_lite::StreamExt;
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::enclose;
use glommio::prelude::*;
use glommio::timer::TimerActionRepeat;
use hashbrown::HashMap;
use rand::{rngs::SmallRng, Rng, SeedableRng};
use aquatic_ws_protocol::*;
use crate::common::*;
use crate::config::Config;
pub async fn run_request_worker(
config: Config,
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
access_list: AccessList,
) {
let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap();
let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap();
let out_message_senders = Rc::new(out_message_senders);
let torrents = Rc::new(RefCell::new(TorrentMaps::default()));
let access_list = Rc::new(RefCell::new(access_list));
// Periodically clean torrents and update access list
TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || {
enclose!((config, torrents, access_list) move || async move {
update_access_list(&config, access_list.clone()).await;
torrents.borrow_mut().clean(&config, &*access_list.borrow());
Some(Duration::from_secs(config.cleaning.interval))
})()
}));
let mut handles = Vec::new();
for (_, receiver) in in_message_receivers.streams() {
let handle = spawn_local(handle_request_stream(
config.clone(),
torrents.clone(),
out_message_senders.clone(),
receiver,
))
.detach();
handles.push(handle);
}
for handle in handles {
handle.await;
}
}
async fn handle_request_stream<S>(
config: Config,
torrents: Rc<RefCell<TorrentMaps>>,
out_message_senders: Rc<Senders<(ConnectionMeta, OutMessage)>>,
mut stream: S,
) where
S: futures_lite::Stream<Item = (ConnectionMeta, InMessage)> + ::std::marker::Unpin,
{
let mut rng = SmallRng::from_entropy();
let max_peer_age = config.cleaning.max_peer_age;
let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age)));
TimerActionRepeat::repeat(enclose!((peer_valid_until) move || {
enclose!((peer_valid_until) move || async move {
*peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age);
Some(Duration::from_secs(1))
})()
}));
let mut out_messages = Vec::new();
while let Some((meta, in_message)) = stream.next().await {
match in_message {
InMessage::AnnounceRequest(request) => handle_announce_request(
&config,
&mut rng,
&mut torrents.borrow_mut(),
&mut out_messages,
peer_valid_until.borrow().to_owned(),
meta,
request,
),
InMessage::ScrapeRequest(request) => handle_scrape_request(
&config,
&mut torrents.borrow_mut(),
&mut out_messages,
meta,
request,
),
};
for (meta, out_message) in out_messages.drain(..) {
out_message_senders
.send_to(meta.out_message_consumer_id.0, (meta, out_message))
.await
.expect("failed sending out_message to socket worker");
}
yield_if_needed().await;
}
}
pub fn handle_announce_request(
config: &Config,
rng: &mut impl Rng,
torrent_maps: &mut TorrentMaps,
out_messages: &mut Vec<(ConnectionMeta, OutMessage)>,
valid_until: ValidUntil,
request_sender_meta: ConnectionMeta,
request: AnnounceRequest,
) {
let torrent_data: &mut TorrentData = if request_sender_meta.converted_peer_ip.is_ipv4() {
torrent_maps.ipv4.entry(request.info_hash).or_default()
} else {
torrent_maps.ipv6.entry(request.info_hash).or_default()
};
// If there is already a peer with this peer_id, check that socket
// addr is same as that of request sender. Otherwise, ignore request.
// Since peers have access to each others peer_id's, they could send
// requests using them, causing all sorts of issues. Checking naive
// (non-converted) socket addresses is enough, since state is split
// on converted peer ip.
if let Some(previous_peer) = torrent_data.peers.get(&request.peer_id) {
if request_sender_meta.naive_peer_addr != previous_peer.connection_meta.naive_peer_addr {
return;
}
}
::log::trace!("received request from {:?}", request_sender_meta);
// Insert/update/remove peer who sent this request
{
let peer_status = PeerStatus::from_event_and_bytes_left(
request.event.unwrap_or_default(),
request.bytes_left,
);
let peer = Peer {
connection_meta: request_sender_meta,
status: peer_status,
valid_until,
};
let opt_removed_peer = match peer_status {
PeerStatus::Leeching => {
torrent_data.num_leechers += 1;
torrent_data.peers.insert(request.peer_id, peer)
}
PeerStatus::Seeding => {
torrent_data.num_seeders += 1;
torrent_data.peers.insert(request.peer_id, peer)
}
PeerStatus::Stopped => torrent_data.peers.remove(&request.peer_id),
};
match opt_removed_peer.map(|peer| peer.status) {
Some(PeerStatus::Leeching) => {
torrent_data.num_leechers -= 1;
}
Some(PeerStatus::Seeding) => {
torrent_data.num_seeders -= 1;
}
_ => {}
}
}
// If peer sent offers, send them on to random peers
if let Some(offers) = request.offers {
// FIXME: config: also maybe check this when parsing request
let max_num_peers_to_take = offers.len().min(config.protocol.max_offers);
#[inline]
fn f(peer: &Peer) -> Peer {
*peer
}
let offer_receivers: Vec<Peer> = extract_response_peers(
rng,
&torrent_data.peers,
max_num_peers_to_take,
request.peer_id,
f,
);
for (offer, offer_receiver) in offers.into_iter().zip(offer_receivers) {
let middleman_offer = MiddlemanOfferToPeer {
action: AnnounceAction,
info_hash: request.info_hash,
peer_id: request.peer_id,
offer: offer.offer,
offer_id: offer.offer_id,
};
out_messages.push((
offer_receiver.connection_meta,
OutMessage::Offer(middleman_offer),
));
::log::trace!(
"sending middleman offer to {:?}",
offer_receiver.connection_meta
);
}
}
// If peer sent answer, send it on to relevant peer
if let (Some(answer), Some(answer_receiver_id), Some(offer_id)) =
(request.answer, request.to_peer_id, request.offer_id)
{
if let Some(answer_receiver) = torrent_data.peers.get(&answer_receiver_id) {
let middleman_answer = MiddlemanAnswerToPeer {
action: AnnounceAction,
peer_id: request.peer_id,
info_hash: request.info_hash,
answer,
offer_id,
};
out_messages.push((
answer_receiver.connection_meta,
OutMessage::Answer(middleman_answer),
));
::log::trace!(
"sending middleman answer to {:?}",
answer_receiver.connection_meta
);
}
}
let out_message = OutMessage::AnnounceResponse(AnnounceResponse {
action: AnnounceAction,
info_hash: request.info_hash,
complete: torrent_data.num_seeders,
incomplete: torrent_data.num_leechers,
announce_interval: config.protocol.peer_announce_interval,
});
out_messages.push((request_sender_meta, out_message));
}
pub fn handle_scrape_request(
config: &Config,
torrent_maps: &mut TorrentMaps,
out_messages: &mut Vec<(ConnectionMeta, OutMessage)>,
meta: ConnectionMeta,
request: ScrapeRequest,
) {
let info_hashes = if let Some(info_hashes) = request.info_hashes {
info_hashes.as_vec()
} else {
return;
};
let num_to_take = info_hashes.len().min(config.protocol.max_scrape_torrents);
let mut out_message = ScrapeResponse {
action: ScrapeAction,
files: HashMap::with_capacity(num_to_take),
};
let torrent_map: &mut TorrentMap = if meta.converted_peer_ip.is_ipv4() {
&mut torrent_maps.ipv4
} else {
&mut torrent_maps.ipv6
};
for info_hash in info_hashes.into_iter().take(num_to_take) {
if let Some(torrent_data) = torrent_map.get(&info_hash) {
let stats = ScrapeStatistics {
complete: torrent_data.num_seeders,
downloaded: 0, // No implementation planned
incomplete: torrent_data.num_leechers,
};
out_message.files.insert(info_hash, stats);
}
}
out_messages.push((meta, OutMessage::ScrapeResponse(out_message)));
}

View file

@ -1,176 +1,144 @@
use std::fs::File;
use std::io::Read;
use std::sync::Arc;
use std::thread::Builder;
use std::time::Duration;
use std::{
fs::File,
io::BufReader,
sync::{atomic::AtomicUsize, Arc},
};
use anyhow::Context;
use mio::{Poll, Waker};
use native_tls::{Identity, TlsAcceptor};
use parking_lot::Mutex;
use privdrop::PrivDrop;
use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding};
use common::TlsConfig;
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
pub mod common;
use crate::config::Config;
mod common;
pub mod config;
pub mod handler;
pub mod network;
pub mod tasks;
use common::*;
use config::Config;
mod handlers;
mod network;
pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker";
const SHARED_CHANNEL_SIZE: usize = 1024;
pub fn run(config: Config) -> anyhow::Result<()> {
let state = State::default();
tasks::update_access_list(&config, &state);
start_workers(config.clone(), state.clone())?;
loop {
::std::thread::sleep(Duration::from_secs(config.cleaning.interval));
tasks::update_access_list(&config, &state);
state
.torrent_maps
.lock()
.clean(&config, &state.access_list.load_full());
if config.cpu_pinning.active {
core_affinity::set_for_current(core_affinity::CoreId {
id: config.cpu_pinning.offset,
});
}
}
pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
let opt_tls_acceptor = create_tls_acceptor(&config)?;
let (in_message_sender, in_message_receiver) = ::crossbeam_channel::unbounded();
let mut out_message_senders = Vec::new();
let mut wakers = Vec::new();
let socket_worker_statuses: SocketWorkerStatuses = {
let mut statuses = Vec::new();
for _ in 0..config.socket_workers {
statuses.push(None);
}
Arc::new(Mutex::new(statuses))
let access_list = if config.access_list.mode.is_on() {
AccessList::create_from_path(&config.access_list.path).expect("Load access list")
} else {
AccessList::default()
};
for i in 0..config.socket_workers {
let num_peers = config.socket_workers + config.request_workers;
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let num_bound_sockets = Arc::new(AtomicUsize::new(0));
let tls_config = Arc::new(create_tls_config(&config).unwrap());
let mut executors = Vec::new();
for i in 0..(config.socket_workers) {
let config = config.clone();
let state = state.clone();
let socket_worker_statuses = socket_worker_statuses.clone();
let in_message_sender = in_message_sender.clone();
let opt_tls_acceptor = opt_tls_acceptor.clone();
let poll = Poll::new()?;
let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN)?);
let tls_config = tls_config.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let num_bound_sockets = num_bound_sockets.clone();
let access_list = access_list.clone();
let (out_message_sender, out_message_receiver) = ::crossbeam_channel::unbounded();
let mut builder = LocalExecutorBuilder::default();
out_message_senders.push(out_message_sender);
wakers.push(waker);
Builder::new()
.name(format!("socket-{:02}", i + 1))
.spawn(move || {
network::run_socket_worker(
config,
state,
i,
socket_worker_statuses,
poll,
in_message_sender,
out_message_receiver,
opt_tls_acceptor,
);
})?;
}
// Wait for socket worker statuses. On error from any, quit program.
// On success from all, drop privileges if corresponding setting is set
// and continue program.
loop {
::std::thread::sleep(::std::time::Duration::from_millis(10));
if let Some(statuses) = socket_worker_statuses.try_lock() {
for opt_status in statuses.iter() {
if let Some(Err(err)) = opt_status {
return Err(::anyhow::anyhow!(err.to_owned()));
}
}
if statuses.iter().all(Option::is_some) {
if config.privileges.drop_privileges {
PrivDrop::default()
.chroot(config.privileges.chroot_path.clone())
.user(config.privileges.user.clone())
.apply()
.context("Couldn't drop root privileges")?;
}
break;
}
if config.cpu_pinning.active {
builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + i);
}
let executor = builder.spawn(|| async move {
network::run_socket_worker(
config,
tls_config,
request_mesh_builder,
response_mesh_builder,
num_bound_sockets,
access_list,
)
.await
});
executors.push(executor);
}
let out_message_sender = OutMessageSender::new(out_message_senders);
for i in 0..config.request_workers {
for i in 0..(config.request_workers) {
let config = config.clone();
let state = state.clone();
let in_message_receiver = in_message_receiver.clone();
let out_message_sender = out_message_sender.clone();
let wakers = wakers.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let access_list = access_list.clone();
Builder::new()
.name(format!("request-{:02}", i + 1))
.spawn(move || {
handler::run_request_worker(
config,
state,
in_message_receiver,
out_message_sender,
wakers,
);
})?;
let mut builder = LocalExecutorBuilder::default();
if config.cpu_pinning.active {
builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + config.socket_workers + i);
}
let executor = builder.spawn(|| async move {
handlers::run_request_worker(
config,
request_mesh_builder,
response_mesh_builder,
access_list,
)
.await
});
executors.push(executor);
}
if config.statistics.interval != 0 {
let state = state.clone();
let config = config.clone();
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
Builder::new()
.name("statistics".to_string())
.spawn(move || loop {
::std::thread::sleep(Duration::from_secs(config.statistics.interval));
tasks::print_statistics(&state);
})
.expect("spawn statistics thread");
for executor in executors {
executor
.expect("failed to spawn local executor")
.join()
.unwrap();
}
Ok(())
}
pub fn create_tls_acceptor(config: &Config) -> anyhow::Result<Option<TlsAcceptor>> {
if config.network.use_tls {
let mut identity_bytes = Vec::new();
let mut file = File::open(&config.network.tls_pkcs12_path)
.context("Couldn't open pkcs12 identity file")?;
fn create_tls_config(config: &Config) -> anyhow::Result<TlsConfig> {
let certs = {
let f = File::open(&config.network.tls_certificate_path)?;
let mut f = BufReader::new(f);
file.read_to_end(&mut identity_bytes)
.context("Couldn't read pkcs12 identity file")?;
rustls_pemfile::certs(&mut f)?
.into_iter()
.map(|bytes| futures_rustls::rustls::Certificate(bytes))
.collect()
};
let identity = Identity::from_pkcs12(&identity_bytes, &config.network.tls_pkcs12_password)
.context("Couldn't parse pkcs12 identity file")?;
let private_key = {
let f = File::open(&config.network.tls_private_key_path)?;
let mut f = BufReader::new(f);
let acceptor = TlsAcceptor::new(identity)
.context("Couldn't create TlsAcceptor from pkcs12 identity")?;
rustls_pemfile::pkcs8_private_keys(&mut f)?
.first()
.map(|bytes| futures_rustls::rustls::PrivateKey(bytes.clone()))
.ok_or(anyhow::anyhow!("No private keys in file"))?
};
Ok(Some(acceptor))
} else {
Ok(None)
}
let tls_config = futures_rustls::rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, private_key)?;
Ok(tls_config)
}

View file

@ -0,0 +1,450 @@
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::rc::Rc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use aquatic_common::access_list::AccessList;
use aquatic_common::convert_ipv4_mapped_ipv6;
use aquatic_ws_protocol::*;
use async_tungstenite::WebSocketStream;
use futures::stream::{SplitSink, SplitStream};
use futures_lite::future::race;
use futures_lite::StreamExt;
use futures_rustls::server::TlsStream;
use futures_rustls::TlsAcceptor;
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::channels::local_channel::{LocalReceiver, LocalSender, new_unbounded};
use glommio::channels::shared_channel::ConnectedReceiver;
use glommio::net::{TcpListener, TcpStream};
use glommio::timer::TimerActionRepeat;
use glommio::{enclose, prelude::*};
use hashbrown::HashMap;
use slab::Slab;
use crate::config::Config;
use super::common::*;
struct PendingScrapeResponse {
pending_worker_out_messages: usize,
stats: HashMap<InfoHash, ScrapeStatistics>,
}
struct ConnectionReference {
out_message_sender: Rc<LocalSender<(ConnectionMeta, OutMessage)>>,
}
pub async fn run_socket_worker(
config: Config,
tls_config: Arc<TlsConfig>,
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
num_bound_sockets: Arc<AtomicUsize>,
access_list: AccessList,
) {
let config = Rc::new(config);
let access_list = Rc::new(RefCell::new(access_list));
let listener = TcpListener::bind(config.network.address).expect("bind socket");
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap();
let in_message_senders = Rc::new(in_message_senders);
let (_, mut out_message_receivers) =
out_message_mesh_builder.join(Role::Consumer).await.unwrap();
let out_message_consumer_id = ConsumerId(out_message_receivers.consumer_id().unwrap());
let connection_slab = Rc::new(RefCell::new(Slab::new()));
let connections_to_remove = Rc::new(RefCell::new(Vec::new()));
// Periodically update access list
TimerActionRepeat::repeat(enclose!((config, access_list) move || {
enclose!((config, access_list) move || async move {
update_access_list(config.clone(), access_list.clone()).await;
Some(Duration::from_secs(config.cleaning.interval))
})()
}));
// Periodically remove closed connections
TimerActionRepeat::repeat(
enclose!((config, connection_slab, connections_to_remove) move || {
remove_closed_connections(
config.clone(),
connection_slab.clone(),
connections_to_remove.clone(),
)
}),
);
for (_, out_message_receiver) in out_message_receivers.streams() {
spawn_local(receive_out_messages(
out_message_receiver,
connection_slab.clone(),
))
.detach();
}
let mut incoming = listener.incoming();
while let Some(stream) = incoming.next().await {
match stream {
Ok(stream) => {
let (out_message_sender, out_message_receiver) =
new_unbounded();
let out_message_sender = Rc::new(out_message_sender);
let key = RefCell::borrow_mut(&connection_slab).insert(ConnectionReference {
out_message_sender: out_message_sender.clone(),
});
spawn_local(enclose!((config, access_list, in_message_senders, tls_config, connections_to_remove) async move {
if let Err(err) = Connection::run(
config,
access_list,
in_message_senders,
out_message_sender,
out_message_receiver,
out_message_consumer_id,
ConnectionId(key),
tls_config,
stream
).await {
::log::debug!("Connection::run() error: {:?}", err);
}
RefCell::borrow_mut(&connections_to_remove).push(key);
}))
.detach();
}
Err(err) => {
::log::error!("accept connection: {:?}", err);
}
}
}
}
async fn remove_closed_connections(
config: Rc<Config>,
connection_slab: Rc<RefCell<Slab<ConnectionReference>>>,
connections_to_remove: Rc<RefCell<Vec<usize>>>,
) -> Option<Duration> {
let connections_to_remove = connections_to_remove.replace(Vec::new());
for connection_id in connections_to_remove {
if let Some(_) = RefCell::borrow_mut(&connection_slab).try_remove(connection_id) {
::log::debug!("removed connection with id {}", connection_id);
} else {
::log::error!(
"couldn't remove connection with id {}, it is not in connection slab",
connection_id
);
}
}
Some(Duration::from_secs(config.cleaning.interval))
}
async fn receive_out_messages(
mut out_message_receiver: ConnectedReceiver<(ConnectionMeta, OutMessage)>,
connection_references: Rc<RefCell<Slab<ConnectionReference>>>,
) {
while let Some(channel_out_message) = out_message_receiver.next().await {
if let Some(reference) = connection_references
.borrow()
.get(channel_out_message.0.connection_id.0)
{
match reference.out_message_sender.try_send(channel_out_message) {
Ok(()) | Err(GlommioError::Closed(_)) => {},
Err(err) => {
::log::error!(
"Couldn't send out_message from shared channel to local receiver: {:?}",
err
);
}
}
}
}
}
struct Connection;
impl Connection {
async fn run(
config: Rc<Config>,
access_list: Rc<RefCell<AccessList>>,
in_message_senders: Rc<Senders<(ConnectionMeta, InMessage)>>,
out_message_sender: Rc<LocalSender<(ConnectionMeta, OutMessage)>>,
out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>,
out_message_consumer_id: ConsumerId,
connection_id: ConnectionId,
tls_config: Arc<TlsConfig>,
stream: TcpStream,
) -> anyhow::Result<()> {
let peer_addr = stream
.peer_addr()
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?;
let tls_acceptor: TlsAcceptor = tls_config.into();
let stream = tls_acceptor.accept(stream).await?;
let ws_config = tungstenite::protocol::WebSocketConfig {
max_frame_size: Some(config.network.websocket_max_frame_size),
max_message_size: Some(config.network.websocket_max_message_size),
..Default::default()
};
let stream = async_tungstenite::accept_async_with_config(stream, Some(ws_config)).await?;
let (ws_out, ws_in) = futures::StreamExt::split(stream);
let pending_scrape_slab = Rc::new(RefCell::new(Slab::new()));
let reader_handle = spawn_local(enclose!((pending_scrape_slab) async move {
let mut reader = ConnectionReader {
config,
access_list,
in_message_senders,
out_message_sender,
pending_scrape_slab,
out_message_consumer_id,
ws_in,
peer_addr,
connection_id,
};
reader.run_in_message_loop().await
}))
.detach();
let writer_handle = spawn_local(async move {
let mut writer = ConnectionWriter {
out_message_receiver,
ws_out,
pending_scrape_slab,
peer_addr,
};
writer.run_out_message_loop().await
})
.detach();
race(reader_handle, writer_handle).await.unwrap()
}
}
struct ConnectionReader {
config: Rc<Config>,
access_list: Rc<RefCell<AccessList>>,
in_message_senders: Rc<Senders<(ConnectionMeta, InMessage)>>,
out_message_sender: Rc<LocalSender<(ConnectionMeta, OutMessage)>>,
pending_scrape_slab: Rc<RefCell<Slab<PendingScrapeResponse>>>,
out_message_consumer_id: ConsumerId,
ws_in: SplitStream<WebSocketStream<TlsStream<TcpStream>>>,
peer_addr: SocketAddr,
connection_id: ConnectionId,
}
impl ConnectionReader {
async fn run_in_message_loop(&mut self) -> anyhow::Result<()> {
loop {
::log::debug!("read_in_message");
let message = self.ws_in.next().await.unwrap()?;
match InMessage::from_ws_message(message) {
Ok(in_message) => {
::log::debug!("received in_message: {:?}", in_message);
self.handle_in_message(in_message).await?;
}
Err(err) => {
::log::debug!("Couldn't parse in_message: {:?}", err);
self.send_error_response("Invalid request".into(), None);
}
}
}
}
async fn handle_in_message(&mut self, in_message: InMessage) -> anyhow::Result<()> {
match in_message {
InMessage::AnnounceRequest(announce_request) => {
let info_hash = announce_request.info_hash;
if self
.access_list
.borrow()
.allows(self.config.access_list.mode, &info_hash.0)
{
let in_message = InMessage::AnnounceRequest(announce_request);
let consumer_index =
calculate_in_message_consumer_index(&self.config, info_hash);
// Only fails when receiver is closed
self.in_message_senders
.send_to(
consumer_index,
(self.make_connection_meta(None), in_message),
)
.await
.unwrap();
} else {
self.send_error_response("Info hash not allowed".into(), Some(info_hash));
}
}
InMessage::ScrapeRequest(ScrapeRequest { info_hashes, .. }) => {
let info_hashes = if let Some(info_hashes) = info_hashes {
info_hashes
} else {
// If request.info_hashes is empty, don't return scrape for all
// torrents, even though reference server does it. It is too expensive.
self.send_error_response("Full scrapes are not allowed".into(), None);
return Ok(());
};
let mut info_hashes_by_worker: BTreeMap<usize, Vec<InfoHash>> = BTreeMap::new();
for info_hash in info_hashes.as_vec() {
let info_hashes = info_hashes_by_worker
.entry(calculate_in_message_consumer_index(&self.config, info_hash))
.or_default();
info_hashes.push(info_hash);
}
let pending_worker_out_messages = info_hashes_by_worker.len();
let pending_scrape_response = PendingScrapeResponse {
pending_worker_out_messages,
stats: Default::default(),
};
let pending_scrape_id = PendingScrapeId(
RefCell::borrow_mut(&mut self.pending_scrape_slab)
.insert(pending_scrape_response),
);
let meta = self.make_connection_meta(Some(pending_scrape_id));
for (consumer_index, info_hashes) in info_hashes_by_worker {
let in_message = InMessage::ScrapeRequest(ScrapeRequest {
action: ScrapeAction,
info_hashes: Some(ScrapeRequestInfoHashes::Multiple(info_hashes)),
});
// Only fails when receiver is closed
self.in_message_senders
.send_to(consumer_index, (meta, in_message))
.await
.unwrap();
}
}
}
Ok(())
}
fn send_error_response(&self, failure_reason: Cow<'static, str>, info_hash: Option<InfoHash>) {
let out_message = OutMessage::ErrorResponse(ErrorResponse {
action: Some(ErrorResponseAction::Scrape),
failure_reason,
info_hash,
});
if let Err(err) = self
.out_message_sender
.try_send((self.make_connection_meta(None), out_message))
{
::log::error!("ConnectionWriter::send_error_response failed: {:?}", err)
}
}
fn make_connection_meta(&self, pending_scrape_id: Option<PendingScrapeId>) -> ConnectionMeta {
ConnectionMeta {
connection_id: self.connection_id,
out_message_consumer_id: self.out_message_consumer_id,
naive_peer_addr: self.peer_addr,
converted_peer_ip: convert_ipv4_mapped_ipv6(self.peer_addr.ip()),
pending_scrape_id,
}
}
}
struct ConnectionWriter {
out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>,
ws_out: SplitSink<WebSocketStream<TlsStream<TcpStream>>, tungstenite::Message>,
pending_scrape_slab: Rc<RefCell<Slab<PendingScrapeResponse>>>,
peer_addr: SocketAddr,
}
impl ConnectionWriter {
async fn run_out_message_loop(&mut self) -> anyhow::Result<()> {
loop {
let (meta, out_message) = self
.out_message_receiver
.recv()
.await
.expect("wait_for_out_message: can't receive out_message, sender is closed");
if meta.naive_peer_addr != self.peer_addr {
return Err(anyhow::anyhow!("peer addresses didn't match"));
}
match out_message {
OutMessage::ScrapeResponse(out_message) => {
let pending_scrape_id = meta
.pending_scrape_id
.expect("meta.pending_scrape_id not set");
let finished = if let Some(pending) = Slab::get_mut(
&mut RefCell::borrow_mut(&self.pending_scrape_slab),
pending_scrape_id.0,
) {
pending.stats.extend(out_message.files);
pending.pending_worker_out_messages -= 1;
pending.pending_worker_out_messages == 0
} else {
return Err(anyhow::anyhow!("pending scrape not found in slab"));
};
if finished {
let out_message = {
let mut slab = RefCell::borrow_mut(&self.pending_scrape_slab);
let pending = slab.remove(pending_scrape_id.0);
slab.shrink_to_fit();
OutMessage::ScrapeResponse(ScrapeResponse {
action: ScrapeAction,
files: pending.stats,
})
};
self.send_out_message(&out_message).await?;
}
}
out_message => {
self.send_out_message(&out_message).await?;
}
};
}
}
async fn send_out_message(&mut self, out_message: &OutMessage) -> anyhow::Result<()> {
futures::SinkExt::send(&mut self.ws_out, out_message.to_ws_message()).await?;
futures::SinkExt::flush(&mut self.ws_out).await?;
Ok(())
}
}
fn calculate_in_message_consumer_index(config: &Config, info_hash: InfoHash) -> usize {
(info_hash.0[0] as usize) % config.request_workers
}

View file

@ -1,298 +0,0 @@
use std::io::{Read, Write};
use std::net::SocketAddr;
use either::Either;
use hashbrown::HashMap;
use log::info;
use mio::net::TcpStream;
use mio::{Poll, Token};
use native_tls::{MidHandshakeTlsStream, TlsAcceptor, TlsStream};
use tungstenite::handshake::{server::NoCallback, HandshakeError, MidHandshake};
use tungstenite::protocol::WebSocketConfig;
use tungstenite::ServerHandshake;
use tungstenite::WebSocket;
use crate::common::*;
pub enum Stream {
TcpStream(TcpStream),
TlsStream(TlsStream<TcpStream>),
}
impl Stream {
#[inline]
pub fn get_peer_addr(&self) -> ::std::io::Result<SocketAddr> {
match self {
Self::TcpStream(stream) => stream.peer_addr(),
Self::TlsStream(stream) => stream.get_ref().peer_addr(),
}
}
#[inline]
pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
match self {
Self::TcpStream(stream) => poll.registry().deregister(stream),
Self::TlsStream(stream) => poll.registry().deregister(stream.get_mut()),
}
}
}
impl Read for Stream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> Result<usize, ::std::io::Error> {
match self {
Self::TcpStream(stream) => stream.read(buf),
Self::TlsStream(stream) => stream.read(buf),
}
}
/// Not used but provided for completeness
#[inline]
fn read_vectored(
&mut self,
bufs: &mut [::std::io::IoSliceMut<'_>],
) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.read_vectored(bufs),
Self::TlsStream(stream) => stream.read_vectored(bufs),
}
}
}
impl Write for Stream {
#[inline]
fn write(&mut self, buf: &[u8]) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.write(buf),
Self::TlsStream(stream) => stream.write(buf),
}
}
/// Not used but provided for completeness
#[inline]
fn write_vectored(&mut self, bufs: &[::std::io::IoSlice<'_>]) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.write_vectored(bufs),
Self::TlsStream(stream) => stream.write_vectored(bufs),
}
}
#[inline]
fn flush(&mut self) -> ::std::io::Result<()> {
match self {
Self::TcpStream(stream) => stream.flush(),
Self::TlsStream(stream) => stream.flush(),
}
}
}
enum HandshakeMachine {
TcpStream(TcpStream),
TlsStream(TlsStream<TcpStream>),
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
WsMidHandshake(MidHandshake<ServerHandshake<Stream, NoCallback>>),
}
impl HandshakeMachine {
#[inline]
fn new(tcp_stream: TcpStream) -> Self {
Self::TcpStream(tcp_stream)
}
#[inline]
fn advance(
self,
ws_config: WebSocketConfig,
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS
) -> (Option<Either<EstablishedWs, Self>>, bool) {
// bool = stop looping
match self {
HandshakeMachine::TcpStream(stream) => {
if let Some(tls_acceptor) = opt_tls_acceptor {
Self::handle_tls_handshake_result(tls_acceptor.accept(stream))
} else {
let handshake_result = ::tungstenite::accept_with_config(
Stream::TcpStream(stream),
Some(ws_config),
);
Self::handle_ws_handshake_result(handshake_result)
}
}
HandshakeMachine::TlsStream(stream) => {
let handshake_result = ::tungstenite::accept(Stream::TlsStream(stream));
Self::handle_ws_handshake_result(handshake_result)
}
HandshakeMachine::TlsMidHandshake(handshake) => {
Self::handle_tls_handshake_result(handshake.handshake())
}
HandshakeMachine::WsMidHandshake(handshake) => {
Self::handle_ws_handshake_result(handshake.handshake())
}
}
}
#[inline]
fn handle_tls_handshake_result(
result: Result<TlsStream<TcpStream>, ::native_tls::HandshakeError<TcpStream>>,
) -> (Option<Either<EstablishedWs, Self>>, bool) {
match result {
Ok(stream) => {
::log::trace!(
"established tls handshake with peer with addr: {:?}",
stream.get_ref().peer_addr()
);
(Some(Either::Right(Self::TlsStream(stream))), false)
}
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
(Some(Either::Right(Self::TlsMidHandshake(handshake))), true)
}
Err(native_tls::HandshakeError::Failure(err)) => {
info!("tls handshake error: {}", err);
(None, false)
}
}
}
#[inline]
fn handle_ws_handshake_result(
result: Result<WebSocket<Stream>, HandshakeError<ServerHandshake<Stream, NoCallback>>>,
) -> (Option<Either<EstablishedWs, Self>>, bool) {
match result {
Ok(mut ws) => match ws.get_mut().get_peer_addr() {
Ok(peer_addr) => {
::log::trace!(
"established ws handshake with peer with addr: {:?}",
peer_addr
);
let established_ws = EstablishedWs { ws, peer_addr };
(Some(Either::Left(established_ws)), false)
}
Err(err) => {
::log::info!(
"get_peer_addr failed during handshake, removing connection: {:?}",
err
);
(None, false)
}
},
Err(HandshakeError::Interrupted(handshake)) => (
Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))),
true,
),
Err(HandshakeError::Failure(err)) => {
info!("ws handshake error: {}", err);
(None, false)
}
}
}
}
pub struct EstablishedWs {
pub ws: WebSocket<Stream>,
pub peer_addr: SocketAddr,
}
pub struct Connection {
ws_config: WebSocketConfig,
pub valid_until: ValidUntil,
inner: Either<EstablishedWs, HandshakeMachine>,
}
/// Create from TcpStream. Run `advance_handshakes` until `get_established_ws`
/// returns Some(EstablishedWs).
///
/// advance_handshakes takes ownership of self because the TLS and WebSocket
/// handshake methods do. get_established_ws doesn't, since work can be done
/// on a mutable reference to a tungstenite websocket, and this way, the whole
/// Connection doesn't have to be removed from and reinserted into the
/// TorrentMap. This is also the reason for wrapping Container.inner in an
/// Either instead of combining all states into one structure just having a
/// single method for advancing handshakes and maybe returning a websocket.
impl Connection {
#[inline]
pub fn new(ws_config: WebSocketConfig, valid_until: ValidUntil, tcp_stream: TcpStream) -> Self {
Self {
ws_config,
valid_until,
inner: Either::Right(HandshakeMachine::new(tcp_stream)),
}
}
#[inline]
pub fn get_established_ws(&mut self) -> Option<&mut EstablishedWs> {
match self.inner {
Either::Left(ref mut ews) => Some(ews),
Either::Right(_) => None,
}
}
#[inline]
pub fn advance_handshakes(
self,
opt_tls_acceptor: &Option<TlsAcceptor>,
valid_until: ValidUntil,
) -> (Option<Self>, bool) {
match self.inner {
Either::Left(_) => (Some(self), false),
Either::Right(machine) => {
let ws_config = self.ws_config;
let (opt_inner, stop_loop) = machine.advance(ws_config, opt_tls_acceptor);
let opt_new_self = opt_inner.map(|inner| Self {
ws_config,
valid_until,
inner,
});
(opt_new_self, stop_loop)
}
}
}
#[inline]
pub fn close(&mut self) {
if let Either::Left(ref mut ews) = self.inner {
if ews.ws.can_read() {
if let Err(err) = ews.ws.close(None) {
::log::info!("error closing ws: {}", err);
}
// Required after ws.close()
if let Err(err) = ews.ws.write_pending() {
::log::info!("error writing pending messages after closing ws: {}", err)
}
}
}
}
pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
use Either::{Left, Right};
match self.inner {
Left(EstablishedWs { ref mut ws, .. }) => ws.get_mut().deregister(poll),
Right(HandshakeMachine::TcpStream(ref mut stream)) => {
poll.registry().deregister(stream)
}
Right(HandshakeMachine::TlsMidHandshake(ref mut handshake)) => {
poll.registry().deregister(handshake.get_mut())
}
Right(HandshakeMachine::TlsStream(ref mut stream)) => {
poll.registry().deregister(stream.get_mut())
}
Right(HandshakeMachine::WsMidHandshake(ref mut handshake)) => {
handshake.get_mut().get_mut().deregister(poll)
}
}
}
}
pub type ConnectionMap = HashMap<Token, Connection>;

View file

@ -1,331 +0,0 @@
use std::io::ErrorKind;
use std::time::Duration;
use std::vec::Drain;
use aquatic_common::access_list::AccessListQuery;
use crossbeam_channel::Receiver;
use hashbrown::HashMap;
use log::{debug, error, info};
use mio::net::TcpListener;
use mio::{Events, Interest, Poll, Token};
use native_tls::TlsAcceptor;
use tungstenite::protocol::WebSocketConfig;
use aquatic_common::convert_ipv4_mapped_ipv6;
use aquatic_ws_protocol::*;
use crate::common::*;
use crate::config::Config;
pub mod connection;
pub mod utils;
use connection::*;
use utils::*;
pub fn run_socket_worker(
config: Config,
state: State,
socket_worker_index: usize,
socket_worker_statuses: SocketWorkerStatuses,
poll: Poll,
in_message_sender: InMessageSender,
out_message_receiver: OutMessageReceiver,
opt_tls_acceptor: Option<TlsAcceptor>,
) {
match create_listener(&config) {
Ok(listener) => {
socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(()));
run_poll_loop(
config,
&state,
socket_worker_index,
poll,
in_message_sender,
out_message_receiver,
listener,
opt_tls_acceptor,
);
}
Err(err) => {
socket_worker_statuses.lock()[socket_worker_index] =
Some(Err(format!("Couldn't open socket: {:#}", err)));
}
}
}
pub fn run_poll_loop(
config: Config,
state: &State,
socket_worker_index: usize,
mut poll: Poll,
in_message_sender: InMessageSender,
out_message_receiver: OutMessageReceiver,
listener: ::std::net::TcpListener,
opt_tls_acceptor: Option<TlsAcceptor>,
) {
let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds);
let ws_config = WebSocketConfig {
max_message_size: Some(config.network.websocket_max_message_size),
max_frame_size: Some(config.network.websocket_max_frame_size),
max_send_queue: None,
..Default::default()
};
let mut listener = TcpListener::from_std(listener);
let mut events = Events::with_capacity(config.network.poll_event_capacity);
poll.registry()
.register(&mut listener, LISTENER_TOKEN, Interest::READABLE)
.unwrap();
let mut connections: ConnectionMap = HashMap::new();
let mut local_responses = Vec::new();
let mut poll_token_counter = Token(0usize);
let mut iter_counter = 0usize;
loop {
poll.poll(&mut events, Some(poll_timeout))
.expect("failed polling");
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
for event in events.iter() {
let token = event.token();
if token == LISTENER_TOKEN {
accept_new_streams(
ws_config,
&mut listener,
&mut poll,
&mut connections,
valid_until,
&mut poll_token_counter,
);
} else if token != CHANNEL_TOKEN {
run_handshakes_and_read_messages(
&config,
state,
socket_worker_index,
&mut local_responses,
&in_message_sender,
&opt_tls_acceptor,
&mut poll,
&mut connections,
token,
valid_until,
);
}
send_out_messages(
&mut poll,
local_responses.drain(..),
&out_message_receiver,
&mut connections,
);
}
// Remove inactive connections, but not every iteration
if iter_counter % 128 == 0 {
remove_inactive_connections(&mut connections);
}
iter_counter = iter_counter.wrapping_add(1);
}
}
fn accept_new_streams(
ws_config: WebSocketConfig,
listener: &mut TcpListener,
poll: &mut Poll,
connections: &mut ConnectionMap,
valid_until: ValidUntil,
poll_token_counter: &mut Token,
) {
loop {
match listener.accept() {
Ok((mut stream, _)) => {
poll_token_counter.0 = poll_token_counter.0.wrapping_add(1);
if poll_token_counter.0 < 2 {
poll_token_counter.0 = 2;
}
let token = *poll_token_counter;
remove_connection_if_exists(poll, connections, token);
poll.registry()
.register(&mut stream, token, Interest::READABLE)
.unwrap();
let connection = Connection::new(ws_config, valid_until, stream);
connections.insert(token, connection);
}
Err(err) => {
if err.kind() == ErrorKind::WouldBlock {
break;
}
info!("error while accepting streams: {}", err);
}
}
}
}
/// On the stream given by poll_token, get TLS (if requested) and tungstenite
/// up and running, then read messages and pass on through channel.
pub fn run_handshakes_and_read_messages(
config: &Config,
state: &State,
socket_worker_index: usize,
local_responses: &mut Vec<(ConnectionMeta, OutMessage)>,
in_message_sender: &InMessageSender,
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS
poll: &mut Poll,
connections: &mut ConnectionMap,
poll_token: Token,
valid_until: ValidUntil,
) {
let access_list_mode = config.access_list.mode;
loop {
if let Some(established_ws) = connections
.get_mut(&poll_token)
.map(|c| {
// Ugly but works
c.valid_until = valid_until;
c
})
.and_then(Connection::get_established_ws)
{
use ::tungstenite::Error::Io;
match established_ws.ws.read_message() {
Ok(ws_message) => {
let naive_peer_addr = established_ws.peer_addr;
let converted_peer_ip = convert_ipv4_mapped_ipv6(naive_peer_addr.ip());
let meta = ConnectionMeta {
worker_index: socket_worker_index,
poll_token,
naive_peer_addr,
converted_peer_ip,
};
debug!("read message");
match InMessage::from_ws_message(ws_message) {
Ok(InMessage::AnnounceRequest(ref request))
if !state
.access_list
.allows(access_list_mode, &request.info_hash.0) =>
{
let out_message = OutMessage::ErrorResponse(ErrorResponse {
failure_reason: "Info hash not allowed".into(),
action: Some(ErrorResponseAction::Announce),
info_hash: Some(request.info_hash),
});
local_responses.push((meta, out_message));
}
Ok(in_message) => {
if let Err(err) = in_message_sender.send((meta, in_message)) {
error!("InMessageSender: couldn't send message: {:?}", err);
}
}
Err(_) => {
// FIXME: maybe this condition just occurs when enough data hasn't been recevied?
/*
info!("error parsing message: {:?}", err);
let out_message = OutMessage::ErrorResponse(ErrorResponse {
failure_reason: "Error parsing message".into(),
action: None,
info_hash: None,
});
local_responses.push((meta, out_message));
*/
}
}
}
Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => {
break;
}
Err(tungstenite::Error::ConnectionClosed) => {
remove_connection_if_exists(poll, connections, poll_token);
break;
}
Err(err) => {
info!("error reading messages: {}", err);
remove_connection_if_exists(poll, connections, poll_token);
break;
}
}
} else if let Some(connection) = connections.remove(&poll_token) {
let (opt_new_connection, stop_loop) =
connection.advance_handshakes(opt_tls_acceptor, valid_until);
if let Some(connection) = opt_new_connection {
connections.insert(poll_token, connection);
}
if stop_loop {
break;
}
} else {
break;
}
}
}
/// Read messages from channel, send to peers
pub fn send_out_messages(
poll: &mut Poll,
local_responses: Drain<(ConnectionMeta, OutMessage)>,
out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>,
connections: &mut ConnectionMap,
) {
let len = out_message_receiver.len();
for (meta, out_message) in local_responses.chain(out_message_receiver.try_iter().take(len)) {
let opt_established_ws = connections
.get_mut(&meta.poll_token)
.and_then(Connection::get_established_ws);
if let Some(established_ws) = opt_established_ws {
if established_ws.peer_addr != meta.naive_peer_addr {
info!("socket worker error: peer socket addrs didn't match");
continue;
}
use ::tungstenite::Error::Io;
let ws_message = out_message.to_ws_message();
match established_ws.ws.write_message(ws_message) {
Ok(()) => {
debug!("sent message");
}
Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => {}
Err(tungstenite::Error::ConnectionClosed) => {
remove_connection_if_exists(poll, connections, meta.poll_token);
}
Err(err) => {
info!("error writing ws message: {}", err);
remove_connection_if_exists(poll, connections, meta.poll_token);
}
}
}
}
}

View file

@ -1,66 +0,0 @@
use std::time::Instant;
use anyhow::Context;
use mio::{Poll, Token};
use socket2::{Domain, Protocol, Socket, Type};
use crate::config::Config;
use super::connection::*;
pub fn create_listener(config: &Config) -> ::anyhow::Result<::std::net::TcpListener> {
let builder = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))
} else {
Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))
}
.context("Couldn't create socket2::Socket")?;
if config.network.ipv6_only {
builder
.set_only_v6(true)
.context("Couldn't put socket in ipv6 only mode")?
}
builder
.set_nonblocking(true)
.context("Couldn't put socket in non-blocking mode")?;
builder
.set_reuse_port(true)
.context("Couldn't put socket in reuse_port mode")?;
builder
.bind(&config.network.address.into())
.with_context(|| format!("Couldn't bind socket to address {}", config.network.address))?;
builder
.listen(128)
.context("Couldn't listen for connections on socket")?;
Ok(builder.into())
}
pub fn remove_connection_if_exists(poll: &mut Poll, connections: &mut ConnectionMap, token: Token) {
if let Some(mut connection) = connections.remove(&token) {
connection.close();
if let Err(err) = connection.deregister(poll) {
::log::error!("couldn't deregister stream: {}", err);
}
}
}
// Close and remove inactive connections
pub fn remove_inactive_connections(connections: &mut ConnectionMap) {
let now = Instant::now();
connections.retain(|_, connection| {
if connection.valid_until.0 < now {
connection.close();
false
} else {
true
}
});
connections.shrink_to_fit();
}

View file

@ -14,39 +14,3 @@ pub fn update_access_list(config: &Config, state: &State) {
AccessListMode::Off => {}
}
}
pub fn print_statistics(state: &State) {
let mut peers_per_torrent = Histogram::new();
{
let torrents = &mut state.torrent_maps.lock();
for torrent in torrents.ipv4.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
for torrent in torrents.ipv6.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
}
if peers_per_torrent.entries() != 0 {
println!(
"peers per torrent: min: {}, p50: {}, p75: {}, p90: {}, p99: {}, p999: {}, max: {}",
peers_per_torrent.minimum().unwrap(),
peers_per_torrent.percentile(50.0).unwrap(),
peers_per_torrent.percentile(75.0).unwrap(),
peers_per_torrent.percentile(90.0).unwrap(),
peers_per_torrent.percentile(99.0).unwrap(),
peers_per_torrent.percentile(99.9).unwrap(),
peers_per_torrent.maximum().unwrap(),
);
}
}

View file

@ -11,16 +11,19 @@ name = "aquatic_ws_load_test"
[dependencies]
anyhow = "1"
async-tungstenite = "0.15"
aquatic_cli_helpers = "0.1.0"
aquatic_ws_protocol = "0.1.0"
futures = "0.3"
futures-rustls = "0.22"
glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5" }
hashbrown = { version = "0.11.2", features = ["serde"] }
mimalloc = { version = "0.1", default-features = false }
mio = { version = "0.7", features = ["udp", "os-poll", "os-util"] }
rand = { version = "0.8", features = ["small_rng"] }
rand_distr = "0.4"
rustls = { version = "0.20", features = ["dangerous_configuration"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
slab = "0.4"
tungstenite = "0.15"
[dev-dependencies]

View file

@ -9,20 +9,11 @@ pub struct Config {
pub num_workers: u8,
pub num_connections: usize,
pub duration: usize,
pub network: NetworkConfig,
pub torrents: TorrentConfig,
}
impl aquatic_cli_helpers::Config for Config {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct NetworkConfig {
pub connection_creation_interval: usize,
pub poll_timeout_microseconds: u64,
pub poll_event_capacity: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct TorrentConfig {
@ -49,22 +40,11 @@ impl Default for Config {
num_workers: 1,
num_connections: 16,
duration: 0,
network: NetworkConfig::default(),
torrents: TorrentConfig::default(),
}
}
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
connection_creation_interval: 40,
poll_timeout_microseconds: 1000,
poll_event_capacity: 4096,
}
}
}
impl Default for TorrentConfig {
fn default() -> Self {
Self {

View file

@ -2,6 +2,7 @@ use std::sync::{atomic::Ordering, Arc};
use std::thread;
use std::time::{Duration, Instant};
use glommio::LocalExecutorBuilder;
use rand::prelude::*;
use rand_distr::Pareto;
@ -48,11 +49,18 @@ fn run(config: Config) -> ::anyhow::Result<()> {
pareto: Arc::new(pareto),
};
let tls_config = create_tls_config().unwrap();
for _ in 0..config.num_workers {
let config = config.clone();
let tls_config = tls_config.clone();
let state = state.clone();
thread::spawn(move || run_socket_thread(&config, state));
LocalExecutorBuilder::default()
.spawn(|| async move {
run_socket_thread(config, tls_config, state).await.unwrap();
})
.unwrap();
}
monitor_statistics(state, &config);
@ -60,6 +68,36 @@ fn run(config: Config) -> ::anyhow::Result<()> {
Ok(())
}
struct FakeCertificateVerifier;
impl rustls::client::ServerCertVerifier for FakeCertificateVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::Certificate,
_intermediates: &[rustls::Certificate],
_server_name: &rustls::ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
Ok(rustls::client::ServerCertVerified::assertion())
}
}
fn create_tls_config() -> anyhow::Result<Arc<rustls::ClientConfig>> {
let mut config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth();
config
.dangerous()
.set_certificate_verifier(Arc::new(FakeCertificateVerifier));
Ok(Arc::new(config))
}
fn monitor_statistics(state: LoadTestState, config: &Config) {
let start_time = Instant::now();
let mut report_avg_response_vec: Vec<f64> = Vec::new();

View file

@ -1,307 +1,208 @@
use std::io::ErrorKind;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::{
cell::RefCell,
convert::TryInto,
rc::Rc,
sync::{atomic::Ordering, Arc},
time::Duration,
};
use hashbrown::HashMap;
use mio::{net::TcpStream, Events, Interest, Poll, Token};
use rand::{prelude::*, rngs::SmallRng};
use tungstenite::{handshake::MidHandshake, ClientHandshake, HandshakeError, WebSocket};
use aquatic_ws_protocol::{InMessage, JsonValue, OfferId, OutMessage, PeerId};
use async_tungstenite::{WebSocketStream, client_async};
use futures::{StreamExt, SinkExt};
use futures_rustls::{TlsConnector, client::TlsStream};
use glommio::net::TcpStream;
use glommio::{prelude::*, timer::TimerActionRepeat};
use rand::{Rng, SeedableRng, prelude::SmallRng};
use crate::common::*;
use crate::config::*;
use crate::utils::create_random_request;
use crate::{common::LoadTestState, config::Config, utils::create_random_request};
// Allow large enum variant WebSocket because it should be very common
#[allow(clippy::large_enum_variant)]
pub enum ConnectionState {
TcpStream(TcpStream),
WebSocket(WebSocket<TcpStream>),
MidHandshake(MidHandshake<ClientHandshake<TcpStream>>),
pub async fn run_socket_thread(
config: Config,
tls_config: Arc<rustls::ClientConfig>,
load_test_state: LoadTestState,
) -> anyhow::Result<()> {
let config = Rc::new(config);
let num_active_connections = Rc::new(RefCell::new(0usize));
TimerActionRepeat::repeat(move || {
periodically_open_connections(
config.clone(),
tls_config.clone(),
load_test_state.clone(),
num_active_connections.clone(),
)
});
futures::future::pending::<bool>().await;
Ok(())
}
impl ConnectionState {
fn advance(self, config: &Config) -> Option<Self> {
match self {
Self::TcpStream(stream) => {
let req = format!(
"ws://{}:{}",
config.server_address.ip(),
config.server_address.port()
);
match ::tungstenite::client(req, stream) {
Ok((ws, _)) => Some(ConnectionState::WebSocket(ws)),
Err(HandshakeError::Interrupted(handshake)) => {
Some(ConnectionState::MidHandshake(handshake))
}
Err(HandshakeError::Failure(err)) => {
eprintln!("handshake error: {:?}", err);
None
}
}
async fn periodically_open_connections(
config: Rc<Config>,
tls_config: Arc<rustls::ClientConfig>,
load_test_state: LoadTestState,
num_active_connections: Rc<RefCell<usize>>,
) -> Option<Duration> {
if *num_active_connections.borrow() < config.num_connections {
spawn_local(async move {
if let Err(err) =
Connection::run(config, tls_config, load_test_state, num_active_connections).await
{
eprintln!("connection creation error: {:?}", err);
}
Self::MidHandshake(handshake) => match handshake.handshake() {
Ok((ws, _)) => Some(ConnectionState::WebSocket(ws)),
Err(HandshakeError::Interrupted(handshake)) => {
Some(ConnectionState::MidHandshake(handshake))
}
Err(HandshakeError::Failure(err)) => {
eprintln!("handshake error: {:?}", err);
None
}
},
Self::WebSocket(ws) => Some(Self::WebSocket(ws)),
}
})
.detach();
}
Some(Duration::from_secs(1))
}
pub struct Connection {
stream: ConnectionState,
peer_id: PeerId,
struct Connection {
config: Rc<Config>,
load_test_state: LoadTestState,
rng: SmallRng,
can_send: bool,
peer_id: PeerId,
send_answer: Option<(PeerId, OfferId)>,
stream: WebSocketStream<TlsStream<TcpStream>>,
}
impl Connection {
pub fn create_and_register(
config: &Config,
rng: &mut impl Rng,
connections: &mut ConnectionMap,
poll: &mut Poll,
token_counter: &mut usize,
async fn run(
config: Rc<Config>,
tls_config: Arc<rustls::ClientConfig>,
load_test_state: LoadTestState,
num_active_connections: Rc<RefCell<usize>>,
) -> anyhow::Result<()> {
let mut stream = TcpStream::connect(config.server_address)?;
let mut rng = SmallRng::from_entropy();
let peer_id = PeerId(rng.gen());
let stream = TcpStream::connect(config.server_address)
.await
.map_err(|err| anyhow::anyhow!("connect: {:?}", err))?;
let stream = TlsConnector::from(tls_config).connect("example.com".try_into().unwrap(), stream).await?;
let request = format!(
"ws://{}:{}",
config.server_address.ip(),
config.server_address.port()
);
let (stream, _) = client_async(request, stream).await?;
poll.registry()
.register(
&mut stream,
Token(*token_counter),
Interest::READABLE | Interest::WRITABLE,
)
.unwrap();
let connection = Connection {
stream: ConnectionState::TcpStream(stream),
peer_id: PeerId(rng.gen()),
can_send: false,
let mut connection = Connection {
config,
load_test_state,
rng,
stream,
can_send: true,
peer_id,
send_answer: None,
};
connections.insert(*token_counter, connection);
*num_active_connections.borrow_mut() += 1;
*token_counter += 1;
println!("run connection");
if let Err(err) = connection.run_connection_loop().await {
eprintln!("connection error: {:?}", err);
}
*num_active_connections.borrow_mut() -= 1;
Ok(())
}
pub fn advance(self, config: &Config) -> Option<Self> {
if let Some(stream) = self.stream.advance(config) {
let can_send = matches!(stream, ConnectionState::WebSocket(_));
async fn run_connection_loop(&mut self) -> anyhow::Result<()> {
loop {
if self.can_send {
let request =
create_random_request(&self.config, &self.load_test_state, &mut self.rng, self.peer_id);
Some(Self {
stream,
peer_id: self.peer_id,
can_send,
send_answer: None,
})
} else {
None
// If self.send_answer is set and request is announce request, make
// the request an offer answer
let request = if let InMessage::AnnounceRequest(mut r) = request {
if let Some((peer_id, offer_id)) = self.send_answer {
r.to_peer_id = Some(peer_id);
r.offer_id = Some(offer_id);
r.answer = Some(JsonValue(::serde_json::json!(
{"sdp": "abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-"}
)));
r.event = None;
r.offers = None;
}
self.send_answer = None;
InMessage::AnnounceRequest(r)
} else {
request
};
self.stream.send(request.to_ws_message()).await?;
self.stream.flush().await?;
self.load_test_state
.statistics
.requests
.fetch_add(1, Ordering::SeqCst);
self.can_send = false;
}
self.read_message().await?;
}
}
pub fn read_responses(&mut self, state: &LoadTestState) -> bool {
// bool = drop connection
if let ConnectionState::WebSocket(ref mut ws) = self.stream {
loop {
match ws.read_message() {
Ok(message) => match OutMessage::from_ws_message(message) {
Ok(OutMessage::Offer(offer)) => {
state
.statistics
.responses_offer
.fetch_add(1, Ordering::SeqCst);
async fn read_message(&mut self) -> anyhow::Result<()> {
match OutMessage::from_ws_message(self.stream.next().await.unwrap()?) {
Ok(OutMessage::Offer(offer)) => {
self.load_test_state
.statistics
.responses_offer
.fetch_add(1, Ordering::SeqCst);
self.send_answer = Some((offer.peer_id, offer.offer_id));
self.send_answer = Some((offer.peer_id, offer.offer_id));
self.can_send = true;
}
Ok(OutMessage::Answer(_)) => {
state
.statistics
.responses_answer
.fetch_add(1, Ordering::SeqCst);
self.can_send = true;
}
Ok(OutMessage::Answer(_)) => {
self.load_test_state
.statistics
.responses_answer
.fetch_add(1, Ordering::SeqCst);
self.can_send = true;
}
Ok(OutMessage::AnnounceResponse(_)) => {
state
.statistics
.responses_announce
.fetch_add(1, Ordering::SeqCst);
self.can_send = true;
}
Ok(OutMessage::AnnounceResponse(_)) => {
self.load_test_state
.statistics
.responses_announce
.fetch_add(1, Ordering::SeqCst);
self.can_send = true;
}
Ok(OutMessage::ScrapeResponse(_)) => {
state
.statistics
.responses_scrape
.fetch_add(1, Ordering::SeqCst);
self.can_send = true;
}
Ok(OutMessage::ScrapeResponse(_)) => {
self.load_test_state
.statistics
.responses_scrape
.fetch_add(1, Ordering::SeqCst);
self.can_send = true;
}
Ok(OutMessage::ErrorResponse(response)) => {
state
.statistics
.responses_error
.fetch_add(1, Ordering::SeqCst);
self.can_send = true;
}
Ok(OutMessage::ErrorResponse(response)) => {
self.load_test_state
.statistics
.responses_error
.fetch_add(1, Ordering::SeqCst);
eprintln!("received error response: {:?}", response.failure_reason);
eprintln!("received error response: {:?}", response.failure_reason);
self.can_send = true;
}
Err(err) => {
eprintln!("error deserializing offer: {:?}", err);
}
},
Err(tungstenite::Error::Io(err)) if err.kind() == ErrorKind::WouldBlock => {
return false;
}
Err(_) => {
return true;
}
}
self.can_send = true;
}
Err(err) => {
eprintln!("error deserializing offer: {:?}", err);
}
}
false
}
pub fn send_request(
&mut self,
config: &Config,
state: &LoadTestState,
rng: &mut impl Rng,
) -> bool {
// bool = remove connection
if !self.can_send {
return false;
}
if let ConnectionState::WebSocket(ref mut ws) = self.stream {
let request = create_random_request(&config, &state, rng, self.peer_id);
// If self.send_answer is set and request is announce request, make
// the request an offer answer
let request = if let InMessage::AnnounceRequest(mut r) = request {
if let Some((peer_id, offer_id)) = self.send_answer {
r.to_peer_id = Some(peer_id);
r.offer_id = Some(offer_id);
r.answer = Some(JsonValue(::serde_json::json!(
{"sdp": "abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-abcdefg-"}
)));
r.event = None;
r.offers = None;
}
self.send_answer = None;
InMessage::AnnounceRequest(r)
} else {
request
};
match ws.write_message(request.to_ws_message()) {
Ok(()) => {
state.statistics.requests.fetch_add(1, Ordering::SeqCst);
self.can_send = false;
false
}
Err(tungstenite::Error::Io(err)) if err.kind() == ErrorKind::WouldBlock => false,
Err(_) => true,
}
} else {
println!("send request can't send to non-ws stream");
false
}
}
}
pub type ConnectionMap = HashMap<usize, Connection>;
pub fn run_socket_thread(config: &Config, state: LoadTestState) {
let timeout = Duration::from_micros(config.network.poll_timeout_microseconds);
let create_conn_interval = 2 ^ config.network.connection_creation_interval;
let mut connections: ConnectionMap = HashMap::with_capacity(config.num_connections);
let mut poll = Poll::new().expect("create poll");
let mut events = Events::with_capacity(config.network.poll_event_capacity);
let mut rng = SmallRng::from_entropy();
let mut token_counter = 0usize;
let mut iter_counter = 0usize;
let mut drop_keys = Vec::new();
loop {
poll.poll(&mut events, Some(timeout))
.expect("failed polling");
for event in events.iter() {
let token = event.token();
if event.is_readable() {
if let Some(connection) = connections.get_mut(&token.0) {
if let ConnectionState::WebSocket(_) = connection.stream {
let drop_connection = connection.read_responses(&state);
if drop_connection {
connections.remove(&token.0);
}
continue;
}
}
}
if let Some(connection) = connections.remove(&token.0) {
if let Some(connection) = connection.advance(config) {
connections.insert(token.0, connection);
}
}
}
for (k, connection) in connections.iter_mut() {
let drop_connection = connection.send_request(config, &state, &mut rng);
if drop_connection {
drop_keys.push(*k)
}
}
for k in drop_keys.drain(..) {
connections.remove(&k);
}
// Slowly create new connections
if connections.len() < config.num_connections && iter_counter % create_conn_interval == 0 {
let res = Connection::create_and_register(
config,
&mut rng,
&mut connections,
&mut poll,
&mut token_counter,
);
if let Err(err) = res {
eprintln!("create connection error: {}", err);
}
}
iter_counter = iter_counter.wrapping_add(1);
Ok(())
}
}

View file

@ -23,14 +23,16 @@ impl InMessage {
#[inline]
pub fn from_ws_message(ws_message: tungstenite::Message) -> ::anyhow::Result<Self> {
use tungstenite::Message::Text;
use tungstenite::Message;
let mut text = if let Text(text) = ws_message {
text
} else {
return Err(anyhow::anyhow!("Message is not text"));
};
return ::simd_json::serde::from_str(&mut text).context("deserialize with serde");
match ws_message {
Message::Text(mut text) => {
::simd_json::serde::from_str(&mut text).context("deserialize with serde")
}
Message::Binary(mut bytes) => {
::simd_json::serde::from_slice(&mut bytes[..]).context("deserialize with serde")
}
_ => Err(anyhow::anyhow!("Message is neither text nor binary")),
}
}
}