aquatic_http: remove mio implementation

This commit is contained in:
Joakim Frostegård 2021-10-28 01:48:32 +02:00
parent a935a3bbf0
commit 130377b8f4
19 changed files with 394 additions and 1587 deletions

5
Cargo.lock generated
View file

@ -94,19 +94,15 @@ dependencies = [
"aquatic_http_protocol",
"cfg-if",
"core_affinity",
"crossbeam-channel",
"either",
"futures-lite",
"glommio",
"hashbrown 0.11.2",
"histogram",
"indexmap",
"itoa",
"log",
"memchr",
"mimalloc",
"mio",
"native-tls",
"parking_lot",
"privdrop",
"quickcheck",
@ -117,7 +113,6 @@ dependencies = [
"serde",
"slab",
"smartstring",
"socket2 0.4.2",
]
[[package]]

View file

@ -1,8 +1,7 @@
# TODO
* aquatic_http glommio:
* remove mio version
* get rid of / improve ConnectionMeta stuff in handler
* get rid of / improve ConnectionMeta stuff in handler
* clean out connections regularly
* timeout inside of task for "it took to long to receive request, send response"?
* handle panicked/cancelled tasks

View file

@ -15,11 +15,6 @@ path = "src/lib/lib.rs"
name = "aquatic_http"
path = "src/bin/main.rs"
[features]
default = ["with-glommio"]
with-glommio = ["glommio", "futures-lite", "rustls", "rustls-pemfile", "slab"]
with-mio = ["crossbeam-channel", "histogram", "mio", "native-tls", "socket2"]
[dependencies]
anyhow = "1"
aquatic_cli_helpers = "0.1.0"
@ -28,6 +23,8 @@ aquatic_http_protocol = "0.1.0"
cfg-if = "1"
core_affinity = "0.5"
either = "1"
futures-lite = "1"
glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5" }
hashbrown = "0.11.2"
indexmap = "1"
itoa = "0.4"
@ -37,23 +34,12 @@ memchr = "2"
parking_lot = "0.11"
privdrop = "0.5"
rand = { version = "0.8", features = ["small_rng"] }
rustls = "0.20"
rustls-pemfile = "0.2"
serde = { version = "1", features = ["derive"] }
slab = "0.4"
smartstring = "0.2"
# mio
crossbeam-channel = { version = "0.5", optional = true }
histogram = { version = "0.6", optional = true }
mio = { version = "0.7", features = ["tcp", "os-poll", "os-util"], optional = true }
native-tls = { version = "0.2", optional = true }
socket2 = { version = "0.4.1", features = ["all"], optional = true }
# glommio
futures-lite = { version = "1", optional = true }
glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5", optional = true }
rustls = { version = "0.20", optional = true }
rustls-pemfile = { version = "0.2", optional = true }
slab = { version = "0.4", optional = true }
[dev-dependencies]
quickcheck = "1.0"
quickcheck_macros = "1.0"

View file

@ -14,7 +14,112 @@ use aquatic_http_protocol::response::ResponsePeer;
use crate::config::Config;
pub mod handlers;
use std::borrow::Borrow;
use std::cell::RefCell;
use std::rc::Rc;
use futures_lite::AsyncBufReadExt;
use glommio::io::{BufferedFile, StreamReaderBuilder};
use glommio::prelude::*;
use aquatic_http_protocol::{
request::{AnnounceRequest, ScrapeRequest},
response::{AnnounceResponse, ScrapeResponse},
};
#[derive(Copy, Clone, Debug)]
pub struct ConsumerId(pub usize);
#[derive(Clone, Copy, Debug)]
pub struct ConnectionId(pub usize);
#[derive(Debug)]
pub enum ChannelRequest {
Announce {
request: AnnounceRequest,
peer_addr: SocketAddr,
connection_id: ConnectionId,
response_consumer_id: ConsumerId,
},
Scrape {
request: ScrapeRequest,
peer_addr: SocketAddr,
connection_id: ConnectionId,
response_consumer_id: ConsumerId,
},
}
#[derive(Debug)]
pub enum ChannelResponse {
Announce {
response: AnnounceResponse,
peer_addr: SocketAddr,
connection_id: ConnectionId,
},
Scrape {
response: ScrapeResponse,
peer_addr: SocketAddr,
connection_id: ConnectionId,
},
}
impl ChannelResponse {
pub fn get_connection_id(&self) -> ConnectionId {
match self {
Self::Announce { connection_id, .. } => *connection_id,
Self::Scrape { connection_id, .. } => *connection_id,
}
}
pub fn get_peer_addr(&self) -> SocketAddr {
match self {
Self::Announce { peer_addr, .. } => *peer_addr,
Self::Scrape { peer_addr, .. } => *peer_addr,
}
}
}
pub async fn update_access_list<C: Borrow<Config>>(
config: C,
access_list: Rc<RefCell<AccessList>>,
) {
if config.borrow().access_list.mode.is_on() {
match BufferedFile::open(&config.borrow().access_list.path).await {
Ok(file) => {
let mut reader = StreamReaderBuilder::new(file).build();
let mut new_access_list = AccessList::default();
loop {
let mut buf = String::with_capacity(42);
match reader.read_line(&mut buf).await {
Ok(_) => {
if let Err(err) = new_access_list.insert_from_line(&buf) {
::log::error!(
"Couln't parse access list line '{}': {:?}",
buf,
err
);
}
}
Err(err) => {
::log::error!("Couln't read access list line {:?}", err);
break;
}
}
yield_if_needed().await;
}
*access_list.borrow_mut() = new_access_list;
}
Err(err) => {
::log::error!("Couldn't open access list file: {:?}", err)
}
};
}
}
pub trait Ip: ::std::fmt::Debug + Copy + Eq + ::std::hash::Hash {}
@ -186,4 +291,4 @@ mod tests {
assert_eq!(f(101), 3);
assert_eq!(f(1000), 4);
}
}
}

View file

