aquatic_udp: validate requests in socket workers

Also, don't send error responses for unconnected requests
This commit is contained in:
Joakim Frostegård 2021-10-18 01:14:32 +02:00
parent fc90c71a56
commit 7616df9686
11 changed files with 88 additions and 259 deletions

View file

@ -30,6 +30,11 @@ impl Ip for Ipv6Addr {
}
}
pub enum ConnectedRequest {
Announce(AnnounceRequest),
Scrape(ScrapeRequest),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConnectionKey {
pub connection_id: ConnectionId,
@ -180,7 +185,6 @@ pub struct Statistics {
#[derive(Clone)]
pub struct State {
pub access_list: Arc<AccessList>,
pub connections: Arc<Mutex<ConnectionMap>>,
pub torrents: Arc<Mutex<TorrentMaps>>,
pub statistics: Arc<Statistics>,
}
@ -189,7 +193,6 @@ impl Default for State {
fn default() -> Self {
Self {
access_list: Arc::new(AccessList::default()),
connections: Arc::new(Mutex::new(HashMap::new())),
torrents: Arc::new(Mutex::new(TorrentMaps::default())),
statistics: Arc::new(Statistics::default()),
}

View file

@ -1,39 +0,0 @@
use std::net::SocketAddr;
use std::vec::Drain;
use parking_lot::MutexGuard;
use rand::{rngs::StdRng, Rng};
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
#[inline]
pub fn handle_connect_requests(
config: &Config,
connections: &mut MutexGuard<ConnectionMap>,
rng: &mut StdRng,
requests: Drain<(ConnectRequest, SocketAddr)>,
responses: &mut Vec<(Response, SocketAddr)>,
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
responses.extend(requests.map(|(request, src)| {
let connection_id = ConnectionId(rng.gen());
let key = ConnectionKey {
connection_id,
socket_addr: src,
};
connections.insert(key, valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
(response, src)
}));
}

View file

@ -14,20 +14,17 @@ use crate::common::*;
use crate::config::Config;
mod announce;
mod connect;
mod scrape;
use announce::handle_announce_requests;
use connect::handle_connect_requests;
use scrape::handle_scrape_requests;
pub fn run_request_worker(
state: State,
config: Config,
request_receiver: Receiver<(Request, SocketAddr)>,
request_receiver: Receiver<(ConnectedRequest, SocketAddr)>,
response_sender: Sender<(Response, SocketAddr)>,
) {
let mut connect_requests: Vec<(ConnectRequest, SocketAddr)> = Vec::new();
let mut announce_requests: Vec<(AnnounceRequest, SocketAddr)> = Vec::new();
let mut scrape_requests: Vec<(ScrapeRequest, SocketAddr)> = Vec::new();
@ -39,15 +36,15 @@ pub fn run_request_worker(
let timeout = Duration::from_micros(config.handlers.channel_recv_timeout_microseconds);
loop {
let mut opt_connections = None;
let mut opt_torrents = None;
// Collect requests from channel, divide them by type
//
// Collect a maximum number of request. Stop collecting before that
// number is reached if having waited for too long for a request, but
// only if ConnectionMap mutex isn't locked.
// only if TorrentMaps mutex isn't locked.
for i in 0..config.handlers.max_requests_per_iter {
let (request, src): (Request, SocketAddr) = if i == 0 {
let (request, src): (ConnectedRequest, SocketAddr) = if i == 0 {
match request_receiver.recv() {
Ok(r) => r,
Err(_) => break, // Really shouldn't happen
@ -56,8 +53,8 @@ pub fn run_request_worker(
match request_receiver.recv_timeout(timeout) {
Ok(r) => r,
Err(_) => {
if let Some(guard) = state.connections.try_lock() {
opt_connections = Some(guard);
if let Some(guard) = state.torrents.try_lock() {
opt_torrents = Some(guard);
break;
} else {
@ -68,69 +65,25 @@ pub fn run_request_worker(
};
match request {
Request::Connect(r) => connect_requests.push((r, src)),
Request::Announce(r) => announce_requests.push((r, src)),
Request::Scrape(r) => scrape_requests.push((r, src)),
ConnectedRequest::Announce(r) => announce_requests.push((r, src)),
ConnectedRequest::Scrape(r) => scrape_requests.push((r, src)),
}
}
let mut connections: MutexGuard<ConnectionMap> =
opt_connections.unwrap_or_else(|| state.connections.lock());
handle_connect_requests(
&config,
&mut connections,
&mut std_rng,
connect_requests.drain(..),
&mut responses,
);
// Check announce and scrape requests for valid connections
announce_requests.retain(|(request, src)| {
let connection_valid =
connections.contains_key(&ConnectionKey::new(request.connection_id, *src));
if !connection_valid {
responses.push((
create_invalid_connection_response(request.transaction_id),
*src,
));
}
connection_valid
});
scrape_requests.retain(|(request, src)| {
let connection_valid =
connections.contains_key(&ConnectionKey::new(request.connection_id, *src));
if !connection_valid {
responses.push((
create_invalid_connection_response(request.transaction_id),
*src,
));
}
connection_valid
});
::std::mem::drop(connections);
let mut torrents: MutexGuard<TorrentMaps> =
opt_torrents.unwrap_or_else(|| state.torrents.lock());
// Generate responses for announce and scrape requests
if !(announce_requests.is_empty() && scrape_requests.is_empty()) {
let mut torrents = state.torrents.lock();
handle_announce_requests(
&config,
&mut torrents,
&mut small_rng,
announce_requests.drain(..),
&mut responses,
);
handle_announce_requests(
&config,
&mut torrents,
&mut small_rng,
announce_requests.drain(..),
&mut responses,
);
handle_scrape_requests(&mut torrents, scrape_requests.drain(..), &mut responses);
}
handle_scrape_requests(&mut torrents, scrape_requests.drain(..), &mut responses);
for r in responses.drain(..) {
if let Err(err) = response_sender.send(r) {
@ -139,10 +92,3 @@ pub fn run_request_worker(
}
}
}
fn create_invalid_connection_response(transaction_id: TransactionId) -> Response {
Response::Error(ErrorResponse {
transaction_id,
message: "Connection invalid or expired".into(),
})
}

View file

@ -55,7 +55,6 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
loop {
::std::thread::sleep(Duration::from_secs(config.cleaning.interval));
tasks::clean_connections(&state);
tasks::update_access_list(&config, &state);
state.torrents.lock().clean(&config, &state.access_list);

View file

@ -4,12 +4,13 @@ use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::Duration;
use std::time::{Duration, Instant};
use std::vec::Drain;
use crossbeam_channel::{Receiver, Sender};
use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token};
use rand::prelude::{Rng, SeedableRng, StdRng};
use socket2::{Domain, Protocol, Socket, Type};
use aquatic_udp_protocol::{IpVersion, Request, Response};
@ -21,10 +22,11 @@ pub fn run_socket_worker(
state: State,
config: Config,
token_num: usize,
request_sender: Sender<(Request, SocketAddr)>,
request_sender: Sender<(ConnectedRequest, SocketAddr)>,
response_receiver: Receiver<(Response, SocketAddr)>,
num_bound_sockets: Arc<AtomicUsize>,
) {
let mut rng = StdRng::from_entropy();
let mut buffer = [0u8; MAX_PACKET_SIZE];
let mut socket = UdpSocket::from_std(create_socket(&config));
@ -39,8 +41,9 @@ pub fn run_socket_worker(
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
let mut events = Events::with_capacity(config.network.poll_event_capacity);
let mut connections = ConnectionMap::default();
let mut requests: Vec<(Request, SocketAddr)> = Vec::new();
let mut requests: Vec<(ConnectedRequest, SocketAddr)> = Vec::new();
let mut local_responses: Vec<(Response, SocketAddr)> = Vec::new();
let timeout = Duration::from_millis(50);
@ -54,8 +57,10 @@ pub fn run_socket_worker(
if (token.0 == token_num) & event.is_readable() {
read_requests(
&state,
&config,
&state,
&mut connections,
&mut rng,
&mut socket,
&mut buffer,
&mut requests,
@ -83,6 +88,11 @@ pub fn run_socket_worker(
&response_receiver,
local_responses.drain(..),
);
let now = Instant::now();
connections.retain(|_, v| v.0 > now);
connections.shrink_to_fit();
}
}
@ -121,16 +131,19 @@ fn create_socket(config: &Config) -> ::std::net::UdpSocket {
#[inline]
fn read_requests(
state: &State,
config: &Config,
state: &State,
connections: &mut ConnectionMap,
rng: &mut StdRng,
socket: &mut UdpSocket,
buffer: &mut [u8],
requests: &mut Vec<(Request, SocketAddr)>,
requests: &mut Vec<(ConnectedRequest, SocketAddr)>,
local_responses: &mut Vec<(Response, SocketAddr)>,
) {
let mut requests_received: usize = 0;
let mut bytes_received: usize = 0;
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode;
loop {
@ -146,20 +159,43 @@ fn read_requests(
}
match request {
Ok(Request::Announce(AnnounceRequest {
info_hash,
transaction_id,
..
})) if !state.access_list.allows(access_list_mode, &info_hash.0) => {
let response = Response::Error(ErrorResponse {
transaction_id,
message: "Info hash not allowed".into(),
Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen());
connections.insert(ConnectionKey::new(connection_id, src), valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(request) => {
requests.push((request, src));
Ok(Request::Announce(request)) => {
let key = ConnectionKey::new(request.connection_id, src);
if connections.contains_key(&key) {
if state
.access_list
.allows(access_list_mode, &request.info_hash.0)
{
requests.push((ConnectedRequest::Announce(request), src));
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Ok(Request::Scrape(request)) => {
let key = ConnectionKey::new(request.connection_id, src);
if connections.contains_key(&key) {
requests.push((ConnectedRequest::Scrape(request), src));
}
}
Err(err) => {
::log::debug!("request_from_bytes error: {:?}", err);

View file

@ -1,5 +1,4 @@
use std::sync::atomic::Ordering;
use std::time::Instant;
use histogram::Histogram;
@ -19,15 +18,6 @@ pub fn update_access_list(config: &Config, state: &State) {
}
}
pub fn clean_connections(state: &State) {
let now = Instant::now();
let mut connections = state.connections.lock();
connections.retain(|_, v| v.0 > now);
connections.shrink_to_fit();
}
pub fn gather_and_print_statistics(state: &State, config: &Config) {
let interval = config.statistics.interval;