From 7616df96864879ac8dce1e201b22b3eb1dc8a004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Mon, 18 Oct 2021 01:14:32 +0200 Subject: [PATCH] aquatic_udp: validate requests in socket workers Also, don't send error responses for unconnected requests --- TODO.md | 2 - aquatic_udp/src/lib/common.rs | 7 +- aquatic_udp/src/lib/handlers/connect.rs | 39 ----------- aquatic_udp/src/lib/handlers/mod.rs | 90 +++++-------------------- aquatic_udp/src/lib/lib.rs | 1 - aquatic_udp/src/lib/network.rs | 68 ++++++++++++++----- aquatic_udp/src/lib/tasks.rs | 10 --- aquatic_udp_bench/src/announce.rs | 20 ++---- aquatic_udp_bench/src/connect.rs | 80 ---------------------- aquatic_udp_bench/src/main.rs | 11 --- aquatic_udp_bench/src/scrape.rs | 19 ++---- 11 files changed, 88 insertions(+), 259 deletions(-) delete mode 100644 aquatic_udp/src/lib/handlers/connect.rs delete mode 100644 aquatic_udp_bench/src/connect.rs diff --git a/TODO.md b/TODO.md index 66c2b6f..45e45d5 100644 --- a/TODO.md +++ b/TODO.md @@ -1,7 +1,5 @@ # TODO -* aquatic_udp needs to check connection validity before sending error responses! connection validity checks could be moved to socket workers, since theyare sharded by ip - * access lists: * use arc-swap Cache * test functionality diff --git a/aquatic_udp/src/lib/common.rs b/aquatic_udp/src/lib/common.rs index 67f0f4b..132666c 100644 --- a/aquatic_udp/src/lib/common.rs +++ b/aquatic_udp/src/lib/common.rs @@ -30,6 +30,11 @@ impl Ip for Ipv6Addr { } } +pub enum ConnectedRequest { + Announce(AnnounceRequest), + Scrape(ScrapeRequest), +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ConnectionKey { pub connection_id: ConnectionId, @@ -180,7 +185,6 @@ pub struct Statistics { #[derive(Clone)] pub struct State { pub access_list: Arc, - pub connections: Arc>, pub torrents: Arc>, pub statistics: Arc, } @@ -189,7 +193,6 @@ impl Default for State { fn default() -> Self { Self { access_list: Arc::new(AccessList::default()), - connections: Arc::new(Mutex::new(HashMap::new())), torrents: Arc::new(Mutex::new(TorrentMaps::default())), statistics: Arc::new(Statistics::default()), } diff --git a/aquatic_udp/src/lib/handlers/connect.rs b/aquatic_udp/src/lib/handlers/connect.rs deleted file mode 100644 index e058241..0000000 --- a/aquatic_udp/src/lib/handlers/connect.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::net::SocketAddr; -use std::vec::Drain; - -use parking_lot::MutexGuard; -use rand::{rngs::StdRng, Rng}; - -use aquatic_udp_protocol::*; - -use crate::common::*; -use crate::config::Config; - -#[inline] -pub fn handle_connect_requests( - config: &Config, - connections: &mut MutexGuard, - rng: &mut StdRng, - requests: Drain<(ConnectRequest, SocketAddr)>, - responses: &mut Vec<(Response, SocketAddr)>, -) { - let valid_until = ValidUntil::new(config.cleaning.max_connection_age); - - responses.extend(requests.map(|(request, src)| { - let connection_id = ConnectionId(rng.gen()); - - let key = ConnectionKey { - connection_id, - socket_addr: src, - }; - - connections.insert(key, valid_until); - - let response = Response::Connect(ConnectResponse { - connection_id, - transaction_id: request.transaction_id, - }); - - (response, src) - })); -} diff --git a/aquatic_udp/src/lib/handlers/mod.rs b/aquatic_udp/src/lib/handlers/mod.rs index a1906d9..5cb7d39 100644 --- a/aquatic_udp/src/lib/handlers/mod.rs +++ b/aquatic_udp/src/lib/handlers/mod.rs @@ -14,20 +14,17 @@ use crate::common::*; use crate::config::Config; mod announce; -mod connect; mod scrape; use announce::handle_announce_requests; -use connect::handle_connect_requests; use scrape::handle_scrape_requests; pub fn run_request_worker( state: State, config: Config, - request_receiver: Receiver<(Request, SocketAddr)>, + request_receiver: Receiver<(ConnectedRequest, SocketAddr)>, response_sender: Sender<(Response, SocketAddr)>, ) { - let mut connect_requests: Vec<(ConnectRequest, SocketAddr)> = Vec::new(); let mut announce_requests: Vec<(AnnounceRequest, SocketAddr)> = Vec::new(); let mut scrape_requests: Vec<(ScrapeRequest, SocketAddr)> = Vec::new(); @@ -39,15 +36,15 @@ pub fn run_request_worker( let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds); loop { - let mut opt_connections = None; + let mut opt_torrents = None; // Collect requests from channel, divide them by type // // Collect a maximum number of request. Stop collecting before that // number is reached if having waited for too long for a request, but - // only if ConnectionMap mutex isn't locked. + // only if TorrentMaps mutex isn't locked. for i in 0..config.handlers.max_requests_per_iter { - let (request, src): (Request, SocketAddr) = if i == 0 { + let (request, src): (ConnectedRequest, SocketAddr) = if i == 0 { match request_receiver.recv() { Ok(r) => r, Err(_) => break, // Really shouldn't happen @@ -56,8 +53,8 @@ pub fn run_request_worker( match request_receiver.recv_timeout(timeout) { Ok(r) => r, Err(_) => { - if let Some(guard) = state.connections.try_lock() { - opt_connections = Some(guard); + if let Some(guard) = state.torrents.try_lock() { + opt_torrents = Some(guard); break; } else { @@ -68,69 +65,25 @@ pub fn run_request_worker( }; match request { - Request::Connect(r) => connect_requests.push((r, src)), - Request::Announce(r) => announce_requests.push((r, src)), - Request::Scrape(r) => scrape_requests.push((r, src)), + ConnectedRequest::Announce(r) => announce_requests.push((r, src)), + ConnectedRequest::Scrape(r) => scrape_requests.push((r, src)), } } - let mut connections: MutexGuard = - opt_connections.unwrap_or_else(|| state.connections.lock()); - - handle_connect_requests( - &config, - &mut connections, - &mut std_rng, - connect_requests.drain(..), - &mut responses, - ); - - // Check announce and scrape requests for valid connections - - announce_requests.retain(|(request, src)| { - let connection_valid = - connections.contains_key(&ConnectionKey::new(request.connection_id, *src)); - - if !connection_valid { - responses.push(( - create_invalid_connection_response(request.transaction_id), - *src, - )); - } - - connection_valid - }); - - scrape_requests.retain(|(request, src)| { - let connection_valid = - connections.contains_key(&ConnectionKey::new(request.connection_id, *src)); - - if !connection_valid { - responses.push(( - create_invalid_connection_response(request.transaction_id), - *src, - )); - } - - connection_valid - }); - - ::std::mem::drop(connections); + let mut torrents: MutexGuard = + opt_torrents.unwrap_or_else(|| state.torrents.lock()); // Generate responses for announce and scrape requests - if !(announce_requests.is_empty() && scrape_requests.is_empty()) { - let mut torrents = state.torrents.lock(); + handle_announce_requests( + &config, + &mut torrents, + &mut small_rng, + announce_requests.drain(..), + &mut responses, + ); - handle_announce_requests( - &config, - &mut torrents, - &mut small_rng, - announce_requests.drain(..), - &mut responses, - ); - handle_scrape_requests(&mut torrents, scrape_requests.drain(..), &mut responses); - } + handle_scrape_requests(&mut torrents, scrape_requests.drain(..), &mut responses); for r in responses.drain(..) { if let Err(err) = response_sender.send(r) { @@ -139,10 +92,3 @@ pub fn run_request_worker( } } } - -fn create_invalid_connection_response(transaction_id: TransactionId) -> Response { - Response::Error(ErrorResponse { - transaction_id, - message: "Connection invalid or expired".into(), - }) -} diff --git a/aquatic_udp/src/lib/lib.rs b/aquatic_udp/src/lib/lib.rs index 5e32d83..7ccee1c 100644 --- a/aquatic_udp/src/lib/lib.rs +++ b/aquatic_udp/src/lib/lib.rs @@ -55,7 +55,6 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { loop { ::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); - tasks::clean_connections(&state); tasks::update_access_list(&config, &state); state.torrents.lock().clean(&config, &state.access_list); diff --git a/aquatic_udp/src/lib/network.rs b/aquatic_udp/src/lib/network.rs index b1baf65..ccea096 100644 --- a/aquatic_udp/src/lib/network.rs +++ b/aquatic_udp/src/lib/network.rs @@ -4,12 +4,13 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; -use std::time::Duration; +use std::time::{Duration, Instant}; use std::vec::Drain; use crossbeam_channel::{Receiver, Sender}; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; +use rand::prelude::{Rng, SeedableRng, StdRng}; use socket2::{Domain, Protocol, Socket, Type}; use aquatic_udp_protocol::{IpVersion, Request, Response}; @@ -21,10 +22,11 @@ pub fn run_socket_worker( state: State, config: Config, token_num: usize, - request_sender: Sender<(Request, SocketAddr)>, + request_sender: Sender<(ConnectedRequest, SocketAddr)>, response_receiver: Receiver<(Response, SocketAddr)>, num_bound_sockets: Arc, ) { + let mut rng = StdRng::from_entropy(); let mut buffer = [0u8; MAX_PACKET_SIZE]; let mut socket = UdpSocket::from_std(create_socket(&config)); @@ -39,8 +41,9 @@ pub fn run_socket_worker( num_bound_sockets.fetch_add(1, Ordering::SeqCst); let mut events = Events::with_capacity(config.network.poll_event_capacity); + let mut connections = ConnectionMap::default(); - let mut requests: Vec<(Request, SocketAddr)> = Vec::new(); + let mut requests: Vec<(ConnectedRequest, SocketAddr)> = Vec::new(); let mut local_responses: Vec<(Response, SocketAddr)> = Vec::new(); let timeout = Duration::from_millis(50); @@ -54,8 +57,10 @@ pub fn run_socket_worker( if (token.0 == token_num) & event.is_readable() { read_requests( - &state, &config, + &state, + &mut connections, + &mut rng, &mut socket, &mut buffer, &mut requests, @@ -83,6 +88,11 @@ pub fn run_socket_worker( &response_receiver, local_responses.drain(..), ); + + let now = Instant::now(); + + connections.retain(|_, v| v.0 > now); + connections.shrink_to_fit(); } } @@ -121,16 +131,19 @@ fn create_socket(config: &Config) -> ::std::net::UdpSocket { #[inline] fn read_requests( - state: &State, config: &Config, + state: &State, + connections: &mut ConnectionMap, + rng: &mut StdRng, socket: &mut UdpSocket, buffer: &mut [u8], - requests: &mut Vec<(Request, SocketAddr)>, + requests: &mut Vec<(ConnectedRequest, SocketAddr)>, local_responses: &mut Vec<(Response, SocketAddr)>, ) { let mut requests_received: usize = 0; let mut bytes_received: usize = 0; + let valid_until = ValidUntil::new(config.cleaning.max_connection_age); let access_list_mode = config.access_list.mode; loop { @@ -146,20 +159,43 @@ fn read_requests( } match request { - Ok(Request::Announce(AnnounceRequest { - info_hash, - transaction_id, - .. - })) if !state.access_list.allows(access_list_mode, &info_hash.0) => { - let response = Response::Error(ErrorResponse { - transaction_id, - message: "Info hash not allowed".into(), + Ok(Request::Connect(request)) => { + let connection_id = ConnectionId(rng.gen()); + + connections.insert(ConnectionKey::new(connection_id, src), valid_until); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, }); local_responses.push((response, src)) } - Ok(request) => { - requests.push((request, src)); + Ok(Request::Announce(request)) => { + let key = ConnectionKey::new(request.connection_id, src); + + if connections.contains_key(&key) { + if state + .access_list + .allows(access_list_mode, &request.info_hash.0) + { + requests.push((ConnectedRequest::Announce(request), src)); + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + local_responses.push((response, src)) + } + } + } + Ok(Request::Scrape(request)) => { + let key = ConnectionKey::new(request.connection_id, src); + + if connections.contains_key(&key) { + requests.push((ConnectedRequest::Scrape(request), src)); + } } Err(err) => { ::log::debug!("request_from_bytes error: {:?}", err); diff --git a/aquatic_udp/src/lib/tasks.rs b/aquatic_udp/src/lib/tasks.rs index 83acd3b..44675c9 100644 --- a/aquatic_udp/src/lib/tasks.rs +++ b/aquatic_udp/src/lib/tasks.rs @@ -1,5 +1,4 @@ use std::sync::atomic::Ordering; -use std::time::Instant; use histogram::Histogram; @@ -19,15 +18,6 @@ pub fn update_access_list(config: &Config, state: &State) { } } -pub fn clean_connections(state: &State) { - let now = Instant::now(); - - let mut connections = state.connections.lock(); - - connections.retain(|_, v| v.0 > now); - connections.shrink_to_fit(); -} - pub fn gather_and_print_statistics(state: &State, config: &Config) { let interval = config.statistics.interval; diff --git a/aquatic_udp_bench/src/announce.rs b/aquatic_udp_bench/src/announce.rs index 8288713..c759b97 100644 --- a/aquatic_udp_bench/src/announce.rs +++ b/aquatic_udp_bench/src/announce.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::time::{Duration, Instant}; use crossbeam_channel::{Receiver, Sender}; @@ -13,15 +13,14 @@ use crate::common::*; use crate::config::BenchConfig; pub fn bench_announce_handler( - state: &State, bench_config: &BenchConfig, aquatic_config: &Config, - request_sender: &Sender<(Request, SocketAddr)>, + request_sender: &Sender<(ConnectedRequest, SocketAddr)>, response_receiver: &Receiver<(Response, SocketAddr)>, rng: &mut impl Rng, info_hashes: &[InfoHash], ) -> (usize, Duration) { - let requests = create_requests(state, rng, info_hashes, bench_config.num_announce_requests); + let requests = create_requests(rng, info_hashes, bench_config.num_announce_requests); let p = aquatic_config.handlers.max_requests_per_iter * bench_config.num_threads; let mut num_responses = 0usize; @@ -37,7 +36,7 @@ pub fn bench_announce_handler( for round in (0..bench_config.num_rounds).progress_with(pb) { for request_chunk in requests.chunks(p) { for (request, src) in request_chunk { - request_sender.send((request.clone().into(), *src)).unwrap(); + request_sender.send((ConnectedRequest::Announce(request.clone()), *src)).unwrap(); } while let Ok((Response::Announce(r), _)) = response_receiver.try_recv() { @@ -72,7 +71,6 @@ pub fn bench_announce_handler( } pub fn create_requests( - state: &State, rng: &mut impl Rng, info_hashes: &[InfoHash], number: usize, @@ -83,15 +81,11 @@ pub fn create_requests( let mut requests = Vec::new(); - let connections = state.connections.lock(); - - let connection_keys: Vec = connections.keys().take(number).cloned().collect(); - - for connection_key in connection_keys.into_iter() { + for _ in 0..number { let info_hash_index = pareto_usize(rng, pareto, max_index); let request = AnnounceRequest { - connection_id: connection_key.connection_id, + connection_id: ConnectionId(0), transaction_id: TransactionId(rng.gen()), info_hash: info_hashes[info_hash_index], peer_id: PeerId(rng.gen()), @@ -105,7 +99,7 @@ pub fn create_requests( port: Port(rng.gen()), }; - requests.push((request, connection_key.socket_addr)); + requests.push((request, SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1)))); } requests diff --git a/aquatic_udp_bench/src/connect.rs b/aquatic_udp_bench/src/connect.rs deleted file mode 100644 index 3f3bc05..0000000 --- a/aquatic_udp_bench/src/connect.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::time::{Duration, Instant}; - -use crossbeam_channel::{Receiver, Sender}; -use indicatif::ProgressIterator; -use rand::{rngs::SmallRng, thread_rng, Rng, SeedableRng}; -use std::net::SocketAddr; - -use aquatic_udp::common::*; -use aquatic_udp::config::Config; - -use crate::common::*; -use crate::config::BenchConfig; - -pub fn bench_connect_handler( - bench_config: &BenchConfig, - aquatic_config: &Config, - request_sender: &Sender<(Request, SocketAddr)>, - response_receiver: &Receiver<(Response, SocketAddr)>, -) -> (usize, Duration) { - let requests = create_requests(bench_config.num_connect_requests); - - let p = aquatic_config.handlers.max_requests_per_iter * bench_config.num_threads; - let mut num_responses = 0usize; - - let mut dummy: i64 = thread_rng().gen(); - - let pb = create_progress_bar("Connect", bench_config.num_rounds as u64); - - // Start connect benchmark - - let before = Instant::now(); - - for round in (0..bench_config.num_rounds).progress_with(pb) { - for request_chunk in requests.chunks(p) { - for (request, src) in request_chunk { - request_sender.send((request.clone().into(), *src)).unwrap(); - } - - while let Ok((Response::Connect(r), _)) = response_receiver.try_recv() { - num_responses += 1; - dummy ^= r.connection_id.0; - } - } - - let total = bench_config.num_connect_requests * (round + 1); - - while num_responses < total { - if let Ok((Response::Connect(r), _)) = response_receiver.recv() { - num_responses += 1; - dummy ^= r.connection_id.0; - } - } - } - - let elapsed = before.elapsed(); - - if dummy == 0 { - println!("dummy dummy"); - } - - (num_responses, elapsed) -} - -pub fn create_requests(number: usize) -> Vec<(ConnectRequest, SocketAddr)> { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - - let mut requests = Vec::new(); - - for _ in 0..number { - let request = ConnectRequest { - transaction_id: TransactionId(rng.gen()), - }; - - let src = SocketAddr::from(([rng.gen(), rng.gen(), rng.gen(), rng.gen()], rng.gen())); - - requests.push((request, src)); - } - - requests -} diff --git a/aquatic_udp_bench/src/main.rs b/aquatic_udp_bench/src/main.rs index a64cb07..c0313df 100644 --- a/aquatic_udp_bench/src/main.rs +++ b/aquatic_udp_bench/src/main.rs @@ -29,7 +29,6 @@ use config::BenchConfig; mod announce; mod common; mod config; -mod connect; mod scrape; #[global_allocator] @@ -65,18 +64,10 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> { // Run benchmarks - let c = connect::bench_connect_handler( - &bench_config, - &aquatic_config, - &request_sender, - &response_receiver, - ); - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); let info_hashes = create_info_hashes(&mut rng); let a = announce::bench_announce_handler( - &state, &bench_config, &aquatic_config, &request_sender, @@ -86,7 +77,6 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> { ); let s = scrape::bench_scrape_handler( - &state, &bench_config, &aquatic_config, &request_sender, @@ -100,7 +90,6 @@ pub fn run(bench_config: BenchConfig) -> ::anyhow::Result<()> { bench_config.num_rounds, bench_config.num_threads, ); - print_results("Connect: ", c.0, c.1); print_results("Announce:", a.0, a.1); print_results("Scrape: ", s.0, s.1); diff --git a/aquatic_udp_bench/src/scrape.rs b/aquatic_udp_bench/src/scrape.rs index 09534cc..26a5ad3 100644 --- a/aquatic_udp_bench/src/scrape.rs +++ b/aquatic_udp_bench/src/scrape.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::time::{Duration, Instant}; use crossbeam_channel::{Receiver, Sender}; @@ -13,16 +13,14 @@ use crate::common::*; use crate::config::BenchConfig; pub fn bench_scrape_handler( - state: &State, bench_config: &BenchConfig, aquatic_config: &Config, - request_sender: &Sender<(Request, SocketAddr)>, + request_sender: &Sender<(ConnectedRequest, SocketAddr)>, response_receiver: &Receiver<(Response, SocketAddr)>, rng: &mut impl Rng, info_hashes: &[InfoHash], ) -> (usize, Duration) { let requests = create_requests( - state, rng, info_hashes, bench_config.num_scrape_requests, @@ -43,7 +41,7 @@ pub fn bench_scrape_handler( for round in (0..bench_config.num_rounds).progress_with(pb) { for request_chunk in requests.chunks(p) { for (request, src) in request_chunk { - request_sender.send((request.clone().into(), *src)).unwrap(); + request_sender.send((ConnectedRequest::Scrape(request.clone()), *src)).unwrap(); } while let Ok((Response::Scrape(r), _)) = response_receiver.try_recv() { @@ -78,7 +76,6 @@ pub fn bench_scrape_handler( } pub fn create_requests( - state: &State, rng: &mut impl Rng, info_hashes: &[InfoHash], number: usize, @@ -88,13 +85,9 @@ pub fn create_requests( let max_index = info_hashes.len() - 1; - let connections = state.connections.lock(); - - let connection_keys: Vec = connections.keys().take(number).cloned().collect(); - let mut requests = Vec::new(); - for connection_key in connection_keys.into_iter() { + for _ in 0..number { let mut request_info_hashes = Vec::new(); for _ in 0..hashes_per_request { @@ -103,12 +96,12 @@ pub fn create_requests( } let request = ScrapeRequest { - connection_id: connection_key.connection_id, + connection_id: ConnectionId(0), transaction_id: TransactionId(rng.gen()), info_hashes: request_info_hashes, }; - requests.push((request, connection_key.socket_addr)); + requests.push((request, SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1)))); } requests