@ -1,109 +0,0 @@
use std::borrow::Borrow;
use std::cell::RefCell;
use std::net::SocketAddr;
use std::rc::Rc;
use aquatic_common::access_list::AccessList;
use futures_lite::AsyncBufReadExt;
use glommio::io::{BufferedFile, StreamReaderBuilder};
use glommio::prelude::*;
use aquatic_http_protocol::{
request::{AnnounceRequest, ScrapeRequest},
response::{AnnounceResponse, ScrapeResponse},
};
use crate::config::Config;
#[derive(Copy, Clone, Debug)]
pub struct ConsumerId(pub usize);
#[derive(Clone, Copy, Debug)]
pub struct ConnectionId(pub usize);
#[derive(Debug)]
pub enum ChannelRequest {
Announce {
request: AnnounceRequest,
peer_addr: SocketAddr,
connection_id: ConnectionId,
response_consumer_id: ConsumerId,
},
Scrape {
request: ScrapeRequest,
peer_addr: SocketAddr,
connection_id: ConnectionId,
response_consumer_id: ConsumerId,
},
}
#[derive(Debug)]
pub enum ChannelResponse {
Announce {
response: AnnounceResponse,
peer_addr: SocketAddr,
connection_id: ConnectionId,
},
Scrape {
response: ScrapeResponse,
peer_addr: SocketAddr,
connection_id: ConnectionId,
},
}
impl ChannelResponse {
pub fn get_connection_id(&self) -> ConnectionId {
match self {
Self::Announce { connection_id, .. } => *connection_id,
Self::Scrape { connection_id, .. } => *connection_id,
}
}
pub fn get_peer_addr(&self) -> SocketAddr {
match self {
Self::Announce { peer_addr, .. } => *peer_addr,
Self::Scrape { peer_addr, .. } => *peer_addr,
}
}
}
pub async fn update_access_list<C: Borrow<Config>>(
config: C,
access_list: Rc<RefCell<AccessList>>,
) {
if config.borrow().access_list.mode.is_on() {
match BufferedFile::open(&config.borrow().access_list.path).await {
Ok(file) => {
let mut reader = StreamReaderBuilder::new(file).build();
let mut new_access_list = AccessList::default();
loop {
let mut buf = String::with_capacity(42);
match reader.read_line(&mut buf).await {
Ok(_) => {
if let Err(err) = new_access_list.insert_from_line(&buf) {
::log::error!(
"Couln't parse access list line '{}': {:?}",
buf,
err
);
}
}
Err(err) => {
::log::error!("Couln't read access list line {:?}", err);
break;
}
}
yield_if_needed().await;
}
*access_list.borrow_mut() = new_access_list;
}
Err(err) => {
::log::error!("Couldn't open access list file: {:?}", err)
}
};
}
}

View file

@ -1,149 +0,0 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::time::Duration;
use aquatic_common::access_list::AccessList;
use futures_lite::{Stream, StreamExt};
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::timer::TimerActionRepeat;
use glommio::{enclose, prelude::*};
use rand::prelude::SmallRng;
use rand::SeedableRng;
use crate::common::handlers::handle_announce_request;
use crate::common::handlers::*;
use crate::common::*;
use crate::config::Config;
use super::common::*;
pub async fn run_request_worker(
config: Config,
request_mesh_builder: MeshBuilder<ChannelRequest, Partial>,
response_mesh_builder: MeshBuilder<ChannelResponse, Partial>,
access_list: AccessList,
) {
let (_, mut request_receivers) = request_mesh_builder.join(Role::Consumer).await.unwrap();
let (response_senders, _) = response_mesh_builder.join(Role::Producer).await.unwrap();
let response_senders = Rc::new(response_senders);
let torrents = Rc::new(RefCell::new(TorrentMaps::default()));
let access_list = Rc::new(RefCell::new(access_list));
// Periodically clean torrents and update access list
TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || {
enclose!((config, torrents, access_list) move || async move {
update_access_list(&config, access_list.clone()).await;
torrents.borrow_mut().clean(&config, &*access_list.borrow());
Some(Duration::from_secs(config.cleaning.interval))
})()
}));
let mut handles = Vec::new();
for (_, receiver) in request_receivers.streams() {
let handle = spawn_local(handle_request_stream(
config.clone(),
torrents.clone(),
response_senders.clone(),
receiver,
))
.detach();
handles.push(handle);
}
for handle in handles {
handle.await;
}
}
async fn handle_request_stream<S>(
config: Config,
torrents: Rc<RefCell<TorrentMaps>>,
response_senders: Rc<Senders<ChannelResponse>>,
mut stream: S,
) where
S: Stream<Item = ChannelRequest> + ::std::marker::Unpin,
{
let mut rng = SmallRng::from_entropy();
let max_peer_age = config.cleaning.max_peer_age;
let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age)));
TimerActionRepeat::repeat(enclose!((peer_valid_until) move || {
enclose!((peer_valid_until) move || async move {
*peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age);
Some(Duration::from_secs(1))
})()
}));
while let Some(channel_request) = stream.next().await {
let (response, consumer_id) = match channel_request {
ChannelRequest::Announce {
request,
peer_addr,
response_consumer_id,
connection_id,
} => {
let meta = ConnectionMeta {
worker_index: response_consumer_id.0,
poll_token: connection_id.0,
peer_addr,
};
let response = handle_announce_request(
&config,
&mut rng,
&mut torrents.borrow_mut(),
peer_valid_until.borrow().to_owned(),
meta,
request,
);
let response = ChannelResponse::Announce {
response,
peer_addr,
connection_id,
};
(response, response_consumer_id)
}
ChannelRequest::Scrape {
request,
peer_addr,
response_consumer_id,
connection_id,
} => {
let meta = ConnectionMeta {
worker_index: response_consumer_id.0,
poll_token: connection_id.0,
peer_addr,
};
let response =
handle_scrape_request(&config, &mut torrents.borrow_mut(), meta, request);
let response = ChannelResponse::Scrape {
response,
peer_addr,
connection_id,
};
(response, response_consumer_id)
}
};
::log::debug!("preparing to send response to channel: {:?}", response);
if let Err(err) = response_senders.try_send_to(consumer_id.0, response) {
::log::warn!("response_sender.try_send: {:?}", err);
}
yield_if_needed().await;
}
}

View file

