Merge pull request #9 from greatest-ape/udp-handle-connect-in-socket-workers

aquatic_udp: handle connection checking in socket workers, other improvements
This commit is contained in:
Joakim Frostegård 2021-10-18 12:09:07 +02:00 committed by GitHub
commit 65ef9a8ab2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 263 additions and 412 deletions

1
Cargo.lock generated
View file

@ -216,6 +216,7 @@ name = "aquatic_udp_protocol"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"either",
"quickcheck", "quickcheck",
"quickcheck_macros", "quickcheck_macros",
] ]

View file

@ -1,13 +1,8 @@
# TODO # TODO
* aquatic_udp needs to check connection validity before sending error responses! connection validity checks could be moved to socket workers, since theyare sharded by ip
* access lists: * access lists:
* use arc-swap Cache * use arc-swap Cache
* test functionality * add CI tests
* aquatic_udp
* aquatic_http
* aquatic_ws, including sending back new error responses
* aquatic_ws: should it send back error on message parse error, or does that * aquatic_ws: should it send back error on message parse error, or does that
just indicate that not enough data has been received yet? just indicate that not enough data has been received yet?

View file

@ -1,5 +1,5 @@
use std::hash::Hash; use std::hash::Hash;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::{atomic::AtomicUsize, Arc}; use std::sync::{atomic::AtomicUsize, Arc};
use std::time::Instant; use std::time::Instant;
@ -30,23 +30,25 @@ impl Ip for Ipv6Addr {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ConnectedRequest {
pub struct ConnectionKey { Announce(AnnounceRequest),
pub connection_id: ConnectionId, Scrape(ScrapeRequest),
pub socket_addr: SocketAddr,
} }
impl ConnectionKey { pub enum ConnectedResponse {
pub fn new(connection_id: ConnectionId, socket_addr: SocketAddr) -> Self { Announce(AnnounceResponse),
Self { Scrape(ScrapeResponse),
connection_id, }
socket_addr,
impl Into<Response> for ConnectedResponse {
fn into(self) -> Response {
match self {
Self::Announce(response) => Response::Announce(response),
Self::Scrape(response) => Response::Scrape(response),
} }
} }
} }
pub type ConnectionMap = HashMap<ConnectionKey, ValidUntil>;
#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)]
pub enum PeerStatus { pub enum PeerStatus {
Seeding, Seeding,
@ -172,7 +174,6 @@ impl TorrentMaps {
pub struct Statistics { pub struct Statistics {
pub requests_received: AtomicUsize, pub requests_received: AtomicUsize,
pub responses_sent: AtomicUsize, pub responses_sent: AtomicUsize,
pub readable_events: AtomicUsize,
pub bytes_received: AtomicUsize, pub bytes_received: AtomicUsize,
pub bytes_sent: AtomicUsize, pub bytes_sent: AtomicUsize,
} }
@ -180,7 +181,6 @@ pub struct Statistics {
#[derive(Clone)] #[derive(Clone)]
pub struct State { pub struct State {
pub access_list: Arc<AccessList>, pub access_list: Arc<AccessList>,
pub connections: Arc<Mutex<ConnectionMap>>,
pub torrents: Arc<Mutex<TorrentMaps>>, pub torrents: Arc<Mutex<TorrentMaps>>,
pub statistics: Arc<Statistics>, pub statistics: Arc<Statistics>,
} }
@ -189,7 +189,6 @@ impl Default for State {
fn default() -> Self { fn default() -> Self {
Self { Self {
access_list: Arc::new(AccessList::default()), access_list: Arc::new(AccessList::default()),
connections: Arc::new(Mutex::new(HashMap::new())),
torrents: Arc::new(Mutex::new(TorrentMaps::default())), torrents: Arc::new(Mutex::new(TorrentMaps::default())),
statistics: Arc::new(Statistics::default()), statistics: Arc::new(Statistics::default()),
} }

View file

@ -84,7 +84,7 @@ pub struct StatisticsConfig {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)] #[serde(default)]
pub struct CleaningConfig { pub struct CleaningConfig {
/// Clean torrents and connections this often (seconds) /// Update access list and clean torrents this often (seconds)
pub interval: u64, pub interval: u64,
/// Remove peers that haven't announced for this long (seconds) /// Remove peers that haven't announced for this long (seconds)
pub max_peer_age: u64, pub max_peer_age: u64,

View file

@ -16,7 +16,7 @@ pub fn handle_announce_requests(
torrents: &mut MutexGuard<TorrentMaps>, torrents: &mut MutexGuard<TorrentMaps>,
rng: &mut SmallRng, rng: &mut SmallRng,
requests: Drain<(AnnounceRequest, SocketAddr)>, requests: Drain<(AnnounceRequest, SocketAddr)>,
responses: &mut Vec<(Response, SocketAddr)>, responses: &mut Vec<(ConnectedResponse, SocketAddr)>,
) { ) {
let peer_valid_until = ValidUntil::new(config.cleaning.max_peer_age); let peer_valid_until = ValidUntil::new(config.cleaning.max_peer_age);
@ -42,7 +42,7 @@ pub fn handle_announce_requests(
), ),
}; };
(Response::Announce(response), src) (ConnectedResponse::Announce(response), src)
})); }));
} }

View file

