diff --git a/aquatic/src/lib/common.rs b/aquatic/src/lib/common.rs index f3e1e73..7c7a491 100644 --- a/aquatic/src/lib/common.rs +++ b/aquatic/src/lib/common.rs @@ -126,25 +126,10 @@ pub struct Statistics { } -pub struct HandlerData { - pub connections: ConnectionMap, - pub torrents: TorrentMap, -} - - -impl Default for HandlerData { - fn default() -> Self { - Self { - connections: HashMap::new(), - torrents: HashMap::new(), - } - } -} - - #[derive(Clone)] pub struct State { - pub handler_data: Arc>, + pub connections: Arc>, + pub torrents: Arc>, pub statistics: Arc, } @@ -152,7 +137,8 @@ pub struct State { impl State { pub fn new() -> Self { Self { - handler_data: Arc::new(Mutex::new(HandlerData::default())), + connections: Arc::new(Mutex::new(HashMap::new())), + torrents: Arc::new(Mutex::new(HashMap::new())), statistics: Arc::new(Statistics::default()), } } diff --git a/aquatic/src/lib/handlers.rs b/aquatic/src/lib/handlers.rs index 6882da9..c1699de 100644 --- a/aquatic/src/lib/handlers.rs +++ b/aquatic/src/lib/handlers.rs @@ -33,13 +33,13 @@ pub fn run_request_worker( ); loop { - let mut opt_data = None; + let mut opt_connections = None; // Collect requests from channel, divide them by type // // Collect a maximum number of request. Stop collecting before that // number is reached if having waited for too long for a request, but - // only if HandlerData mutex isn't locked. + // only if ConnectionMap mutex isn't locked. for i in 0..config.handlers.max_requests_per_iter { let (request, src): (Request, SocketAddr) = if i == 0 { match request_receiver.recv(){ @@ -50,8 +50,8 @@ pub fn run_request_worker( match request_receiver.recv_timeout(timeout){ Ok(r) => r, Err(_) => { - if let Some(data) = state.handler_data.try_lock(){ - opt_data = Some(data); + if let Some(guard) = state.connections.try_lock(){ + opt_connections = Some(guard); break } else { @@ -74,32 +74,76 @@ pub fn run_request_worker( } } - let mut data: MutexGuard = opt_data.unwrap_or_else(|| - state.handler_data.lock() + let mut connections: MutexGuard = opt_connections.unwrap_or_else(|| + state.connections.lock() ); handle_connect_requests( &config, - &mut data, + &mut connections, &mut std_rng, connect_requests.drain(..), &mut responses ); - handle_announce_requests( - &config, - &mut data, - &mut small_rng, - announce_requests.drain(..), - &mut responses - ); - handle_scrape_requests( - &mut data, - scrape_requests.drain(..), - &mut responses - ); + announce_requests.retain(|(request, src)| { + let connection_key = ConnectionKey { + connection_id: request.connection_id, + socket_addr: *src, + }; + + if connections.contains_key(&connection_key){ + true + } else { + let response = ErrorResponse { + transaction_id: request.transaction_id, + message: "Connection invalid or expired".to_string() + }; + + responses.push((response.into(), *src)); - ::std::mem::drop(data); + false + } + }); + + scrape_requests.retain(|(request, src)| { + let connection_key = ConnectionKey { + connection_id: request.connection_id, + socket_addr: *src, + }; + + if connections.contains_key(&connection_key){ + true + } else { + let response = ErrorResponse { + transaction_id: request.transaction_id, + message: "Connection invalid or expired".to_string() + }; + + responses.push((response.into(), *src)); + + false + } + }); + + ::std::mem::drop(connections); + + if !(announce_requests.is_empty() && scrape_requests.is_empty()){ + let mut torrents = state.torrents.lock(); + + handle_announce_requests( + &config, + &mut torrents, + &mut small_rng, + announce_requests.drain(..), + &mut responses + ); + handle_scrape_requests( + &mut torrents, + scrape_requests.drain(..), + &mut responses + ); + } for r in responses.drain(..){ if let Err(err) = response_sender.send(r){ @@ -113,7 +157,7 @@ pub fn run_request_worker( #[inline] pub fn handle_connect_requests( config: &Config, - data: &mut MutexGuard, + connections: &mut MutexGuard, rng: &mut StdRng, requests: Drain<(ConnectRequest, SocketAddr)>, responses: &mut Vec<(Response, SocketAddr)>, @@ -128,7 +172,7 @@ pub fn handle_connect_requests( socket_addr: src, }; - data.connections.insert(key, valid_until); + connections.insert(key, valid_until); let response = Response::Connect( ConnectResponse { @@ -145,7 +189,7 @@ pub fn handle_connect_requests( #[inline] pub fn handle_announce_requests( config: &Config, - data: &mut MutexGuard, + torrents: &mut MutexGuard, rng: &mut SmallRng, requests: Drain<(AnnounceRequest, SocketAddr)>, responses: &mut Vec<(Response, SocketAddr)>, @@ -153,20 +197,6 @@ pub fn handle_announce_requests( let peer_valid_until = ValidUntil::new(config.cleaning.max_peer_age); responses.extend(requests.map(|(request, src)| { - let connection_key = ConnectionKey { - connection_id: request.connection_id, - socket_addr: src, - }; - - if !data.connections.contains_key(&connection_key){ - let response = ErrorResponse { - transaction_id: request.transaction_id, - message: "Connection invalid or expired".to_string() - }; - - return (response.into(), src); - } - let peer_ip = src.ip(); let peer_key = PeerMapKey { @@ -186,7 +216,7 @@ pub fn handle_announce_requests( valid_until: peer_valid_until, }; - let torrent_data = data.torrents + let torrent_data = torrents .entry(request.info_hash) .or_default(); @@ -242,33 +272,19 @@ pub fn handle_announce_requests( #[inline] pub fn handle_scrape_requests( - data: &mut MutexGuard, + torrents: &mut MutexGuard, requests: Drain<(ScrapeRequest, SocketAddr)>, responses: &mut Vec<(Response, SocketAddr)>, ){ let empty_stats = create_torrent_scrape_statistics(0, 0); responses.extend(requests.map(|(request, src)|{ - let connection_key = ConnectionKey { - connection_id: request.connection_id, - socket_addr: src, - }; - - if !data.connections.contains_key(&connection_key){ - let response = ErrorResponse { - transaction_id: request.transaction_id, - message: "Connection invalid or expired".to_string() - }; - - return (response.into(), src); - } - let mut stats: Vec = Vec::with_capacity( request.info_hashes.len() ); for info_hash in request.info_hashes.iter() { - if let Some(torrent_data) = data.torrents.get(info_hash){ + if let Some(torrent_data) = torrents.get(info_hash){ stats.push(create_torrent_scrape_statistics( torrent_data.num_seeders.load(Ordering::SeqCst) as i32, torrent_data.num_leechers.load(Ordering::SeqCst) as i32, diff --git a/aquatic/src/lib/tasks.rs b/aquatic/src/lib/tasks.rs index 1d758bb..e8810ee 100644 --- a/aquatic/src/lib/tasks.rs +++ b/aquatic/src/lib/tasks.rs @@ -10,12 +10,16 @@ use crate::config::Config; pub fn clean_connections_and_torrents(state: &State){ let now = Instant::now(); - let mut data = state.handler_data.lock(); + let mut connections = state.connections.lock(); - data.connections.retain(|_, v| v.0 > now); - data.connections.shrink_to_fit(); + connections.retain(|_, v| v.0 > now); + connections.shrink_to_fit(); - data.torrents.retain(|_, torrent| { + ::std::mem::drop(connections); + + let mut torrents = state.torrents.lock(); + + torrents.retain(|_, torrent| { let num_seeders = &torrent.num_seeders; let num_leechers = &torrent.num_leechers; @@ -40,7 +44,7 @@ pub fn clean_connections_and_torrents(state: &State){ !torrent.peers.is_empty() }); - data.torrents.shrink_to_fit(); + torrents.shrink_to_fit(); } @@ -87,7 +91,7 @@ pub fn gather_and_print_statistics( let mut peers_per_torrent = Histogram::new(); - let torrents = &mut state.handler_data.lock().torrents; + let torrents = &mut state.torrents.lock(); for torrent in torrents.values(){ let num_seeders = torrent.num_seeders.load(Ordering::SeqCst); @@ -100,6 +104,8 @@ pub fn gather_and_print_statistics( } } + ::std::mem::drop(torrents); + if peers_per_torrent.entries() != 0 { println!( "peers per torrent: min: {}, p50: {}, p75: {}, p90: {}, p99: {}, p999: {}, max: {}", diff --git a/aquatic_bench/src/bin/bench_handlers/announce.rs b/aquatic_bench/src/bin/bench_handlers/announce.rs index f05d4e2..3bde889 100644 --- a/aquatic_bench/src/bin/bench_handlers/announce.rs +++ b/aquatic_bench/src/bin/bench_handlers/announce.rs @@ -95,9 +95,9 @@ pub fn create_requests( let mut requests = Vec::new(); - let d = state.handler_data.lock(); + let connections = state.connections.lock(); - let connection_keys: Vec = d.connections.keys() + let connection_keys: Vec = connections.keys() .take(number) .cloned() .collect(); diff --git a/aquatic_bench/src/bin/bench_handlers/scrape.rs b/aquatic_bench/src/bin/bench_handlers/scrape.rs index 190db68..5f979d9 100644 --- a/aquatic_bench/src/bin/bench_handlers/scrape.rs +++ b/aquatic_bench/src/bin/bench_handlers/scrape.rs @@ -96,9 +96,9 @@ pub fn create_requests( let max_index = info_hashes.len() - 1; - let d = state.handler_data.lock(); + let connections = state.connections.lock(); - let connection_keys: Vec = d.connections.keys() + let connection_keys: Vec = connections.keys() .take(number) .cloned() .collect();