@ -1,140 +0,0 @@
use std::{
fs::File,
io::BufReader,
sync::{atomic::AtomicUsize, Arc},
};
use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding};
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
use crate::config::Config;
mod common;
mod handlers;
mod network;
const SHARED_CHANNEL_SIZE: usize = 1024;
pub fn run(config: Config) -> anyhow::Result<()> {
if config.cpu_pinning.active {
core_affinity::set_for_current(core_affinity::CoreId {
id: config.cpu_pinning.offset,
});
}
let access_list = if config.access_list.mode.is_on() {
AccessList::create_from_path(&config.access_list.path).expect("Load access list")
} else {
AccessList::default()
};
let num_peers = config.socket_workers + config.request_workers;
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let num_bound_sockets = Arc::new(AtomicUsize::new(0));
let tls_config = Arc::new(create_tls_config(&config).unwrap());
let mut executors = Vec::new();
for i in 0..(config.socket_workers) {
let config = config.clone();
let tls_config = tls_config.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let num_bound_sockets = num_bound_sockets.clone();
let access_list = access_list.clone();
let mut builder = LocalExecutorBuilder::default();
if config.cpu_pinning.active {
builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + i);
}
let executor = builder.spawn(|| async move {
network::run_socket_worker(
config,
tls_config,
request_mesh_builder,
response_mesh_builder,
num_bound_sockets,
access_list,
)
.await
});
executors.push(executor);
}
for i in 0..(config.request_workers) {
let config = config.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let access_list = access_list.clone();
let mut builder = LocalExecutorBuilder::default();
if config.cpu_pinning.active {
builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + config.socket_workers + i);
}
let executor = builder.spawn(|| async move {
handlers::run_request_worker(
config,
request_mesh_builder,
response_mesh_builder,
access_list,
)
.await
});
executors.push(executor);
}
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
for executor in executors {
executor
.expect("failed to spawn local executor")
.join()
.unwrap();
}
Ok(())
}
fn create_tls_config(config: &Config) -> anyhow::Result<rustls::ServerConfig> {
let certs = {
let f = File::open(&config.network.tls.tls_certificate_path)?;
let mut f = BufReader::new(f);
rustls_pemfile::certs(&mut f)?
.into_iter()
.map(|bytes| rustls::Certificate(bytes))
.collect()
};
let private_key = {
let f = File::open(&config.network.tls.tls_private_key_path)?;
let mut f = BufReader::new(f);
rustls_pemfile::pkcs8_private_keys(&mut f)?
.first()
.map(|bytes| rustls::PrivateKey(bytes.clone()))
.ok_or(anyhow::anyhow!("No private keys in file"))?
};
let tls_config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, private_key)?;
Ok(tls_config)
}

View file

@ -8,9 +8,152 @@ use aquatic_common::extract_response_peers;
use aquatic_http_protocol::request::*;
use aquatic_http_protocol::response::*;
use super::*;
use std::cell::RefCell;
use std::rc::Rc;
use std::time::Duration;
use aquatic_common::access_list::AccessList;
use futures_lite::{Stream, StreamExt};
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::timer::TimerActionRepeat;
use glommio::{enclose, prelude::*};
use rand::prelude::SmallRng;
use rand::SeedableRng;
use crate::common::*;
use crate::config::Config;
pub async fn run_request_worker(
config: Config,
request_mesh_builder: MeshBuilder<ChannelRequest, Partial>,
response_mesh_builder: MeshBuilder<ChannelResponse, Partial>,
access_list: AccessList,
) {
let (_, mut request_receivers) = request_mesh_builder.join(Role::Consumer).await.unwrap();
let (response_senders, _) = response_mesh_builder.join(Role::Producer).await.unwrap();
let response_senders = Rc::new(response_senders);
let torrents = Rc::new(RefCell::new(TorrentMaps::default()));
let access_list = Rc::new(RefCell::new(access_list));
// Periodically clean torrents and update access list
TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || {
enclose!((config, torrents, access_list) move || async move {
update_access_list(&config, access_list.clone()).await;
torrents.borrow_mut().clean(&config, &*access_list.borrow());
Some(Duration::from_secs(config.cleaning.interval))
})()
}));
let mut handles = Vec::new();
for (_, receiver) in request_receivers.streams() {
let handle = spawn_local(handle_request_stream(
config.clone(),
torrents.clone(),
response_senders.clone(),
receiver,
))
.detach();
handles.push(handle);
}
for handle in handles {
handle.await;
}
}
async fn handle_request_stream<S>(
config: Config,
torrents: Rc<RefCell<TorrentMaps>>,
response_senders: Rc<Senders<ChannelResponse>>,
mut stream: S,
) where
S: Stream<Item = ChannelRequest> + ::std::marker::Unpin,
{
let mut rng = SmallRng::from_entropy();
let max_peer_age = config.cleaning.max_peer_age;
let peer_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_peer_age)));
TimerActionRepeat::repeat(enclose!((peer_valid_until) move || {
enclose!((peer_valid_until) move || async move {
*peer_valid_until.borrow_mut() = ValidUntil::new(max_peer_age);
Some(Duration::from_secs(1))
})()
}));
while let Some(channel_request) = stream.next().await {
let (response, consumer_id) = match channel_request {
ChannelRequest::Announce {
request,
peer_addr,
response_consumer_id,
connection_id,
} => {
let meta = ConnectionMeta {
worker_index: response_consumer_id.0,
poll_token: connection_id.0,
peer_addr,
};
let response = handle_announce_request(
&config,
&mut rng,
&mut torrents.borrow_mut(),
peer_valid_until.borrow().to_owned(),
meta,
request,
);
let response = ChannelResponse::Announce {
response,
peer_addr,
connection_id,
};
(response, response_consumer_id)
}
ChannelRequest::Scrape {
request,
peer_addr,
response_consumer_id,
connection_id,
} => {
let meta = ConnectionMeta {
worker_index: response_consumer_id.0,
poll_token: connection_id.0,
peer_addr,
};
let response =
handle_scrape_request(&config, &mut torrents.borrow_mut(), meta, request);
let response = ChannelResponse::Scrape {
response,
peer_addr,
connection_id,
};
(response, response_consumer_id)
}
};
::log::debug!("preparing to send response to channel: {:?}", response);
if let Err(err) = response_senders.try_send_to(consumer_id.0, response) {
::log::warn!("response_sender.try_send: {:?}", err);
}
yield_if_needed().await;
}
}
pub fn handle_announce_request(
config: &Config,
rng: &mut impl Rng,
@ -214,4 +357,4 @@ pub fn handle_scrape_request(
};
response
}
}

View file

