diff --git a/Cargo.lock b/Cargo.lock index f1dcab0..c8f0e2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -216,6 +216,7 @@ name = "aquatic_udp_protocol" version = "0.1.0" dependencies = [ "byteorder", + "either", "quickcheck", "quickcheck_macros", ] diff --git a/TODO.md b/TODO.md index 66c2b6f..ea65cc2 100644 --- a/TODO.md +++ b/TODO.md @@ -1,13 +1,8 @@ # TODO -* aquatic_udp needs to check connection validity before sending error responses! connection validity checks could be moved to socket workers, since theyare sharded by ip - * access lists: * use arc-swap Cache - * test functionality - * aquatic_udp - * aquatic_http - * aquatic_ws, including sending back new error responses + * add CI tests * aquatic_ws: should it send back error on message parse error, or does that just indicate that not enough data has been received yet? diff --git a/aquatic_udp/src/lib/common.rs b/aquatic_udp/src/lib/common.rs index 67f0f4b..d733c52 100644 --- a/aquatic_udp/src/lib/common.rs +++ b/aquatic_udp/src/lib/common.rs @@ -1,5 +1,5 @@ use std::hash::Hash; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::{atomic::AtomicUsize, Arc}; use std::time::Instant; @@ -30,23 +30,25 @@ impl Ip for Ipv6Addr { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ConnectionKey { - pub connection_id: ConnectionId, - pub socket_addr: SocketAddr, +pub enum ConnectedRequest { + Announce(AnnounceRequest), + Scrape(ScrapeRequest), } -impl ConnectionKey { - pub fn new(connection_id: ConnectionId, socket_addr: SocketAddr) -> Self { - Self { - connection_id, - socket_addr, +pub enum ConnectedResponse { + Announce(AnnounceResponse), + Scrape(ScrapeResponse), +} + +impl Into for ConnectedResponse { + fn into(self) -> Response { + match self { + Self::Announce(response) => Response::Announce(response), + Self::Scrape(response) => Response::Scrape(response), } } } -pub type ConnectionMap = HashMap; - #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] pub enum PeerStatus { Seeding, @@ -172,7 +174,6 @@ impl TorrentMaps { pub struct Statistics { pub requests_received: AtomicUsize, pub responses_sent: AtomicUsize, - pub readable_events: AtomicUsize, pub bytes_received: AtomicUsize, pub bytes_sent: AtomicUsize, } @@ -180,7 +181,6 @@ pub struct Statistics { #[derive(Clone)] pub struct State { pub access_list: Arc, - pub connections: Arc>, pub torrents: Arc>, pub statistics: Arc, } @@ -189,7 +189,6 @@ impl Default for State { fn default() -> Self { Self { access_list: Arc::new(AccessList::default()), - connections: Arc::new(Mutex::new(HashMap::new())), torrents: Arc::new(Mutex::new(TorrentMaps::default())), statistics: Arc::new(Statistics::default()), } diff --git a/aquatic_udp/src/lib/config.rs b/aquatic_udp/src/lib/config.rs index 5ce9685..981095e 100644 --- a/aquatic_udp/src/lib/config.rs +++ b/aquatic_udp/src/lib/config.rs @@ -84,7 +84,7 @@ pub struct StatisticsConfig { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(default)] pub struct CleaningConfig { - /// Clean torrents and connections this often (seconds) + /// Update access list and clean torrents this often (seconds) pub interval: u64, /// Remove peers that haven't announced for this long (seconds) pub max_peer_age: u64, diff --git a/aquatic_udp/src/lib/handlers/announce.rs b/aquatic_udp/src/lib/handlers/announce.rs index bda60d6..913a0d6 100644 --- a/aquatic_udp/src/lib/handlers/announce.rs +++ b/aquatic_udp/src/lib/handlers/announce.rs @@ -16,7 +16,7 @@ pub fn handle_announce_requests( torrents: &mut MutexGuard, rng: &mut SmallRng, requests: Drain<(AnnounceRequest, SocketAddr)>, - responses: &mut Vec<(Response, SocketAddr)>, + responses: &mut Vec<(ConnectedResponse, SocketAddr)>, ) { let peer_valid_until = ValidUntil::new(config.cleaning.max_peer_age); @@ -42,7 +42,7 @@ pub fn handle_announce_requests( ), }; - (Response::Announce(response), src) + (ConnectedResponse::Announce(response), src) })); } diff --git a/aquatic_udp/src/lib/handlers/connect.rs b/aquatic_udp/src/lib/handlers/connect.rs deleted file mode 100644 index e058241..0000000 --- a/aquatic_udp/src/lib/handlers/connect.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::net::SocketAddr; -use std::vec::Drain; - -use parking_lot::MutexGuard; -use rand::{rngs::StdRng, Rng}; - -use aquatic_udp_protocol::*; - -use crate::common::*; -use crate::config::Config; - -#[inline] -pub fn handle_connect_requests( - config: &Config, - connections: &mut MutexGuard, - rng: &mut StdRng, - requests: Drain<(ConnectRequest, SocketAddr)>, - responses: &mut Vec<(Response, SocketAddr)>, -) { - let valid_until = ValidUntil::new(config.cleaning.max_connection_age); - - responses.extend(requests.map(|(request, src)| { - let connection_id = ConnectionId(rng.gen()); - - let key = ConnectionKey { - connection_id, - socket_addr: src, - }; - - connections.insert(key, valid_until); - - let response = Response::Connect(ConnectResponse { - connection_id, - transaction_id: request.transaction_id, - }); - - (response, src) - })); -} diff --git a/aquatic_udp/src/lib/handlers/mod.rs b/aquatic_udp/src/lib/handlers/mod.rs index a1906d9..0634702 100644 --- a/aquatic_udp/src/lib/handlers/mod.rs +++ b/aquatic_udp/src/lib/handlers/mod.rs @@ -2,11 +2,7 @@ use std::net::SocketAddr; use std::time::Duration; use crossbeam_channel::{Receiver, Sender}; -use parking_lot::MutexGuard; -use rand::{ - rngs::{SmallRng, StdRng}, - SeedableRng, -}; +use rand::{rngs::SmallRng, SeedableRng}; use aquatic_udp_protocol::*; @@ -14,40 +10,35 @@ use crate::common::*; use crate::config::Config; mod announce; -mod connect; mod scrape; use announce::handle_announce_requests; -use connect::handle_connect_requests; use scrape::handle_scrape_requests; pub fn run_request_worker( state: State, config: Config, - request_receiver: Receiver<(Request, SocketAddr)>, - response_sender: Sender<(Response, SocketAddr)>, + request_receiver: Receiver<(ConnectedRequest, SocketAddr)>, + response_sender: Sender<(ConnectedResponse, SocketAddr)>, ) { - let mut connect_requests: Vec<(ConnectRequest, SocketAddr)> = Vec::new(); let mut announce_requests: Vec<(AnnounceRequest, SocketAddr)> = Vec::new(); let mut scrape_requests: Vec<(ScrapeRequest, SocketAddr)> = Vec::new(); + let mut responses: Vec<(ConnectedResponse, SocketAddr)> = Vec::new(); - let mut responses: Vec<(Response, SocketAddr)> = Vec::new(); - - let mut std_rng = StdRng::from_entropy(); - let mut small_rng = SmallRng::from_rng(&mut std_rng).unwrap(); + let mut small_rng = SmallRng::from_entropy(); let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds); loop { - let mut opt_connections = None; + let mut opt_torrents = None; // Collect requests from channel, divide them by type // // Collect a maximum number of request. Stop collecting before that // number is reached if having waited for too long for a request, but - // only if ConnectionMap mutex isn't locked. + // only if TorrentMaps mutex isn't locked. for i in 0..config.handlers.max_requests_per_iter { - let (request, src): (Request, SocketAddr) = if i == 0 { + let (request, src): (ConnectedRequest, SocketAddr) = if i == 0 { match request_receiver.recv() { Ok(r) => r, Err(_) => break, // Really shouldn't happen @@ -56,8 +47,8 @@ pub fn run_request_worker( match request_receiver.recv_timeout(timeout) { Ok(r) => r, Err(_) => { - if let Some(guard) = state.connections.try_lock() { - opt_connections = Some(guard); + if let Some(guard) = state.torrents.try_lock() { + opt_torrents = Some(guard); break; } else { @@ -68,59 +59,14 @@ pub fn run_request_worker( }; match request { - Request::Connect(r) => connect_requests.push((r, src)), - Request::Announce(r) => announce_requests.push((r, src)), - Request::Scrape(r) => scrape_requests.push((r, src)), + ConnectedRequest::Announce(r) => announce_requests.push((r, src)), + ConnectedRequest::Scrape(r) => scrape_requests.push((r, src)), } } - let mut connections: MutexGuard = - opt_connections.unwrap_or_else(|| state.connections.lock()); - - handle_connect_requests( - &config, - &mut connections, - &mut std_rng, - connect_requests.drain(..), - &mut responses, - ); - - // Check announce and scrape requests for valid connections - - announce_requests.retain(|(request, src)| { - let connection_valid = - connections.contains_key(&ConnectionKey::new(request.connection_id, *src)); - - if !connection_valid { - responses.push(( - create_invalid_connection_response(request.transaction_id), - *src, - )); - } - - connection_valid - }); - - scrape_requests.retain(|(request, src)| { - let connection_valid = - connections.contains_key(&ConnectionKey::new(request.connection_id, *src)); - - if !connection_valid { - responses.push(( - create_invalid_connection_response(request.transaction_id), - *src, - )); - } - - connection_valid - }); - - ::std::mem::drop(connections); - - // Generate responses for announce and scrape requests - - if !(announce_requests.is_empty() && scrape_requests.is_empty()) { - let mut torrents = state.torrents.lock(); + // Generate responses for announce and scrape requests, then drop MutexGuard. + { + let mut torrents = opt_torrents.unwrap_or_else(|| state.torrents.lock()); handle_announce_requests( &config, @@ -129,6 +75,7 @@ pub fn run_request_worker( announce_requests.drain(..), &mut responses, ); + handle_scrape_requests(&mut torrents, scrape_requests.drain(..), &mut responses); } @@ -139,10 +86,3 @@ pub fn run_request_worker( } } } - -fn create_invalid_connection_response(transaction_id: TransactionId) -> Response { - Response::Error(ErrorResponse { - transaction_id, - message: "Connection invalid or expired".into(), - }) -} diff --git a/aquatic_udp/src/lib/handlers/scrape.rs b/aquatic_udp/src/lib/handlers/scrape.rs index 8198bd8..b544ccf 100644 --- a/aquatic_udp/src/lib/handlers/scrape.rs +++ b/aquatic_udp/src/lib/handlers/scrape.rs @@ -12,7 +12,7 @@ use crate::common::*; pub fn handle_scrape_requests( torrents: &mut MutexGuard, requests: Drain<(ScrapeRequest, SocketAddr)>, - responses: &mut Vec<(Response, SocketAddr)>, + responses: &mut Vec<(ConnectedResponse, SocketAddr)>, ) { let empty_stats = create_torrent_scrape_statistics(0, 0); @@ -45,7 +45,7 @@ pub fn handle_scrape_requests( } } - let response = Response::Scrape(ScrapeResponse { + let response = ConnectedResponse::Scrape(ScrapeResponse { transaction_id: request.transaction_id, torrent_stats: stats, }); diff --git a/aquatic_udp/src/lib/lib.rs b/aquatic_udp/src/lib/lib.rs index 5e32d83..a740a3e 100644 --- a/aquatic_udp/src/lib/lib.rs +++ b/aquatic_udp/src/lib/lib.rs @@ -23,7 +23,7 @@ pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker"; pub fn run(config: Config) -> ::anyhow::Result<()> { let state = State::default(); - tasks::update_access_list(&config, &state); + tasks::update_access_list(&config, &state.access_list); let num_bound_sockets = start_workers(config.clone(), state.clone())?; @@ -55,8 +55,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { loop { ::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); - tasks::clean_connections(&state); - tasks::update_access_list(&config, &state); + tasks::update_access_list(&config, &state.access_list); state.torrents.lock().clean(&config, &state.access_list); } diff --git a/aquatic_udp/src/lib/network.rs b/aquatic_udp/src/lib/network.rs index b1baf65..da6bd2a 100644 --- a/aquatic_udp/src/lib/network.rs +++ b/aquatic_udp/src/lib/network.rs @@ -4,12 +4,14 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; -use std::time::Duration; +use std::time::{Duration, Instant}; use std::vec::Drain; use crossbeam_channel::{Receiver, Sender}; +use hashbrown::HashMap; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; +use rand::prelude::{Rng, SeedableRng, StdRng}; use socket2::{Domain, Protocol, Socket, Type}; use aquatic_udp_protocol::{IpVersion, Request, Response}; @@ -17,14 +19,40 @@ use aquatic_udp_protocol::{IpVersion, Request, Response}; use crate::common::*; use crate::config::Config; +#[derive(Default)] +struct ConnectionMap(HashMap<(ConnectionId, SocketAddr), ValidUntil>); + +impl ConnectionMap { + fn insert( + &mut self, + connection_id: ConnectionId, + socket_addr: SocketAddr, + valid_until: ValidUntil, + ) { + self.0.insert((connection_id, socket_addr), valid_until); + } + + fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool { + self.0.contains_key(&(connection_id, socket_addr)) + } + + fn clean(&mut self) { + let now = Instant::now(); + + self.0.retain(|_, v| v.0 > now); + self.0.shrink_to_fit(); + } +} + pub fn run_socket_worker( state: State, config: Config, token_num: usize, - request_sender: Sender<(Request, SocketAddr)>, - response_receiver: Receiver<(Response, SocketAddr)>, + request_sender: Sender<(ConnectedRequest, SocketAddr)>, + response_receiver: Receiver<(ConnectedResponse, SocketAddr)>, num_bound_sockets: Arc, ) { + let mut rng = StdRng::from_entropy(); let mut buffer = [0u8; MAX_PACKET_SIZE]; let mut socket = UdpSocket::from_std(create_socket(&config)); @@ -39,12 +67,14 @@ pub fn run_socket_worker( num_bound_sockets.fetch_add(1, Ordering::SeqCst); let mut events = Events::with_capacity(config.network.poll_event_capacity); + let mut connections = ConnectionMap::default(); - let mut requests: Vec<(Request, SocketAddr)> = Vec::new(); let mut local_responses: Vec<(Response, SocketAddr)> = Vec::new(); let timeout = Duration::from_millis(50); + let mut iter_counter = 0usize; + loop { poll.poll(&mut events, Some(timeout)) .expect("failed polling"); @@ -54,24 +84,15 @@ pub fn run_socket_worker( if (token.0 == token_num) & event.is_readable() { read_requests( - &state, &config, + &state, + &mut connections, + &mut rng, &mut socket, &mut buffer, - &mut requests, + &request_sender, &mut local_responses, ); - - for r in requests.drain(..) { - if let Err(err) = request_sender.send(r) { - ::log::error!("error sending to request_sender: {}", err); - } - } - - state - .statistics - .readable_events - .fetch_add(1, Ordering::SeqCst); } } @@ -83,6 +104,14 @@ pub fn run_socket_worker( &response_receiver, local_responses.drain(..), ); + + iter_counter += 1; + + if iter_counter == 1000 { + connections.clean(); + + iter_counter = 0; + } } } @@ -121,16 +150,19 @@ fn create_socket(config: &Config) -> ::std::net::UdpSocket { #[inline] fn read_requests( - state: &State, config: &Config, + state: &State, + connections: &mut ConnectionMap, + rng: &mut StdRng, socket: &mut UdpSocket, buffer: &mut [u8], - requests: &mut Vec<(Request, SocketAddr)>, + request_sender: &Sender<(ConnectedRequest, SocketAddr)>, local_responses: &mut Vec<(Response, SocketAddr)>, ) { let mut requests_received: usize = 0; let mut bytes_received: usize = 0; + let valid_until = ValidUntil::new(config.cleaning.max_connection_age); let access_list_mode = config.access_list.mode; loop { @@ -146,37 +178,61 @@ fn read_requests( } match request { - Ok(Request::Announce(AnnounceRequest { - info_hash, - transaction_id, - .. - })) if !state.access_list.allows(access_list_mode, &info_hash.0) => { - let response = Response::Error(ErrorResponse { - transaction_id, - message: "Info hash not allowed".into(), + Ok(Request::Connect(request)) => { + let connection_id = ConnectionId(rng.gen()); + + connections.insert(connection_id, src, valid_until); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, }); local_responses.push((response, src)) } - Ok(request) => { - requests.push((request, src)); + Ok(Request::Announce(request)) => { + if connections.contains(request.connection_id, src) { + if state + .access_list + .allows(access_list_mode, &request.info_hash.0) + { + if let Err(err) = request_sender + .try_send((ConnectedRequest::Announce(request), src)) + { + ::log::warn!("request_sender.try_send failed: {:?}", err) + } + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + local_responses.push((response, src)) + } + } + } + Ok(Request::Scrape(request)) => { + if connections.contains(request.connection_id, src) { + if let Err(err) = + request_sender.try_send((ConnectedRequest::Scrape(request), src)) + { + ::log::warn!("request_sender.try_send failed: {:?}", err) + } + } } Err(err) => { - ::log::debug!("request_from_bytes error: {:?}", err); + ::log::debug!("Request::from_bytes error: {:?}", err); - if let Some(transaction_id) = err.transaction_id { - let opt_message = if err.error.is_some() { - Some("Parse error".into()) - } else if let Some(message) = err.message { - Some(message.into()) - } else { - None - }; - - if let Some(message) = opt_message { + if let RequestParseError::Sendable { + connection_id, + transaction_id, + err, + } = err + { + if connections.contains(connection_id, src) { let response = ErrorResponse { transaction_id, - message, + message: err.right_or("Parse error").into(), }; local_responses.push((response.into(), src)); @@ -213,7 +269,7 @@ fn send_responses( config: &Config, socket: &mut UdpSocket, buffer: &mut [u8], - response_receiver: &Receiver<(Response, SocketAddr)>, + response_receiver: &Receiver<(ConnectedResponse, SocketAddr)>, local_responses: Drain<(Response, SocketAddr)>, ) { let mut responses_sent: usize = 0; @@ -221,30 +277,37 @@ fn send_responses( let mut cursor = Cursor::new(buffer); - let response_iterator = local_responses - .into_iter() - .chain(response_receiver.try_iter()); + let response_iterator = local_responses.into_iter().chain( + response_receiver + .try_iter() + .map(|(response, addr)| (response.into(), addr)), + ); for (response, src) in response_iterator { cursor.set_position(0); let ip_version = ip_version_from_ip(src.ip()); - response.write(&mut cursor, ip_version).unwrap(); + match response.write(&mut cursor, ip_version) { + Ok(()) => { + let amt = cursor.position() as usize; - let amt = cursor.position() as usize; + match socket.send_to(&cursor.get_ref()[..amt], src) { + Ok(amt) => { + responses_sent += 1; + bytes_sent += amt; + } + Err(err) => { + if err.kind() == ErrorKind::WouldBlock { + break; + } - match socket.send_to(&cursor.get_ref()[..amt], src) { - Ok(amt) => { - responses_sent += 1; - bytes_sent += amt; + ::log::info!("send_to error: {}", err); + } + } } Err(err) => { - if err.kind() == ErrorKind::WouldBlock { - break; - } - - ::log::info!("send_to error: {}", err); + ::log::error!("Response::write error: {:?}", err); } } } diff --git a/aquatic_udp/src/lib/tasks.rs b/aquatic_udp/src/lib/tasks.rs index 83acd3b..2665c45 100644 --- a/aquatic_udp/src/lib/tasks.rs +++ b/aquatic_udp/src/lib/tasks.rs @@ -1,5 +1,5 @@ use std::sync::atomic::Ordering; -use std::time::Instant; +use std::sync::Arc; use histogram::Histogram; @@ -8,10 +8,10 @@ use aquatic_common::access_list::AccessListMode; use crate::common::*; use crate::config::Config; -pub fn update_access_list(config: &Config, state: &State) { +pub fn update_access_list(config: &Config, access_list: &Arc) { match config.access_list.mode { AccessListMode::White | AccessListMode::Black => { - if let Err(err) = state.access_list.update_from_path(&config.access_list.path) { + if let Err(err) = access_list.update_from_path(&config.access_list.path) { ::log::error!("Update access list from path: {:?}", err); } } @@ -19,15 +19,6 @@ pub fn update_access_list(config: &Config, state: &State) { } } -pub fn clean_connections(state: &State) { - let now = Instant::now(); - - let mut connections = state.connections.lock(); - - connections.retain(|_, v| v.0 > now); - connections.shrink_to_fit(); -} - pub fn gather_and_print_statistics(state: &State, config: &Config) { let interval = config.statistics.interval; @@ -50,19 +41,9 @@ pub fn gather_and_print_statistics(state: &State, config: &Config) { let bytes_received_per_second: f64 = bytes_received / interval as f64; let bytes_sent_per_second: f64 = bytes_sent / interval as f64; - let readable_events: f64 = state - .statistics - .readable_events - .fetch_and(0, Ordering::SeqCst) as f64; - let requests_per_readable_event = if readable_events == 0.0 { - 0.0 - } else { - requests_received / readable_events - }; - println!( - "stats: {:.2} requests/second, {:.2} responses/second, {:.2} requests/readable event", - requests_per_second, responses_per_second, requests_per_readable_event + "stats: {:.2} requests/second, {:.2} responses/second", + requests_per_second, responses_per_second ); println!( diff --git a/aquatic_udp_bench/src/announce.rs b/aquatic_udp_bench/src/announce.rs index 8288713..12b35e3 100644 --- a/aquatic_udp_bench/src/announce.rs +++ b/aquatic_udp_bench/src/announce.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::time::{Duration, Instant}; use crossbeam_channel::{Receiver, Sender}; @@ -13,15 +13,14 @@ use crate::common::*; use crate::config::BenchConfig; pub fn bench_announce_handler( - state: &State, bench_config: &BenchConfig, aquatic_config: &Config, - request_sender: &Sender<(Request, SocketAddr)>, - response_receiver: &Receiver<(Response, SocketAddr)>, + request_sender: &Sender<(ConnectedRequest, SocketAddr)>, + response_receiver: &Receiver<(ConnectedResponse, SocketAddr)>, rng: &mut impl Rng, info_hashes: &[InfoHash], ) -> (usize, Duration) { - let requests = create_requests(state, rng, info_hashes, bench_config.num_announce_requests); + let requests = create_requests(rng, info_hashes, bench_config.num_announce_requests); let p = aquatic_config.handlers.max_requests_per_iter * bench_config.num_threads; let mut num_responses = 0usize; @@ -37,10 +36,12 @@ pub fn bench_announce_handler( for round in (0..bench_config.num_rounds).progress_with(pb) { for request_chunk in requests.chunks(p) { for (request, src) in request_chunk { - request_sender.send((request.clone().into(), *src)).unwrap(); + request_sender + .send((ConnectedRequest::Announce(request.clone()), *src)) + .unwrap(); } - while let Ok((Response::Announce(r), _)) = response_receiver.try_recv() { + while let Ok((ConnectedResponse::Announce(r), _)) = response_receiver.try_recv() { num_responses += 1; if let Some(last_peer) = r.peers.last() { @@ -52,7 +53,7 @@ pub fn bench_announce_handler( let total = bench_config.num_announce_requests * (round + 1); while num_responses < total { - if let Ok((Response::Announce(r), _)) = response_receiver.recv() { + if let Ok((ConnectedResponse::Announce(r), _)) = response_receiver.recv() { num_responses += 1; if let Some(last_peer) = r.peers.last() { @@ -72,7 +73,6 @@ pub fn bench_announce_handler( } pub fn create_requests( - state: &State, rng: &mut impl Rng, info_hashes: &[InfoHash], number: usize, @@ -83,15 +83,11 @@ pub fn create_requests( let mut requests = Vec::new(); - let connections = state.connections.lock(); - - let connection_keys: Vec = connections.keys().take(number).cloned().collect(); - - for connection_key in connection_keys.into_iter() { + for _ in 0..number { let info_hash_index = pareto_usize(rng, pareto, max_index); let request = AnnounceRequest { - connection_id: connection_key.connection_id, + connection_id: ConnectionId(0), transaction_id: TransactionId(rng.gen()), info_hash: info_hashes[info_hash_index], peer_id: PeerId(rng.gen()), @@ -105,7 +101,10 @@ pub fn create_requests( port: Port(rng.gen()), }; - requests.push((request, connection_key.socket_addr)); + requests.push(( + request, + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1)), + )); } requests diff --git a/aquatic_udp_bench/src/connect.rs b/aquatic_udp_bench/src/connect.rs deleted file mode 100644 index 3f3bc05..0000000 --- a/aquatic_udp_bench/src/connect.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::time::{Duration, Instant}; - -use crossbeam_channel::{Receiver, Sender}; -use indicatif::ProgressIterator; -use rand::{rngs::SmallRng, thread_rng, Rng, SeedableRng}; -use std::net::SocketAddr; - -use aquatic_udp::common::*; -use aquatic_udp::config::Config; - -use crate::common::*; -use crate::config::BenchConfig; - -pub fn bench_connect_handler( - bench_config: &BenchConfig, - aquatic_config: &Config, - request_sender: &Sender<(Request, SocketAddr)>, - response_receiver: &Receiver<(Response, SocketAddr)>, -) -> (usize, Duration) { - let requests = create_requests(bench_config.num_connect_requests); - - let p = aquatic_config.handlers.max_requests_per_iter * bench_config.num_threads; - let mut num_responses = 0usize; - - let mut dummy: i64 = thread_rng().gen(); - - let pb = create_progress_bar("Connect", bench_config.num_rounds as u64); - - // Start connect benchmark - - let before = Instant::now(); - - for round in (0..bench_config.num_rounds).progress_with(pb) { - for request_chunk in requests.chunks(p) { - for (request, src) in request_chunk { - request_sender.send((request.clone().into(), *src)).unwrap(); - } - - while let Ok((Response::Connect(r), _)) = response_receiver.try_recv() { - num_responses += 1; - dummy ^= r.connection_id.0; - } - } - - let total = bench_config.num_connect_requests * (round + 1); - - while num_responses < total { - if let Ok((Response::Connect(r), _)) = response_receiver.recv() { - num_responses += 1; - dummy ^= r.connection_id.0; - } - } - } - - let elapsed = before.elapsed(); - - if dummy == 0 { - println!("dummy dummy"); - } - - (num_responses, elapsed) -} - -pub fn create_requests(number: usize) -> Vec<(ConnectRequest, SocketAddr)> { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - - let mut requests = Vec::new(); - - for _ in 0..number { - let request = ConnectRequest { - transaction_id: TransactionId(rng.gen()), - }; - - let src = SocketAddr::from(([rng.gen(), rng.gen(), rng.gen(), rng.gen()], rng.gen())); - - requests.push((request, src)); - } - - requests -} diff --git a/aquatic_udp_bench/src/main.rs b/aquatic_udp_bench/src/main.rs index a64cb07..c89de8b 100644 --- a/aquatic_udp_bench/src/main.rs +++ b/aquatic_udp_bench/src/main.rs @@ -2,16 +2,9 @@ //! //! Example outputs: //! ``` -//! # Results over 20 rounds with 1 threads -//! Connect: 2 306 637 requests/second, 433.53 ns/request -//! Announce: 688 391 requests/second, 1452.66 ns/request -//! Scrape: 1 505 700 requests/second, 664.14 ns/request -//! ``` -//! ``` -//! # Results over 20 rounds with 2 threads -//! Connect: 3 472 434 requests/second, 287.98 ns/request -//! Announce: 739 371 requests/second, 1352.50 ns/request -//! Scrape: 1 845 253 requests/second, 541.93 ns/request +//! # Results over 10 rounds with 2 threads +//! Announce: 429 540 requests/second, 2328.07 ns/request +//! Scrape: 1 873 545 requests/second, 533.75 ns/request //! ``` use crossbeam_channel::unbounded; @@ -29,7 +22,6 @@ use config::BenchConfig; mod announce; mod common; mod config; -mod connect; mod scrape; #[global_allocator] @@ -65,18 +57,10 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> { // Run benchmarks - let c = connect::bench_connect_handler( - &bench_config, - &aquatic_config, - &request_sender, - &response_receiver, - ); - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let info_hashes = create_info_hashes(&mut rng); let a = announce::bench_announce_handler( - &state, &bench_config, &aquatic_config, &request_sender, @@ -86,7 +70,6 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> { ); let s = scrape::bench_scrape_handler( - &state, &bench_config, &aquatic_config, &request_sender, @@ -100,7 +83,6 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> { bench_config.num_rounds, bench_config.num_threads, ); - print_results("Connect: ", c.0, c.1); print_results("Announce:", a.0, a.1); print_results("Scrape: ", s.0, s.1); diff --git a/aquatic_udp_bench/src/scrape.rs b/aquatic_udp_bench/src/scrape.rs index 09534cc..7b62152 100644 --- a/aquatic_udp_bench/src/scrape.rs +++ b/aquatic_udp_bench/src/scrape.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::time::{Duration, Instant}; use crossbeam_channel::{Receiver, Sender}; @@ -13,16 +13,14 @@ use crate::common::*; use crate::config::BenchConfig; pub fn bench_scrape_handler( - state: &State, bench_config: &BenchConfig, aquatic_config: &Config, - request_sender: &Sender<(Request, SocketAddr)>, - response_receiver: &Receiver<(Response, SocketAddr)>, + request_sender: &Sender<(ConnectedRequest, SocketAddr)>, + response_receiver: &Receiver<(ConnectedResponse, SocketAddr)>, rng: &mut impl Rng, info_hashes: &[InfoHash], ) -> (usize, Duration) { let requests = create_requests( - state, rng, info_hashes, bench_config.num_scrape_requests, @@ -43,10 +41,12 @@ pub fn bench_scrape_handler( for round in (0..bench_config.num_rounds).progress_with(pb) { for request_chunk in requests.chunks(p) { for (request, src) in request_chunk { - request_sender.send((request.clone().into(), *src)).unwrap(); + request_sender + .send((ConnectedRequest::Scrape(request.clone()), *src)) + .unwrap(); } - while let Ok((Response::Scrape(r), _)) = response_receiver.try_recv() { + while let Ok((ConnectedResponse::Scrape(r), _)) = response_receiver.try_recv() { num_responses += 1; if let Some(stat) = r.torrent_stats.last() { @@ -58,7 +58,7 @@ pub fn bench_scrape_handler( let total = bench_config.num_scrape_requests * (round + 1); while num_responses < total { - if let Ok((Response::Scrape(r), _)) = response_receiver.recv() { + if let Ok((ConnectedResponse::Scrape(r), _)) = response_receiver.recv() { num_responses += 1; if let Some(stat) = r.torrent_stats.last() { @@ -78,7 +78,6 @@ pub fn bench_scrape_handler( } pub fn create_requests( - state: &State, rng: &mut impl Rng, info_hashes: &[InfoHash], number: usize, @@ -88,13 +87,9 @@ pub fn create_requests( let max_index = info_hashes.len() - 1; - let connections = state.connections.lock(); - - let connection_keys: Vec = connections.keys().take(number).cloned().collect(); - let mut requests = Vec::new(); - for connection_key in connection_keys.into_iter() { + for _ in 0..number { let mut request_info_hashes = Vec::new(); for _ in 0..hashes_per_request { @@ -103,12 +98,15 @@ pub fn create_requests( } let request = ScrapeRequest { - connection_id: connection_key.connection_id, + connection_id: ConnectionId(0), transaction_id: TransactionId(rng.gen()), info_hashes: request_info_hashes, }; - requests.push((request, connection_key.socket_addr)); + requests.push(( + request, + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1)), + )); } requests diff --git a/aquatic_udp_protocol/Cargo.toml b/aquatic_udp_protocol/Cargo.toml index 12d11ce..3cf3f43 100644 --- a/aquatic_udp_protocol/Cargo.toml +++ b/aquatic_udp_protocol/Cargo.toml @@ -9,6 +9,7 @@ repository = "https://github.com/greatest-ape/aquatic" [dependencies] byteorder = "1" +either = "1" [dev-dependencies] quickcheck = "1.0" diff --git a/aquatic_udp_protocol/src/request.rs b/aquatic_udp_protocol/src/request.rs index d88ad8d..8d74196 100644 --- a/aquatic_udp_protocol/src/request.rs +++ b/aquatic_udp_protocol/src/request.rs @@ -3,6 +3,7 @@ use std::io::{self, Cursor, Read, Write}; use std::net::Ipv4Addr; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +use either::Either; use super::common::*; @@ -67,32 +68,40 @@ pub struct ScrapeRequest { } #[derive(Debug)] -pub struct RequestParseError { - pub transaction_id: Option, - pub message: Option, - pub error: Option, +pub enum RequestParseError { + Sendable { + connection_id: ConnectionId, + transaction_id: TransactionId, + err: Either, + }, + Unsendable { + err: Either, + }, } impl RequestParseError { - pub fn new(err: io::Error, transaction_id: i32) -> Self { - Self { - transaction_id: Some(TransactionId(transaction_id)), - message: None, - error: Some(err), + pub fn sendable_io(err: io::Error, connection_id: i64, transaction_id: i32) -> Self { + Self::Sendable { + connection_id: ConnectionId(connection_id), + transaction_id: TransactionId(transaction_id), + err: Either::Left(err), } } - pub fn io(err: io::Error) -> Self { - Self { - transaction_id: None, - message: None, - error: Some(err), + pub fn sendable_text(text: &'static str, connection_id: i64, transaction_id: i32) -> Self { + Self::Sendable { + connection_id: ConnectionId(connection_id), + transaction_id: TransactionId(transaction_id), + err: Either::Right(text), } } - pub fn text(transaction_id: i32, message: &str) -> Self { - Self { - transaction_id: Some(TransactionId(transaction_id)), - message: Some(message.to_string()), - error: None, + pub fn unsendable_io(err: io::Error) -> Self { + Self::Unsendable { + err: Either::Left(err), + } + } + pub fn unsendable_text(text: &'static str) -> Self { + Self::Unsendable { + err: Either::Right(text), } } } @@ -171,13 +180,13 @@ impl Request { let connection_id = cursor .read_i64::() - .map_err(RequestParseError::io)?; + .map_err(RequestParseError::unsendable_io)?; let action = cursor .read_i32::() - .map_err(RequestParseError::io)?; + .map_err(RequestParseError::unsendable_io)?; let transaction_id = cursor .read_i32::() - .map_err(RequestParseError::io)?; + .map_err(RequestParseError::unsendable_io)?; match action { // Connect @@ -188,8 +197,7 @@ impl Request { }) .into()) } else { - Err(RequestParseError::text( - transaction_id, + Err(RequestParseError::unsendable_text( "Protocol identifier missing", )) } @@ -201,39 +209,39 @@ impl Request { let mut peer_id = [0; 20]; let mut ip = [0; 4]; - cursor - .read_exact(&mut info_hash) - .map_err(|err| RequestParseError::new(err, transaction_id))?; - cursor - .read_exact(&mut peer_id) - .map_err(|err| RequestParseError::new(err, transaction_id))?; + cursor.read_exact(&mut info_hash).map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; + cursor.read_exact(&mut peer_id).map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; - let bytes_downloaded = cursor - .read_i64::() - .map_err(|err| RequestParseError::new(err, transaction_id))?; - let bytes_left = cursor - .read_i64::() - .map_err(|err| RequestParseError::new(err, transaction_id))?; - let bytes_uploaded = cursor - .read_i64::() - .map_err(|err| RequestParseError::new(err, transaction_id))?; - let event = cursor - .read_i32::() - .map_err(|err| RequestParseError::new(err, transaction_id))?; + let bytes_downloaded = cursor.read_i64::().map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; + let bytes_left = cursor.read_i64::().map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; + let bytes_uploaded = cursor.read_i64::().map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; + let event = cursor.read_i32::().map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; - cursor - .read_exact(&mut ip) - .map_err(|err| RequestParseError::new(err, transaction_id))?; + cursor.read_exact(&mut ip).map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; - let key = cursor - .read_u32::() - .map_err(|err| RequestParseError::new(err, transaction_id))?; - let peers_wanted = cursor - .read_i32::() - .map_err(|err| RequestParseError::new(err, transaction_id))?; - let port = cursor - .read_u16::() - .map_err(|err| RequestParseError::new(err, transaction_id))?; + let key = cursor.read_u32::().map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; + let peers_wanted = cursor.read_i32::().map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; + let port = cursor.read_u16::().map_err(|err| { + RequestParseError::sendable_io(err, connection_id, transaction_id) + })?; let opt_ip = if ip == [0; 4] { None @@ -277,7 +285,11 @@ impl Request { .into()) } - _ => Err(RequestParseError::text(transaction_id, "Invalid action")), + _ => Err(RequestParseError::sendable_text( + "Invalid action", + connection_id, + transaction_id, + )), } } }