@ -1,39 +0,0 @@
use std::net::SocketAddr;
use std::vec::Drain;
use parking_lot::MutexGuard;
use rand::{rngs::StdRng, Rng};
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
#[inline]
pub fn handle_connect_requests(
config: &Config,
connections: &mut MutexGuard<ConnectionMap>,
rng: &mut StdRng,
requests: Drain<(ConnectRequest, SocketAddr)>,
responses: &mut Vec<(Response, SocketAddr)>,
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
responses.extend(requests.map(|(request, src)| {
let connection_id = ConnectionId(rng.gen());
let key = ConnectionKey {
connection_id,
socket_addr: src,
};
connections.insert(key, valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
(response, src)
}));
}

View file

@ -2,11 +2,7 @@ use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use crossbeam_channel::{Receiver, Sender}; use crossbeam_channel::{Receiver, Sender};
use parking_lot::MutexGuard; use rand::{rngs::SmallRng, SeedableRng};
use rand::{
rngs::{SmallRng, StdRng},
SeedableRng,
};
use aquatic_udp_protocol::*; use aquatic_udp_protocol::*;
@ -14,40 +10,35 @@ use crate::common::*;
use crate::config::Config; use crate::config::Config;
mod announce; mod announce;
mod connect;
mod scrape; mod scrape;
use announce::handle_announce_requests; use announce::handle_announce_requests;
use connect::handle_connect_requests;
use scrape::handle_scrape_requests; use scrape::handle_scrape_requests;
pub fn run_request_worker( pub fn run_request_worker(
state: State, state: State,
config: Config, config: Config,
request_receiver: Receiver<(Request, SocketAddr)>, request_receiver: Receiver<(ConnectedRequest, SocketAddr)>,
response_sender: Sender<(Response, SocketAddr)>, response_sender: Sender<(ConnectedResponse, SocketAddr)>,
) { ) {
let mut connect_requests: Vec<(ConnectRequest, SocketAddr)> = Vec::new();
let mut announce_requests: Vec<(AnnounceRequest, SocketAddr)> = Vec::new(); let mut announce_requests: Vec<(AnnounceRequest, SocketAddr)> = Vec::new();
let mut scrape_requests: Vec<(ScrapeRequest, SocketAddr)> = Vec::new(); let mut scrape_requests: Vec<(ScrapeRequest, SocketAddr)> = Vec::new();
let mut responses: Vec<(ConnectedResponse, SocketAddr)> = Vec::new();
let mut responses: Vec<(Response, SocketAddr)> = Vec::new(); let mut small_rng = SmallRng::from_entropy();
let mut std_rng = StdRng::from_entropy();
let mut small_rng = SmallRng::from_rng(&mut std_rng).unwrap();
let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds); let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds);
loop { loop {
let mut opt_connections = None; let mut opt_torrents = None;
// Collect requests from channel, divide them by type // Collect requests from channel, divide them by type
// //
// Collect a maximum number of request. Stop collecting before that // Collect a maximum number of request. Stop collecting before that
// number is reached if having waited for too long for a request, but // number is reached if having waited for too long for a request, but
// only if ConnectionMap mutex isn't locked. // only if TorrentMaps mutex isn't locked.
for i in 0..config.handlers.max_requests_per_iter { for i in 0..config.handlers.max_requests_per_iter {
let (request, src): (Request, SocketAddr) = if i == 0 { let (request, src): (ConnectedRequest, SocketAddr) = if i == 0 {
match request_receiver.recv() { match request_receiver.recv() {
Ok(r) => r, Ok(r) => r,
Err(_) => break, // Really shouldn't happen Err(_) => break, // Really shouldn't happen
@ -56,8 +47,8 @@ pub fn run_request_worker(
match request_receiver.recv_timeout(timeout) { match request_receiver.recv_timeout(timeout) {
Ok(r) => r, Ok(r) => r,
Err(_) => { Err(_) => {
if let Some(guard) = state.connections.try_lock() { if let Some(guard) = state.torrents.try_lock() {
opt_connections = Some(guard); opt_torrents = Some(guard);
break; break;
} else { } else {
@ -68,59 +59,14 @@ pub fn run_request_worker(
}; };
match request { match request {
Request::Connect(r) => connect_requests.push((r, src)), ConnectedRequest::Announce(r) => announce_requests.push((r, src)),
Request::Announce(r) => announce_requests.push((r, src)), ConnectedRequest::Scrape(r) => scrape_requests.push((r, src)),
Request::Scrape(r) => scrape_requests.push((r, src)),
} }
} }
let mut connections: MutexGuard<ConnectionMap> = // Generate responses for announce and scrape requests, then drop MutexGuard.
opt_connections.unwrap_or_else(|| state.connections.lock()); {
let mut torrents = opt_torrents.unwrap_or_else(|| state.torrents.lock());
handle_connect_requests(
&config,
&mut connections,
&mut std_rng,
connect_requests.drain(..),
&mut responses,
);
// Check announce and scrape requests for valid connections
announce_requests.retain(|(request, src)| {
let connection_valid =
connections.contains_key(&ConnectionKey::new(request.connection_id, *src));
if !connection_valid {
responses.push((
create_invalid_connection_response(request.transaction_id),
*src,
));
}
connection_valid
});
scrape_requests.retain(|(request, src)| {
let connection_valid =
connections.contains_key(&ConnectionKey::new(request.connection_id, *src));
if !connection_valid {
responses.push((
create_invalid_connection_response(request.transaction_id),
*src,
));
}
connection_valid
});
::std::mem::drop(connections);
// Generate responses for announce and scrape requests
if !(announce_requests.is_empty() && scrape_requests.is_empty()) {
let mut torrents = state.torrents.lock();
handle_announce_requests( handle_announce_requests(
&config, &config,
@ -129,6 +75,7 @@ pub fn run_request_worker(
announce_requests.drain(..), announce_requests.drain(..),
&mut responses, &mut responses,
); );
handle_scrape_requests(&mut torrents, scrape_requests.drain(..), &mut responses); handle_scrape_requests(&mut torrents, scrape_requests.drain(..), &mut responses);
} }
@ -139,10 +86,3 @@ pub fn run_request_worker(
} }
} }
} }
fn create_invalid_connection_response(transaction_id: TransactionId) -> Response {
Response::Error(ErrorResponse {
transaction_id,
message: "Connection invalid or expired".into(),
})
}

View file

@ -12,7 +12,7 @@ use crate::common::*;
pub fn handle_scrape_requests( pub fn handle_scrape_requests(
torrents: &mut MutexGuard<TorrentMaps>, torrents: &mut MutexGuard<TorrentMaps>,
requests: Drain<(ScrapeRequest, SocketAddr)>, requests: Drain<(ScrapeRequest, SocketAddr)>,
responses: &mut Vec<(Response, SocketAddr)>, responses: &mut Vec<(ConnectedResponse, SocketAddr)>,
) { ) {
let empty_stats = create_torrent_scrape_statistics(0, 0); let empty_stats = create_torrent_scrape_statistics(0, 0);
@ -45,7 +45,7 @@ pub fn handle_scrape_requests(
} }
} }
let response = Response::Scrape(ScrapeResponse { let response = ConnectedResponse::Scrape(ScrapeResponse {
transaction_id: request.transaction_id, transaction_id: request.transaction_id,
torrent_stats: stats, torrent_stats: stats,
}); });

View file

@ -23,7 +23,7 @@ pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker";
pub fn run(config: Config) -> ::anyhow::Result<()> { pub fn run(config: Config) -> ::anyhow::Result<()> {
let state = State::default(); let state = State::default();
tasks::update_access_list(&config, &state); tasks::update_access_list(&config, &state.access_list);
let num_bound_sockets = start_workers(config.clone(), state.clone())?; let num_bound_sockets = start_workers(config.clone(), state.clone())?;
@ -55,8 +55,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
loop { loop {
::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); ::std::thread::sleep(Duration::from_secs(config.cleaning.interval));
tasks::clean_connections(&state); tasks::update_access_list(&config, &state.access_list);
tasks::update_access_list(&config, &state);
state.torrents.lock().clean(&config, &state.access_list); state.torrents.lock().clean(&config, &state.access_list);
} }

View file

@ -4,12 +4,14 @@ use std::sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
Arc, Arc,
}; };
use std::time::Duration; use std::time::{Duration, Instant};
use std::vec::Drain; use std::vec::Drain;
use crossbeam_channel::{Receiver, Sender}; use crossbeam_channel::{Receiver, Sender};
use hashbrown::HashMap;
use mio::net::UdpSocket; use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token}; use mio::{Events, Interest, Poll, Token};
use rand::prelude::{Rng, SeedableRng, StdRng};
use socket2::{Domain, Protocol, Socket, Type}; use socket2::{Domain, Protocol, Socket, Type};
use aquatic_udp_protocol::{IpVersion, Request, Response}; use aquatic_udp_protocol::{IpVersion, Request, Response};
@ -17,14 +19,40 @@ use aquatic_udp_protocol::{IpVersion, Request, Response};
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
#[derive(Default)]
struct ConnectionMap(HashMap<(ConnectionId, SocketAddr), ValidUntil>);
impl ConnectionMap {
fn insert(
&mut self,
connection_id: ConnectionId,
socket_addr: SocketAddr,
valid_until: ValidUntil,
) {
self.0.insert((connection_id, socket_addr), valid_until);
}
fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool {
self.0.contains_key(&(connection_id, socket_addr))
}
fn clean(&mut self) {
let now = Instant::now();
self.0.retain(|_, v| v.0 > now);
self.0.shrink_to_fit();
}
}
pub fn run_socket_worker( pub fn run_socket_worker(
state: State, state: State,
config: Config, config: Config,
token_num: usize, token_num: usize,
request_sender: Sender<(Request, SocketAddr)>, request_sender: Sender<(ConnectedRequest, SocketAddr)>,
response_receiver: Receiver<(Response, SocketAddr)>, response_receiver: Receiver<(ConnectedResponse, SocketAddr)>,
num_bound_sockets: Arc<AtomicUsize>, num_bound_sockets: Arc<AtomicUsize>,
) { ) {
let mut rng = StdRng::from_entropy();
let mut buffer = [0u8; MAX_PACKET_SIZE]; let mut buffer = [0u8; MAX_PACKET_SIZE];
let mut socket = UdpSocket::from_std(create_socket(&config)); let mut socket = UdpSocket::from_std(create_socket(&config));
@ -39,12 +67,14 @@ pub fn run_socket_worker(
num_bound_sockets.fetch_add(1, Ordering::SeqCst); num_bound_sockets.fetch_add(1, Ordering::SeqCst);
let mut events = Events::with_capacity(config.network.poll_event_capacity); let mut events = Events::with_capacity(config.network.poll_event_capacity);
let mut connections = ConnectionMap::default();
let mut requests: Vec<(Request, SocketAddr)> = Vec::new();
let mut local_responses: Vec<(Response, SocketAddr)> = Vec::new(); let mut local_responses: Vec<(Response, SocketAddr)> = Vec::new();
let timeout = Duration::from_millis(50); let timeout = Duration::from_millis(50);
let mut iter_counter = 0usize;
loop { loop {
poll.poll(&mut events, Some(timeout)) poll.poll(&mut events, Some(timeout))
.expect("failed polling"); .expect("failed polling");
@ -54,24 +84,15 @@ pub fn run_socket_worker(
if (token.0 == token_num) & event.is_readable() { if (token.0 == token_num) & event.is_readable() {
read_requests( read_requests(
&state,
&config, &config,
&state,
&mut connections,
&mut rng,
&mut socket, &mut socket,
&mut buffer, &mut buffer,
&mut requests, &request_sender,
&mut local_responses, &mut local_responses,
); );
for r in requests.drain(..) {
if let Err(err) = request_sender.send(r) {
::log::error!("error sending to request_sender: {}", err);
}
}
state
.statistics
.readable_events
.fetch_add(1, Ordering::SeqCst);
} }
} }
@ -83,6 +104,14 @@ pub fn run_socket_worker(
&response_receiver, &response_receiver,
local_responses.drain(..), local_responses.drain(..),
); );
iter_counter += 1;
if iter_counter == 1000 {
connections.clean();
iter_counter = 0;
}
} }
} }
@ -121,16 +150,19 @@ fn create_socket(config: &Config) -> ::std::net::UdpSocket {
#[inline] #[inline]
fn read_requests( fn read_requests(
state: &State,
config: &Config, config: &Config,
state: &State,
connections: &mut ConnectionMap,
rng: &mut StdRng,
socket: &mut UdpSocket, socket: &mut UdpSocket,
buffer: &mut [u8], buffer: &mut [u8],
requests: &mut Vec<(Request, SocketAddr)>, request_sender: &Sender<(ConnectedRequest, SocketAddr)>,
local_responses: &mut Vec<(Response, SocketAddr)>, local_responses: &mut Vec<(Response, SocketAddr)>,
) { ) {
let mut requests_received: usize = 0; let mut requests_received: usize = 0;
let mut bytes_received: usize = 0; let mut bytes_received: usize = 0;
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode; let access_list_mode = config.access_list.mode;
loop { loop {
@ -146,37 +178,61 @@ fn read_requests(
} }
match request { match request {
Ok(Request::Announce(AnnounceRequest { Ok(Request::Connect(request)) => {
info_hash, let connection_id = ConnectionId(rng.gen());
transaction_id,
.. connections.insert(connection_id, src, valid_until);
})) if !state.access_list.allows(access_list_mode, &info_hash.0) => {
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) {
if state
.access_list
.allows(access_list_mode, &request.info_hash.0)
{
if let Err(err) = request_sender
.try_send((ConnectedRequest::Announce(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
} else {
let response = Response::Error(ErrorResponse { let response = Response::Error(ErrorResponse {
transaction_id, transaction_id: request.transaction_id,
message: "Info hash not allowed".into(), message: "Info hash not allowed".into(),
}); });
local_responses.push((response, src)) local_responses.push((response, src))
} }
Ok(request) => { }
requests.push((request, src)); }
Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) {
if let Err(err) =
request_sender.try_send((ConnectedRequest::Scrape(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
}
} }
Err(err) => { Err(err) => {
::log::debug!("request_from_bytes error: {:?}", err); ::log::debug!("Request::from_bytes error: {:?}", err);
if let Some(transaction_id) = err.transaction_id { if let RequestParseError::Sendable {
let opt_message = if err.error.is_some() { connection_id,
Some("Parse error".into()) transaction_id,
} else if let Some(message) = err.message { err,
Some(message.into()) } = err
} else { {
None if connections.contains(connection_id, src) {
};
if let Some(message) = opt_message {
let response = ErrorResponse { let response = ErrorResponse {
transaction_id, transaction_id,
message, message: err.right_or("Parse error").into(),
}; };
local_responses.push((response.into(), src)); local_responses.push((response.into(), src));
@ -213,7 +269,7 @@ fn send_responses(
config: &Config, config: &Config,
socket: &mut UdpSocket, socket: &mut UdpSocket,
buffer: &mut [u8], buffer: &mut [u8],
response_receiver: &Receiver<(Response, SocketAddr)>, response_receiver: &Receiver<(ConnectedResponse, SocketAddr)>,
local_responses: Drain<(Response, SocketAddr)>, local_responses: Drain<(Response, SocketAddr)>,
) { ) {
let mut responses_sent: usize = 0; let mut responses_sent: usize = 0;
@ -221,17 +277,19 @@ fn send_responses(
let mut cursor = Cursor::new(buffer); let mut cursor = Cursor::new(buffer);
let response_iterator = local_responses let response_iterator = local_responses.into_iter().chain(
.into_iter() response_receiver
.chain(response_receiver.try_iter()); .try_iter()
.map(|(response, addr)| (response.into(), addr)),
);
for (response, src) in response_iterator { for (response, src) in response_iterator {
cursor.set_position(0); cursor.set_position(0);
let ip_version = ip_version_from_ip(src.ip()); let ip_version = ip_version_from_ip(src.ip());
response.write(&mut cursor, ip_version).unwrap(); match response.write(&mut cursor, ip_version) {
Ok(()) => {
let amt = cursor.position() as usize; let amt = cursor.position() as usize;
match socket.send_to(&cursor.get_ref()[..amt], src) { match socket.send_to(&cursor.get_ref()[..amt], src) {
@ -248,6 +306,11 @@ fn send_responses(
} }
} }
} }
Err(err) => {
::log::error!("Response::write error: {:?}", err);
}
}
}
if config.statistics.interval != 0 { if config.statistics.interval != 0 {
state state

View file

@ -1,5 +1,5 @@
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Instant; use std::sync::Arc;
use histogram::Histogram; use histogram::Histogram;
@ -8,10 +8,10 @@ use aquatic_common::access_list::AccessListMode;
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
pub fn update_access_list(config: &Config, state: &State) { pub fn update_access_list(config: &Config, access_list: &Arc<AccessList>) {
match config.access_list.mode { match config.access_list.mode {
AccessListMode::White | AccessListMode::Black => { AccessListMode::White | AccessListMode::Black => {
if let Err(err) = state.access_list.update_from_path(&config.access_list.path) { if let Err(err) = access_list.update_from_path(&config.access_list.path) {
::log::error!("Update access list from path: {:?}", err); ::log::error!("Update access list from path: {:?}", err);
} }
} }
@ -19,15 +19,6 @@ pub fn update_access_list(config: &Config, state: &State) {
} }
} }
pub fn clean_connections(state: &State) {
let now = Instant::now();
let mut connections = state.connections.lock();
connections.retain(|_, v| v.0 > now);
connections.shrink_to_fit();
}
pub fn gather_and_print_statistics(state: &State, config: &Config) { pub fn gather_and_print_statistics(state: &State, config: &Config) {
let interval = config.statistics.interval; let interval = config.statistics.interval;
@ -50,19 +41,9 @@ pub fn gather_and_print_statistics(state: &State, config: &Config) {
let bytes_received_per_second: f64 = bytes_received / interval as f64; let bytes_received_per_second: f64 = bytes_received / interval as f64;
let bytes_sent_per_second: f64 = bytes_sent / interval as f64; let bytes_sent_per_second: f64 = bytes_sent / interval as f64;
let readable_events: f64 = state
.statistics
.readable_events
.fetch_and(0, Ordering::SeqCst) as f64;
let requests_per_readable_event = if readable_events == 0.0 {
0.0
} else {
requests_received / readable_events
};
println!( println!(
"stats: {:.2} requests/second, {:.2} responses/second, {:.2} requests/readable event", "stats: {:.2} requests/second, {:.2} responses/second",
requests_per_second, responses_per_second, requests_per_readable_event requests_per_second, responses_per_second
); );
println!( println!(

View file

@ -1,4 +1,4 @@
use std::net::SocketAddr; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use crossbeam_channel::{Receiver, Sender}; use crossbeam_channel::{Receiver, Sender};
@ -13,15 +13,14 @@ use crate::common::*;
use crate::config::BenchConfig; use crate::config::BenchConfig;
pub fn bench_announce_handler( pub fn bench_announce_handler(
state: &State,
bench_config: &BenchConfig, bench_config: &BenchConfig,
aquatic_config: &Config, aquatic_config: &Config,
request_sender: &Sender<(Request, SocketAddr)>, request_sender: &Sender<(ConnectedRequest, SocketAddr)>,
response_receiver: &Receiver<(Response, SocketAddr)>, response_receiver: &Receiver<(ConnectedResponse, SocketAddr)>,
rng: &mut impl Rng, rng: &mut impl Rng,
info_hashes: &[InfoHash], info_hashes: &[InfoHash],
) -> (usize, Duration) { ) -> (usize, Duration) {
let requests = create_requests(state, rng, info_hashes, bench_config.num_announce_requests); let requests = create_requests(rng, info_hashes, bench_config.num_announce_requests);
let p = aquatic_config.handlers.max_requests_per_iter * bench_config.num_threads; let p = aquatic_config.handlers.max_requests_per_iter * bench_config.num_threads;
let mut num_responses = 0usize; let mut num_responses = 0usize;
@ -37,10 +36,12 @@ pub fn bench_announce_handler(
for round in (0..bench_config.num_rounds).progress_with(pb) { for round in (0..bench_config.num_rounds).progress_with(pb) {
for request_chunk in requests.chunks(p) { for request_chunk in requests.chunks(p) {
for (request, src) in request_chunk { for (request, src) in request_chunk {
request_sender.send((request.clone().into(), *src)).unwrap(); request_sender
.send((ConnectedRequest::Announce(request.clone()), *src))
.unwrap();
} }
while let Ok((Response::Announce(r), _)) = response_receiver.try_recv() { while let Ok((ConnectedResponse::Announce(r), _)) = response_receiver.try_recv() {
num_responses += 1; num_responses += 1;
if let Some(last_peer) = r.peers.last() { if let Some(last_peer) = r.peers.last() {
@ -52,7 +53,7 @@ pub fn bench_announce_handler(
let total = bench_config.num_announce_requests * (round + 1); let total = bench_config.num_announce_requests * (round + 1);
while num_responses < total { while num_responses < total {
if let Ok((Response::Announce(r), _)) = response_receiver.recv() { if let Ok((ConnectedResponse::Announce(r), _)) = response_receiver.recv() {
num_responses += 1; num_responses += 1;
if let Some(last_peer) = r.peers.last() { if let Some(last_peer) = r.peers.last() {
@ -72,7 +73,6 @@ pub fn bench_announce_handler(
} }
pub fn create_requests( pub fn create_requests(
state: &State,
rng: &mut impl Rng, rng: &mut impl Rng,
info_hashes: &[InfoHash], info_hashes: &[InfoHash],
number: usize, number: usize,
@ -83,15 +83,11 @@ pub fn create_requests(
let mut requests = Vec::new(); let mut requests = Vec::new();
let connections = state.connections.lock(); for _ in 0..number {
let connection_keys: Vec<ConnectionKey> = connections.keys().take(number).cloned().collect();
for connection_key in connection_keys.into_iter() {
let info_hash_index = pareto_usize(rng, pareto, max_index); let info_hash_index = pareto_usize(rng, pareto, max_index);
let request = AnnounceRequest { let request = AnnounceRequest {
connection_id: connection_key.connection_id, connection_id: ConnectionId(0),
transaction_id: TransactionId(rng.gen()), transaction_id: TransactionId(rng.gen()),
info_hash: info_hashes[info_hash_index], info_hash: info_hashes[info_hash_index],
peer_id: PeerId(rng.gen()), peer_id: PeerId(rng.gen()),
@ -105,7 +101,10 @@ pub fn create_requests(
port: Port(rng.gen()), port: Port(rng.gen()),
}; };
requests.push((request, connection_key.socket_addr)); requests.push((
request,
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1)),
));
} }
requests requests

View file

@ -1,80 +0,0 @@
use std::time::{Duration, Instant};
use crossbeam_channel::{Receiver, Sender};
use indicatif::ProgressIterator;
use rand::{rngs::SmallRng, thread_rng, Rng, SeedableRng};
use std::net::SocketAddr;
use aquatic_udp::common::*;
use aquatic_udp::config::Config;
use crate::common::*;
use crate::config::BenchConfig;
pub fn bench_connect_handler(
bench_config: &BenchConfig,
aquatic_config: &Config,
request_sender: &Sender<(Request, SocketAddr)>,
response_receiver: &Receiver<(Response, SocketAddr)>,
) -> (usize, Duration) {
let requests = create_requests(bench_config.num_connect_requests);
let p = aquatic_config.handlers.max_requests_per_iter * bench_config.num_threads;
let mut num_responses = 0usize;
let mut dummy: i64 = thread_rng().gen();
let pb = create_progress_bar("Connect", bench_config.num_rounds as u64);
// Start connect benchmark
let before = Instant::now();
for round in (0..bench_config.num_rounds).progress_with(pb) {
for request_chunk in requests.chunks(p) {
for (request, src) in request_chunk {
request_sender.send((request.clone().into(), *src)).unwrap();
}
while let Ok((Response::Connect(r), _)) = response_receiver.try_recv() {
num_responses += 1;
dummy ^= r.connection_id.0;
}
}
let total = bench_config.num_connect_requests * (round + 1);
while num_responses < total {
if let Ok((Response::Connect(r), _)) = response_receiver.recv() {
num_responses += 1;
dummy ^= r.connection_id.0;
}
}
}
let elapsed = before.elapsed();
if dummy == 0 {
println!("dummy dummy");
}
(num_responses, elapsed)
}
pub fn create_requests(number: usize) -> Vec<(ConnectRequest, SocketAddr)> {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let mut requests = Vec::new();
for _ in 0..number {
let request = ConnectRequest {
transaction_id: TransactionId(rng.gen()),
};
let src = SocketAddr::from(([rng.gen(), rng.gen(), rng.gen(), rng.gen()], rng.gen()));
requests.push((request, src));
}
requests
}

View file

@ -2,16 +2,9 @@
//! //!
//! Example outputs: //! Example outputs:
//! ``` //! ```
//! # Results over 20 rounds with 1 threads //! # Results over 10 rounds with 2 threads
//! Connect: 2 306 637 requests/second, 433.53 ns/request //! Announce: 429 540 requests/second, 2328.07 ns/request
//! Announce: 688 391 requests/second, 1452.66 ns/request //! Scrape: 1 873 545 requests/second, 533.75 ns/request
//! Scrape: 1 505 700 requests/second, 664.14 ns/request
//! ```
//! ```
//! # Results over 20 rounds with 2 threads
//! Connect: 3 472 434 requests/second, 287.98 ns/request
//! Announce: 739 371 requests/second, 1352.50 ns/request
//! Scrape: 1 845 253 requests/second, 541.93 ns/request
//! ``` //! ```
use crossbeam_channel::unbounded; use crossbeam_channel::unbounded;
@ -29,7 +22,6 @@ use config::BenchConfig;
mod announce; mod announce;
mod common; mod common;
mod config; mod config;
mod connect;
mod scrape; mod scrape;
#[global_allocator] #[global_allocator]
@ -65,18 +57,10 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> {
// Run benchmarks // Run benchmarks
let c = connect::bench_connect_handler(
&bench_config,
&aquatic_config,
&request_sender,
&response_receiver,
);
let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let info_hashes = create_info_hashes(&mut rng); let info_hashes = create_info_hashes(&mut rng);
let a = announce::bench_announce_handler( let a = announce::bench_announce_handler(
&state,
&bench_config, &bench_config,
&aquatic_config, &aquatic_config,
&request_sender, &request_sender,
@ -86,7 +70,6 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> {
); );
let s = scrape::bench_scrape_handler( let s = scrape::bench_scrape_handler(
&state,
&bench_config, &bench_config,
&aquatic_config, &aquatic_config,
&request_sender, &request_sender,
@ -100,7 +83,6 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> {
bench_config.num_rounds, bench_config.num_threads, bench_config.num_rounds, bench_config.num_threads,
); );
print_results("Connect: ", c.0, c.1);
print_results("Announce:", a.0, a.1); print_results("Announce:", a.0, a.1);
print_results("Scrape: ", s.0, s.1); print_results("Scrape: ", s.0, s.1);

View file

@ -1,4 +1,4 @@
use std::net::SocketAddr; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use crossbeam_channel::{Receiver, Sender}; use crossbeam_channel::{Receiver, Sender};
@ -13,16 +13,14 @@ use crate::common::*;
use crate::config::BenchConfig; use crate::config::BenchConfig;
pub fn bench_scrape_handler( pub fn bench_scrape_handler(
state: &State,
bench_config: &BenchConfig, bench_config: &BenchConfig,
aquatic_config: &Config, aquatic_config: &Config,
request_sender: &Sender<(Request, SocketAddr)>, request_sender: &Sender<(ConnectedRequest, SocketAddr)>,
response_receiver: &Receiver<(Response, SocketAddr)>, response_receiver: &Receiver<(ConnectedResponse, SocketAddr)>,
rng: &mut impl Rng, rng: &mut impl Rng,
info_hashes: &[InfoHash], info_hashes: &[InfoHash],
) -> (usize, Duration) { ) -> (usize, Duration) {
let requests = create_requests( let requests = create_requests(
state,
rng, rng,
info_hashes, info_hashes,
bench_config.num_scrape_requests, bench_config.num_scrape_requests,
@ -43,10 +41,12 @@ pub fn bench_scrape_handler(
for round in (0..bench_config.num_rounds).progress_with(pb) { for round in (0..bench_config.num_rounds).progress_with(pb) {
for request_chunk in requests.chunks(p) { for request_chunk in requests.chunks(p) {
for (request, src) in request_chunk { for (request, src) in request_chunk {
request_sender.send((request.clone().into(), *src)).unwrap(); request_sender
.send((ConnectedRequest::Scrape(request.clone()), *src))
.unwrap();
} }
while let Ok((Response::Scrape(r), _)) = response_receiver.try_recv() { while let Ok((ConnectedResponse::Scrape(r), _)) = response_receiver.try_recv() {
num_responses += 1; num_responses += 1;
if let Some(stat) = r.torrent_stats.last() { if let Some(stat) = r.torrent_stats.last() {
@ -58,7 +58,7 @@ pub fn bench_scrape_handler(
let total = bench_config.num_scrape_requests * (round + 1); let total = bench_config.num_scrape_requests * (round + 1);
while num_responses < total { while num_responses < total {
if let Ok((Response::Scrape(r), _)) = response_receiver.recv() { if let Ok((ConnectedResponse::Scrape(r), _)) = response_receiver.recv() {
num_responses += 1; num_responses += 1;
if let Some(stat) = r.torrent_stats.last() { if let Some(stat) = r.torrent_stats.last() {
@ -78,7 +78,6 @@ pub fn bench_scrape_handler(
} }
pub fn create_requests( pub fn create_requests(
state: &State,
rng: &mut impl Rng, rng: &mut impl Rng,
info_hashes: &[InfoHash], info_hashes: &[InfoHash],
number: usize, number: usize,
@ -88,13 +87,9 @@ pub fn create_requests(
let max_index = info_hashes.len() - 1; let max_index = info_hashes.len() - 1;
let connections = state.connections.lock();
let connection_keys: Vec<ConnectionKey> = connections.keys().take(number).cloned().collect();
let mut requests = Vec::new(); let mut requests = Vec::new();
for connection_key in connection_keys.into_iter() { for _ in 0..number {
let mut request_info_hashes = Vec::new(); let mut request_info_hashes = Vec::new();
for _ in 0..hashes_per_request { for _ in 0..hashes_per_request {
@ -103,12 +98,15 @@ pub fn create_requests(
} }
let request = ScrapeRequest { let request = ScrapeRequest {
connection_id: connection_key.connection_id, connection_id: ConnectionId(0),
transaction_id: TransactionId(rng.gen()), transaction_id: TransactionId(rng.gen()),
info_hashes: request_info_hashes, info_hashes: request_info_hashes,
}; };
requests.push((request, connection_key.socket_addr)); requests.push((
request,
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1)),
));
} }
requests requests

View file

@ -9,6 +9,7 @@ repository = "https://github.com/greatest-ape/aquatic"
[dependencies] [dependencies]
byteorder = "1" byteorder = "1"
either = "1"
[dev-dependencies] [dev-dependencies]
quickcheck = "1.0" quickcheck = "1.0"

View file

@ -3,6 +3,7 @@ use std::io::{self, Cursor, Read, Write};
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use either::Either;
use super::common::*; use super::common::*;
@ -67,32 +68,40 @@ pub struct ScrapeRequest {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct RequestParseError { pub enum RequestParseError {
pub transaction_id: Option<TransactionId>, Sendable {
pub message: Option<String>, connection_id: ConnectionId,
pub error: Option<io::Error>, transaction_id: TransactionId,
err: Either<io::Error, &'static str>,
},
Unsendable {
err: Either<io::Error, &'static str>,
},
} }
impl RequestParseError { impl RequestParseError {
pub fn new(err: io::Error, transaction_id: i32) -> Self { pub fn sendable_io(err: io::Error, connection_id: i64, transaction_id: i32) -> Self {
Self { Self::Sendable {
transaction_id: Some(TransactionId(transaction_id)), connection_id: ConnectionId(connection_id),
message: None, transaction_id: TransactionId(transaction_id),
error: Some(err), err: Either::Left(err),
} }
} }
pub fn io(err: io::Error) -> Self { pub fn sendable_text(text: &'static str, connection_id: i64, transaction_id: i32) -> Self {
Self { Self::Sendable {
transaction_id: None, connection_id: ConnectionId(connection_id),
message: None, transaction_id: TransactionId(transaction_id),
error: Some(err), err: Either::Right(text),
} }
} }
pub fn text(transaction_id: i32, message: &str) -> Self { pub fn unsendable_io(err: io::Error) -> Self {
Self { Self::Unsendable {
transaction_id: Some(TransactionId(transaction_id)), err: Either::Left(err),
message: Some(message.to_string()), }
error: None, }
pub fn unsendable_text(text: &'static str) -> Self {
Self::Unsendable {
err: Either::Right(text),
} }
} }
} }
@ -171,13 +180,13 @@ impl Request {
let connection_id = cursor let connection_id = cursor
.read_i64::<NetworkEndian>() .read_i64::<NetworkEndian>()
.map_err(RequestParseError::io)?; .map_err(RequestParseError::unsendable_io)?;
let action = cursor let action = cursor
.read_i32::<NetworkEndian>() .read_i32::<NetworkEndian>()
.map_err(RequestParseError::io)?; .map_err(RequestParseError::unsendable_io)?;
let transaction_id = cursor let transaction_id = cursor
.read_i32::<NetworkEndian>() .read_i32::<NetworkEndian>()
.map_err(RequestParseError::io)?; .map_err(RequestParseError::unsendable_io)?;
match action { match action {
// Connect // Connect
@ -188,8 +197,7 @@ impl Request {
}) })
.into()) .into())
} else { } else {
Err(RequestParseError::text( Err(RequestParseError::unsendable_text(
transaction_id,
"Protocol identifier missing", "Protocol identifier missing",
)) ))
} }
@ -201,39 +209,39 @@ impl Request {
let mut peer_id = [0; 20]; let mut peer_id = [0; 20];
let mut ip = [0; 4]; let mut ip = [0; 4];
cursor cursor.read_exact(&mut info_hash).map_err(|err| {
.read_exact(&mut info_hash) RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
cursor cursor.read_exact(&mut peer_id).map_err(|err| {
.read_exact(&mut peer_id) RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let bytes_downloaded = cursor let bytes_downloaded = cursor.read_i64::<NetworkEndian>().map_err(|err| {
.read_i64::<NetworkEndian>() RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let bytes_left = cursor let bytes_left = cursor.read_i64::<NetworkEndian>().map_err(|err| {
.read_i64::<NetworkEndian>() RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let bytes_uploaded = cursor let bytes_uploaded = cursor.read_i64::<NetworkEndian>().map_err(|err| {
.read_i64::<NetworkEndian>() RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let event = cursor let event = cursor.read_i32::<NetworkEndian>().map_err(|err| {
.read_i32::<NetworkEndian>() RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
cursor cursor.read_exact(&mut ip).map_err(|err| {
.read_exact(&mut ip) RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let key = cursor let key = cursor.read_u32::<NetworkEndian>().map_err(|err| {
.read_u32::<NetworkEndian>() RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let peers_wanted = cursor let peers_wanted = cursor.read_i32::<NetworkEndian>().map_err(|err| {
.read_i32::<NetworkEndian>() RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let port = cursor let port = cursor.read_u16::<NetworkEndian>().map_err(|err| {
.read_u16::<NetworkEndian>() RequestParseError::sendable_io(err, connection_id, transaction_id)
.map_err(|err| RequestParseError::new(err, transaction_id))?; })?;
let opt_ip = if ip == [0; 4] { let opt_ip = if ip == [0; 4] {
None None
@ -277,7 +285,11 @@ impl Request {
.into()) .into())
} }
_ => Err(RequestParseError::text(transaction_id, "Invalid action")), _ => Err(RequestParseError::sendable_text(
"Invalid action",
connection_id,
transaction_id,
)),
} }
} }
} }