@ -1,21 +1,143 @@
use cfg_if::cfg_if;
use std::{
fs::File,
io::BufReader,
sync::{atomic::AtomicUsize, Arc},
};
use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding};
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
use crate::config::Config;
pub mod common;
pub mod config;
#[cfg(all(feature = "with-glommio", target_os = "linux"))]
pub mod glommio;
#[cfg(feature = "with-mio")]
pub mod mio;
mod common;
mod handlers;
mod network;
pub const APP_NAME: &str = "aquatic_http: HTTP/TLS BitTorrent tracker";
pub fn run(config: config::Config) -> ::anyhow::Result<()> {
cfg_if! {
if #[cfg(all(feature = "with-glommio", target_os = "linux"))] {
glommio::run(config)
} else {
mio::run(config)
}
const SHARED_CHANNEL_SIZE: usize = 1024;
pub fn run(config: Config) -> anyhow::Result<()> {
if config.cpu_pinning.active {
core_affinity::set_for_current(core_affinity::CoreId {
id: config.cpu_pinning.offset,
});
}
let access_list = if config.access_list.mode.is_on() {
AccessList::create_from_path(&config.access_list.path).expect("Load access list")
} else {
AccessList::default()
};
let num_peers = config.socket_workers + config.request_workers;
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let num_bound_sockets = Arc::new(AtomicUsize::new(0));
let tls_config = Arc::new(create_tls_config(&config).unwrap());
let mut executors = Vec::new();
for i in 0..(config.socket_workers) {
let config = config.clone();
let tls_config = tls_config.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let num_bound_sockets = num_bound_sockets.clone();
let access_list = access_list.clone();
let mut builder = LocalExecutorBuilder::default();
if config.cpu_pinning.active {
builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + i);
}
let executor = builder.spawn(|| async move {
network::run_socket_worker(
config,
tls_config,
request_mesh_builder,
response_mesh_builder,
num_bound_sockets,
access_list,
)
.await
});
executors.push(executor);
}
for i in 0..(config.request_workers) {
let config = config.clone();
let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone();
let access_list = access_list.clone();
let mut builder = LocalExecutorBuilder::default();
if config.cpu_pinning.active {
builder = builder.pin_to_cpu(config.cpu_pinning.offset + 1 + config.socket_workers + i);
}
let executor = builder.spawn(|| async move {
handlers::run_request_worker(
config,
request_mesh_builder,
response_mesh_builder,
access_list,
)
.await
});
executors.push(executor);
}
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
for executor in executors {
executor
.expect("failed to spawn local executor")
.join()
.unwrap();
}
Ok(())
}
fn create_tls_config(config: &Config) -> anyhow::Result<rustls::ServerConfig> {
let certs = {
let f = File::open(&config.network.tls.tls_certificate_path)?;
let mut f = BufReader::new(f);
rustls_pemfile::certs(&mut f)?
.into_iter()
.map(|bytes| rustls::Certificate(bytes))
.collect()
};
let private_key = {
let f = File::open(&config.network.tls.tls_private_key_path)?;
let mut f = BufReader::new(f);
rustls_pemfile::pkcs8_private_keys(&mut f)?
.first()
.map(|bytes| rustls::PrivateKey(bytes.clone()))
.ok_or(anyhow::anyhow!("No private keys in file"))?
};
let tls_config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, private_key)?;
Ok(tls_config)
}

View file

@ -1,57 +0,0 @@
use std::sync::Arc;
use aquatic_common::access_list::AccessListArcSwap;
use crossbeam_channel::{Receiver, Sender};
use log::error;
use mio::Token;
use parking_lot::Mutex;
pub use aquatic_common::{convert_ipv4_mapped_ipv6, ValidUntil};
use aquatic_http_protocol::request::Request;
use aquatic_http_protocol::response::Response;
use crate::common::*;
pub const LISTENER_TOKEN: Token = Token(0);
pub const CHANNEL_TOKEN: Token = Token(1);
#[derive(Clone)]
pub struct State {
pub access_list: Arc<AccessListArcSwap>,
pub torrent_maps: Arc<Mutex<TorrentMaps>>,
}
impl Default for State {
fn default() -> Self {
Self {
access_list: Arc::new(Default::default()),
torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())),
}
}
}
pub type RequestChannelSender = Sender<(ConnectionMeta, Request)>;
pub type RequestChannelReceiver = Receiver<(ConnectionMeta, Request)>;
pub type ResponseChannelReceiver = Receiver<(ConnectionMeta, Response)>;
#[derive(Clone)]
pub struct ResponseChannelSender {
senders: Vec<Sender<(ConnectionMeta, Response)>>,
}
impl ResponseChannelSender {
pub fn new(senders: Vec<Sender<(ConnectionMeta, Response)>>) -> Self {
Self { senders }
}
#[inline]
pub fn send(&self, meta: ConnectionMeta, message: Response) {
if let Err(err) = self.senders[meta.worker_index].send((meta, message)) {
error!("ResponseChannelSender: couldn't send message: {:?}", err);
}
}
}
pub type SocketWorkerStatus = Option<Result<(), String>>;
pub type SocketWorkerStatuses = Arc<Mutex<Vec<SocketWorkerStatus>>>;

View file

@ -1,97 +0,0 @@
use std::sync::Arc;
use std::time::Duration;
use mio::Waker;
use parking_lot::MutexGuard;
use rand::{rngs::SmallRng, SeedableRng};
use aquatic_http_protocol::request::*;
use aquatic_http_protocol::response::*;
use super::common::*;
use crate::common::handlers::{handle_announce_request, handle_scrape_request};
use crate::common::*;
use crate::config::Config;
pub fn run_request_worker(
config: Config,
state: State,
request_channel_receiver: RequestChannelReceiver,
response_channel_sender: ResponseChannelSender,
wakers: Vec<Arc<Waker>>,
) {
let mut wake_socket_workers: Vec<bool> = (0..config.socket_workers).map(|_| false).collect();
let mut announce_requests = Vec::new();
let mut scrape_requests = Vec::new();
let mut rng = SmallRng::from_entropy();
let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds);
loop {
let mut opt_torrent_maps: Option<MutexGuard<TorrentMaps>> = None;
// If torrent state mutex is locked, just keep collecting requests
// and process them later. This can happen with either multiple
// request workers or while cleaning is underway.
for i in 0..config.handlers.max_requests_per_iter {
let opt_in_message = if i == 0 {
request_channel_receiver.recv().ok()
} else {
request_channel_receiver.recv_timeout(timeout).ok()
};
match opt_in_message {
Some((meta, Request::Announce(r))) => {
announce_requests.push((meta, r));
}
Some((meta, Request::Scrape(r))) => {
scrape_requests.push((meta, r));
}
None => {
if let Some(torrent_guard) = state.torrent_maps.try_lock() {
opt_torrent_maps = Some(torrent_guard);
break;
}
}
}
}
let mut torrent_maps = opt_torrent_maps.unwrap_or_else(|| state.torrent_maps.lock());
let valid_until = ValidUntil::new(config.cleaning.max_peer_age);
for (meta, request) in announce_requests.drain(..) {
let response = handle_announce_request(
&config,
&mut rng,
&mut torrent_maps,
valid_until,
meta,
request,
);
response_channel_sender.send(meta, Response::Announce(response));
wake_socket_workers[meta.worker_index] = true;
}
for (meta, request) in scrape_requests.drain(..) {
let response = handle_scrape_request(&config, &mut torrent_maps, meta, request);
response_channel_sender.send(meta, Response::Scrape(response));
wake_socket_workers[meta.worker_index] = true;
}
for (worker_index, wake) in wake_socket_workers.iter_mut().enumerate() {
if *wake {
if let Err(err) = wakers[worker_index].wake() {
::log::error!("request handler couldn't wake poll: {:?}", err);
}
*wake = false;
}
}
}
}

View file

@ -1,150 +0,0 @@
use std::sync::Arc;
use std::thread::Builder;
use std::time::Duration;
use anyhow::Context;
use mio::{Poll, Waker};
use parking_lot::Mutex;
use privdrop::PrivDrop;
pub mod common;
pub mod handler;
pub mod network;
pub mod tasks;
use crate::config::Config;
use common::*;
use network::utils::create_tls_acceptor;
pub fn run(config: Config) -> anyhow::Result<()> {
let state = State::default();
tasks::update_access_list(&config, &state);
start_workers(config.clone(), state.clone())?;
loop {
::std::thread::sleep(Duration::from_secs(config.cleaning.interval));
tasks::update_access_list(&config, &state);
state
.torrent_maps
.lock()
.clean(&config, &state.access_list.load_full());
}
}
pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
let opt_tls_acceptor = create_tls_acceptor(&config.network.tls)?;
let (request_channel_sender, request_channel_receiver) = ::crossbeam_channel::unbounded();
let mut out_message_senders = Vec::new();
let mut wakers = Vec::new();
let socket_worker_statuses: SocketWorkerStatuses = {
let mut statuses = Vec::new();
for _ in 0..config.socket_workers {
statuses.push(None);
}
Arc::new(Mutex::new(statuses))
};
for i in 0..config.socket_workers {
let config = config.clone();
let state = state.clone();
let socket_worker_statuses = socket_worker_statuses.clone();
let request_channel_sender = request_channel_sender.clone();
let opt_tls_acceptor = opt_tls_acceptor.clone();
let poll = Poll::new().expect("create poll");
let waker = Arc::new(Waker::new(poll.registry(), CHANNEL_TOKEN).expect("create waker"));
let (response_channel_sender, response_channel_receiver) = ::crossbeam_channel::unbounded();
out_message_senders.push(response_channel_sender);
wakers.push(waker);
Builder::new()
.name(format!("socket-{:02}", i + 1))
.spawn(move || {
network::run_socket_worker(
config,
state,
i,
socket_worker_statuses,
request_channel_sender,
response_channel_receiver,
opt_tls_acceptor,
poll,
);
})?;
}
// Wait for socket worker statuses. On error from any, quit program.
// On success from all, drop privileges if corresponding setting is set
// and continue program.
loop {
::std::thread::sleep(::std::time::Duration::from_millis(10));
if let Some(statuses) = socket_worker_statuses.try_lock() {
for opt_status in statuses.iter() {
if let Some(Err(err)) = opt_status {
return Err(::anyhow::anyhow!(err.to_owned()));
}
}
if statuses.iter().all(Option::is_some) {
if config.privileges.drop_privileges {
PrivDrop::default()
.chroot(config.privileges.chroot_path.clone())
.user(config.privileges.user.clone())
.apply()
.context("Couldn't drop root privileges")?;
}
break;
}
}
}
let response_channel_sender = ResponseChannelSender::new(out_message_senders);
for i in 0..config.request_workers {
let config = config.clone();
let state = state.clone();
let request_channel_receiver = request_channel_receiver.clone();
let response_channel_sender = response_channel_sender.clone();
let wakers = wakers.clone();
Builder::new()
.name(format!("request-{:02}", i + 1))
.spawn(move || {
handler::run_request_worker(
config,
state,
request_channel_receiver,
response_channel_sender,
wakers,
);
})?;
}
if config.statistics.interval != 0 {
let state = state.clone();
let config = config.clone();
Builder::new()
.name("statistics".to_string())
.spawn(move || loop {
::std::thread::sleep(Duration::from_secs(config.statistics.interval));
tasks::print_statistics(&state);
})
.expect("spawn statistics thread");
}
Ok(())
}

View file

@ -1,260 +0,0 @@
use std::io::ErrorKind;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::sync::Arc;
use hashbrown::HashMap;
use mio::net::TcpStream;
use mio::{Poll, Token};
use native_tls::{MidHandshakeTlsStream, TlsAcceptor};
use aquatic_http_protocol::request::{Request, RequestParseError};
use crate::common::num_digits_in_usize;
use crate::mio::common::*;
use super::stream::Stream;
#[derive(Debug)]
pub enum RequestReadError {
NeedMoreData,
StreamEnded,
Parse(anyhow::Error),
Io(::std::io::Error),
}
pub struct EstablishedConnection {
stream: Stream,
pub peer_addr: SocketAddr,
buf: Vec<u8>,
bytes_read: usize,
}
impl EstablishedConnection {
#[inline]
fn new(stream: Stream) -> Self {
let peer_addr = stream.get_peer_addr();
Self {
stream,
peer_addr,
buf: Vec::new(),
bytes_read: 0,
}
}
pub fn read_request(&mut self) -> Result<Request, RequestReadError> {
if (self.buf.len() - self.bytes_read < 512) & (self.buf.len() <= 3072) {
self.buf.extend_from_slice(&[0; 1024]);
}
match self.stream.read(&mut self.buf[self.bytes_read..]) {
Ok(0) => {
self.clear_buffer();
return Err(RequestReadError::StreamEnded);
}
Ok(bytes_read) => {
self.bytes_read += bytes_read;
::log::debug!("read_request read {} bytes", bytes_read);
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
return Err(RequestReadError::NeedMoreData);
}
Err(err) => {
self.clear_buffer();
return Err(RequestReadError::Io(err));
}
}
match Request::from_bytes(&self.buf[..self.bytes_read]) {
Ok(request) => {
self.clear_buffer();
Ok(request)
}
Err(RequestParseError::NeedMoreData) => Err(RequestReadError::NeedMoreData),
Err(RequestParseError::Invalid(err)) => {
self.clear_buffer();
Err(RequestReadError::Parse(err))
}
}
}
pub fn send_response(&mut self, body: &[u8]) -> ::std::io::Result<()> {
let content_len = body.len() + 2; // 2 is for newlines at end
let content_len_num_digits = num_digits_in_usize(content_len);
let mut response = Vec::with_capacity(39 + content_len_num_digits + body.len());
response.extend_from_slice(b"HTTP/1.1 200 OK\r\nContent-Length: ");
::itoa::write(&mut response, content_len)?;
response.extend_from_slice(b"\r\n\r\n");
response.extend_from_slice(body);
response.extend_from_slice(b"\r\n");
let bytes_written = self.stream.write(&response)?;
if bytes_written != response.len() {
::log::error!(
"send_response: only {} out of {} bytes written",
bytes_written,
response.len()
);
}
self.stream.flush()?;
Ok(())
}
#[inline]
pub fn clear_buffer(&mut self) {
self.bytes_read = 0;
self.buf = Vec::new();
}
}
pub enum TlsHandshakeMachineError {
WouldBlock(TlsHandshakeMachine),
Failure(native_tls::Error),
}
enum TlsHandshakeMachineInner {
TcpStream(TcpStream),
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
}
pub struct TlsHandshakeMachine {
tls_acceptor: Arc<TlsAcceptor>,
inner: TlsHandshakeMachineInner,
}
impl<'a> TlsHandshakeMachine {
#[inline]
fn new(tls_acceptor: Arc<TlsAcceptor>, tcp_stream: TcpStream) -> Self {
Self {
tls_acceptor,
inner: TlsHandshakeMachineInner::TcpStream(tcp_stream),
}
}
/// Attempt to establish a TLS connection. On a WouldBlock error, return
/// the machine wrapped in an error for later attempts.
pub fn establish_tls(self) -> Result<EstablishedConnection, TlsHandshakeMachineError> {
let handshake_result = match self.inner {
TlsHandshakeMachineInner::TcpStream(stream) => self.tls_acceptor.accept(stream),
TlsHandshakeMachineInner::TlsMidHandshake(handshake) => handshake.handshake(),
};
match handshake_result {
Ok(stream) => {
let established = EstablishedConnection::new(Stream::TlsStream(stream));
::log::debug!("established tls connection");
Ok(established)
}
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
let inner = TlsHandshakeMachineInner::TlsMidHandshake(handshake);
let machine = Self {
tls_acceptor: self.tls_acceptor,
inner,
};
Err(TlsHandshakeMachineError::WouldBlock(machine))
}
Err(native_tls::HandshakeError::Failure(err)) => {
Err(TlsHandshakeMachineError::Failure(err))
}
}
}
}
enum ConnectionInner {
Established(EstablishedConnection),
InProgress(TlsHandshakeMachine),
}
pub struct Connection {
pub valid_until: ValidUntil,
inner: ConnectionInner,
}
impl Connection {
#[inline]
pub fn new(
opt_tls_acceptor: &Option<Arc<TlsAcceptor>>,
valid_until: ValidUntil,
tcp_stream: TcpStream,
) -> Self {
// Setup handshake machine if TLS is requested
let inner = if let Some(tls_acceptor) = opt_tls_acceptor {
ConnectionInner::InProgress(TlsHandshakeMachine::new(tls_acceptor.clone(), tcp_stream))
} else {
::log::debug!("established tcp connection");
ConnectionInner::Established(EstablishedConnection::new(Stream::TcpStream(tcp_stream)))
};
Self { valid_until, inner }
}
#[inline]
pub fn from_established(valid_until: ValidUntil, established: EstablishedConnection) -> Self {
Self {
valid_until,
inner: ConnectionInner::Established(established),
}
}
#[inline]
pub fn from_in_progress(valid_until: ValidUntil, machine: TlsHandshakeMachine) -> Self {
Self {
valid_until,
inner: ConnectionInner::InProgress(machine),
}
}
#[inline]
pub fn get_established(&mut self) -> Option<&mut EstablishedConnection> {
if let ConnectionInner::Established(ref mut established) = self.inner {
Some(established)
} else {
None
}
}
/// Takes ownership since TlsStream needs ownership of TcpStream
#[inline]
pub fn get_in_progress(self) -> Option<TlsHandshakeMachine> {
if let ConnectionInner::InProgress(machine) = self.inner {
Some(machine)
} else {
None
}
}
pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
match &mut self.inner {
ConnectionInner::Established(established) => match &mut established.stream {
Stream::TcpStream(ref mut stream) => poll.registry().deregister(stream),
Stream::TlsStream(ref mut stream) => poll.registry().deregister(stream.get_mut()),
},
ConnectionInner::InProgress(TlsHandshakeMachine { inner, .. }) => match inner {
TlsHandshakeMachineInner::TcpStream(ref mut stream) => {
poll.registry().deregister(stream)
}
TlsHandshakeMachineInner::TlsMidHandshake(ref mut mid_handshake) => {
poll.registry().deregister(mid_handshake.get_mut())
}
},
}
}
}
pub type ConnectionMap = HashMap<Token, Connection>;

View file

@ -1,388 +0,0 @@
use std::io::{Cursor, ErrorKind};
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::vec::Drain;
use aquatic_common::access_list::AccessListQuery;
use aquatic_http_protocol::request::Request;
use hashbrown::HashMap;
use log::{debug, error, info};
use mio::net::TcpListener;
use mio::{Events, Interest, Poll, Token};
use native_tls::TlsAcceptor;
use aquatic_http_protocol::response::*;
use crate::common::*;
use crate::config::Config;
use crate::mio::common::*;
pub mod connection;
pub mod stream;
pub mod utils;
use connection::*;
use utils::*;
const CONNECTION_CLEAN_INTERVAL: usize = 2 ^ 22;
pub fn run_socket_worker(
config: Config,
state: State,
socket_worker_index: usize,
socket_worker_statuses: SocketWorkerStatuses,
request_channel_sender: RequestChannelSender,
response_channel_receiver: ResponseChannelReceiver,
opt_tls_acceptor: Option<TlsAcceptor>,
poll: Poll,
) {
match create_listener(config.network.address, config.network.ipv6_only) {
Ok(listener) => {
socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(()));
run_poll_loop(
config,
&state,
socket_worker_index,
request_channel_sender,
response_channel_receiver,
listener,
opt_tls_acceptor,
poll,
);
}
Err(err) => {
socket_worker_statuses.lock()[socket_worker_index] =
Some(Err(format!("Couldn't open socket: {:#}", err)));
}
}
}
pub fn run_poll_loop(
config: Config,
state: &State,
socket_worker_index: usize,
request_channel_sender: RequestChannelSender,
response_channel_receiver: ResponseChannelReceiver,
listener: ::std::net::TcpListener,
opt_tls_acceptor: Option<TlsAcceptor>,
mut poll: Poll,
) {
let poll_timeout = Duration::from_micros(config.network.poll_timeout_microseconds);
let mut listener = TcpListener::from_std(listener);
let mut events = Events::with_capacity(config.network.poll_event_capacity);
poll.registry()
.register(&mut listener, Token(0), Interest::READABLE)
.unwrap();
let mut connections: ConnectionMap = HashMap::new();
let opt_tls_acceptor = opt_tls_acceptor.map(Arc::new);
let mut poll_token_counter = Token(0usize);
let mut iter_counter = 0usize;
let mut response_buffer = [0u8; 4096];
let mut response_buffer = Cursor::new(&mut response_buffer[..]);
let mut local_responses = Vec::new();
loop {
poll.poll(&mut events, Some(poll_timeout))
.expect("failed polling");
for event in events.iter() {
let token = event.token();
if token == LISTENER_TOKEN {
accept_new_streams(
&config,
&mut listener,
&mut poll,
&mut connections,
&mut poll_token_counter,
&opt_tls_acceptor,
);
} else if token != CHANNEL_TOKEN {
handle_connection_read_event(
&config,
&state,
socket_worker_index,
&mut poll,
&request_channel_sender,
&mut local_responses,
&mut connections,
token,
);
}
// Send responses for each event. Channel token is not interesting
// by itself, but is just for making sure responses are sent even
// if no new connects / requests come in.
send_responses(
&config,
&mut poll,
&mut response_buffer,
local_responses.drain(..),
&response_channel_receiver,
&mut connections,
);
}
// Remove inactive connections, but not every iteration
if iter_counter % CONNECTION_CLEAN_INTERVAL == 0 {
remove_inactive_connections(&mut poll, &mut connections);
}
iter_counter = iter_counter.wrapping_add(1);
}
}
fn accept_new_streams(
config: &Config,
listener: &mut TcpListener,
poll: &mut Poll,
connections: &mut ConnectionMap,
poll_token_counter: &mut Token,
opt_tls_acceptor: &Option<Arc<TlsAcceptor>>,
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
loop {
match listener.accept() {
Ok((mut stream, _)) => {
poll_token_counter.0 = poll_token_counter.0.wrapping_add(1);
// Skip listener and channel tokens
if poll_token_counter.0 < 2 {
poll_token_counter.0 = 2;
}
let token = *poll_token_counter;
// Remove connection if it exists (which is unlikely)
remove_connection(poll, connections, poll_token_counter);
poll.registry()
.register(&mut stream, token, Interest::READABLE)
.unwrap();
let connection = Connection::new(opt_tls_acceptor, valid_until, stream);
connections.insert(token, connection);
}
Err(err) => {
if err.kind() == ErrorKind::WouldBlock {
break;
}
info!("error while accepting streams: {}", err);
}
}
}
}
/// On the stream given by poll_token, get TLS up and running if requested,
/// then read requests and pass on through channel.
pub fn handle_connection_read_event(
config: &Config,
state: &State,
socket_worker_index: usize,
poll: &mut Poll,
request_channel_sender: &RequestChannelSender,
local_responses: &mut Vec<(ConnectionMeta, Response)>,
connections: &mut ConnectionMap,
poll_token: Token,
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode;
loop {
// Get connection, updating valid_until
let connection = if let Some(c) = connections.get_mut(&poll_token) {
c
} else {
// If there is no connection, there is no stream, so there
// shouldn't be any (relevant) poll events. In other words, it's
// safe to return here
return;
};
connection.valid_until = valid_until;
if let Some(established) = connection.get_established() {
match established.read_request() {
Ok(Request::Announce(ref r))
if !state.access_list.allows(access_list_mode, &r.info_hash.0) =>
{
let meta = ConnectionMeta {
worker_index: socket_worker_index,
poll_token: poll_token.0,
peer_addr: established.peer_addr,
};
let response = FailureResponse::new("Info hash not allowed");
debug!("read disallowed request, sending back error response");
local_responses.push((meta, Response::Failure(response)));
break;
}
Ok(request) => {
let meta = ConnectionMeta {
worker_index: socket_worker_index,
poll_token: poll_token.0,
peer_addr: established.peer_addr,
};
debug!("read allowed request, sending on to channel");
if let Err(err) = request_channel_sender.send((meta, request)) {
error!("RequestChannelSender: couldn't send message: {:?}", err);
}
break;
}
Err(RequestReadError::NeedMoreData) => {
info!("need more data");
// Stop reading data (defer to later events)
break;
}
Err(RequestReadError::Parse(err)) => {
info!("error reading request (invalid): {:#?}", err);
let meta = ConnectionMeta {
worker_index: socket_worker_index,
poll_token: poll_token.0,
peer_addr: established.peer_addr,
};
let response = FailureResponse::new("Invalid request");
local_responses.push((meta, Response::Failure(response)));
break;
}
Err(RequestReadError::StreamEnded) => {
::log::debug!("stream ended");
remove_connection(poll, connections, &poll_token);
break;
}
Err(RequestReadError::Io(err)) => {
::log::info!("error reading request (io): {}", err);
remove_connection(poll, connections, &poll_token);
break;
}
}
} else if let Some(handshake_machine) = connections
.remove(&poll_token)
.and_then(Connection::get_in_progress)
{
match handshake_machine.establish_tls() {
Ok(established) => {
let connection = Connection::from_established(valid_until, established);
connections.insert(poll_token, connection);
}
Err(TlsHandshakeMachineError::WouldBlock(machine)) => {
let connection = Connection::from_in_progress(valid_until, machine);
connections.insert(poll_token, connection);
// Break and wait for more data
break;
}
Err(TlsHandshakeMachineError::Failure(err)) => {
info!("tls handshake error: {}", err);
// TLS negotiation failed
break;
}
}
}
}
}
/// Read responses from channel, send to peers
pub fn send_responses(
config: &Config,
poll: &mut Poll,
buffer: &mut Cursor<&mut [u8]>,
local_responses: Drain<(ConnectionMeta, Response)>,
channel_responses: &ResponseChannelReceiver,
connections: &mut ConnectionMap,
) {
let channel_responses_len = channel_responses.len();
let channel_responses_drain = channel_responses.try_iter().take(channel_responses_len);
for (meta, response) in local_responses.chain(channel_responses_drain) {
if let Some(established) = connections
.get_mut(&Token(meta.poll_token))
.and_then(Connection::get_established)
{
if established.peer_addr != meta.peer_addr {
info!("socket worker error: peer socket addrs didn't match");
continue;
}
buffer.set_position(0);
let bytes_written = response.write(buffer).unwrap();
match established.send_response(&buffer.get_mut()[..bytes_written]) {
Ok(()) => {
::log::debug!(
"sent response: {:?} with response string {}",
response,
String::from_utf8_lossy(&buffer.get_ref()[..bytes_written])
);
if !config.network.keep_alive {
remove_connection(poll, connections, &Token(meta.poll_token));
}
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
debug!("send response: would block");
}
Err(err) => {
info!("error sending response: {}", err);
remove_connection(poll, connections, &Token(meta.poll_token));
}
}
}
}
}
// Close and remove inactive connections
pub fn remove_inactive_connections(poll: &mut Poll, connections: &mut ConnectionMap) {
let now = Instant::now();
connections.retain(|_, connection| {
let keep = connection.valid_until.0 >= now;
if !keep {
if let Err(err) = connection.deregister(poll) {
::log::error!("deregister connection error: {}", err);
}
}
keep
});
connections.shrink_to_fit();
}
fn remove_connection(poll: &mut Poll, connections: &mut ConnectionMap, connection_token: &Token) {
if let Some(mut connection) = connections.remove(connection_token) {
if let Err(err) = connection.deregister(poll) {
::log::error!("deregister connection error: {}", err);
}
}
}

View file

@ -1,69 +0,0 @@
use std::io::{Read, Write};
use std::net::SocketAddr;
use mio::net::TcpStream;
use native_tls::TlsStream;
pub enum Stream {
TcpStream(TcpStream),
TlsStream(TlsStream<TcpStream>),
}
impl Stream {
#[inline]
pub fn get_peer_addr(&self) -> SocketAddr {
match self {
Self::TcpStream(stream) => stream.peer_addr().unwrap(),
Self::TlsStream(stream) => stream.get_ref().peer_addr().unwrap(),
}
}
}
impl Read for Stream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> Result<usize, ::std::io::Error> {
match self {
Self::TcpStream(stream) => stream.read(buf),
Self::TlsStream(stream) => stream.read(buf),
}
}
/// Not used but provided for completeness
#[inline]
fn read_vectored(
&mut self,
bufs: &mut [::std::io::IoSliceMut<'_>],
) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.read_vectored(bufs),
Self::TlsStream(stream) => stream.read_vectored(bufs),
}
}
}
impl Write for Stream {
#[inline]
fn write(&mut self, buf: &[u8]) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.write(buf),
Self::TlsStream(stream) => stream.write(buf),
}
}
/// Not used but provided for completeness
#[inline]
fn write_vectored(&mut self, bufs: &[::std::io::IoSlice<'_>]) -> ::std::io::Result<usize> {
match self {
Self::TcpStream(stream) => stream.write_vectored(bufs),
Self::TlsStream(stream) => stream.write_vectored(bufs),
}
}
#[inline]
fn flush(&mut self) -> ::std::io::Result<()> {
match self {
Self::TcpStream(stream) => stream.flush(),
Self::TlsStream(stream) => stream.flush(),
}
}
}

View file

@ -1,63 +0,0 @@
use std::fs::File;
use std::io::Read;
use std::net::SocketAddr;
use anyhow::Context;
use native_tls::{Identity, TlsAcceptor};
use socket2::{Domain, Protocol, Socket, Type};
use crate::config::TlsConfig;
pub fn create_tls_acceptor(config: &TlsConfig) -> anyhow::Result<Option<TlsAcceptor>> {
if config.use_tls {
let mut identity_bytes = Vec::new();
let mut file =
File::open(&config.tls_pkcs12_path).context("Couldn't open pkcs12 identity file")?;
file.read_to_end(&mut identity_bytes)
.context("Couldn't read pkcs12 identity file")?;
let identity = Identity::from_pkcs12(&identity_bytes[..], &config.tls_pkcs12_password)
.context("Couldn't parse pkcs12 identity file")?;
let acceptor = TlsAcceptor::new(identity)
.context("Couldn't create TlsAcceptor from pkcs12 identity")?;
Ok(Some(acceptor))
} else {
Ok(None)
}
}
pub fn create_listener(
address: SocketAddr,
ipv6_only: bool,
) -> ::anyhow::Result<::std::net::TcpListener> {
let builder = if address.is_ipv4() {
Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))
} else {
Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))
}
.context("Couldn't create socket2::Socket")?;
if ipv6_only {
builder
.set_only_v6(true)
.context("Couldn't put socket in ipv6 only mode")?
}
builder
.set_nonblocking(true)
.context("Couldn't put socket in non-blocking mode")?;
builder
.set_reuse_port(true)
.context("Couldn't put socket in reuse_port mode")?;
builder
.bind(&address.into())
.with_context(|| format!("Couldn't bind socket to address {}", address))?;
builder
.listen(128)
.context("Couldn't listen for connections on socket")?;
Ok(builder.into())
}

View file

@ -1,53 +0,0 @@
use histogram::Histogram;
use aquatic_common::access_list::{AccessListMode, AccessListQuery};
use super::common::*;
use crate::config::Config;
pub fn update_access_list(config: &Config, state: &State) {
match config.access_list.mode {
AccessListMode::White | AccessListMode::Black => {
if let Err(err) = state.access_list.update_from_path(&config.access_list.path) {
::log::error!("Couldn't update access list: {:?}", err);
}
}
AccessListMode::Off => {}
}
}
pub fn print_statistics(state: &State) {
let mut peers_per_torrent = Histogram::new();
{
let torrents = &mut state.torrent_maps.lock();
for torrent in torrents.ipv4.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
for torrent in torrents.ipv6.values() {
let num_peers = (torrent.num_seeders + torrent.num_leechers) as u64;
if let Err(err) = peers_per_torrent.increment(num_peers) {
eprintln!("error incrementing peers_per_torrent histogram: {}", err)
}
}
}
if peers_per_torrent.entries() != 0 {
println!(
"peers per torrent: min: {}, p50: {}, p75: {}, p90: {}, p99: {}, p999: {}, max: {}",
peers_per_torrent.minimum().unwrap(),
peers_per_torrent.percentile(50.0).unwrap(),
peers_per_torrent.percentile(75.0).unwrap(),
peers_per_torrent.percentile(90.0).unwrap(),
peers_per_torrent.percentile(99.0).unwrap(),
peers_per_torrent.percentile(99.9).unwrap(),
peers_per_torrent.maximum().unwrap(),
);
}
}

View file

@ -1,11 +1,3 @@
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use cfg_if::cfg_if;
pub mod common;