udp: use hmac ConnectionValidator in socket workers

This commit is contained in:
Joakim Frostegård 2022-04-13 22:27:45 +02:00
parent dc4523ede5
commit 8b70034900
3 changed files with 28 additions and 67 deletions

View file

@ -16,13 +16,14 @@ use crate::config::Config;
pub const MAX_PACKET_SIZE: usize = 8192; pub const MAX_PACKET_SIZE: usize = 8192;
pub struct ConnectionIdHandler { #[derive(Clone)]
pub struct ConnectionValidator {
start_time: Instant, start_time: Instant,
max_connection_age: Duration, max_connection_age: Duration,
hmac: blake3::Hasher, hmac: blake3::Hasher,
} }
impl ConnectionIdHandler { impl ConnectionValidator {
pub fn new(config: &Config) -> anyhow::Result<Self> { pub fn new(config: &Config) -> anyhow::Result<Self> {
let mut key = [0; 32]; let mut key = [0; 32];
@ -40,19 +41,23 @@ impl ConnectionIdHandler {
}) })
} }
pub fn create_connection_id(&mut self, source_ip: IpAddr) -> ConnectionId { pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId {
// Seconds elapsed since server start, as bytes // Seconds elapsed since server start, as bytes
let elapsed_time_bytes = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes(); let elapsed_time_bytes = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes();
self.create_connection_id_inner(elapsed_time_bytes, source_ip) self.create_connection_id_inner(elapsed_time_bytes, source_addr)
} }
pub fn connection_id_valid(&mut self, source_ip: IpAddr, connection_id: ConnectionId) -> bool { pub fn connection_id_valid(
&mut self,
source_addr: CanonicalSocketAddr,
connection_id: ConnectionId,
) -> bool {
let elapsed_time_bytes = connection_id.0.to_ne_bytes()[..4].try_into().unwrap(); let elapsed_time_bytes = connection_id.0.to_ne_bytes()[..4].try_into().unwrap();
// i64 comparison should be constant-time // i64 comparison should be constant-time
let hmac_valid = let hmac_valid =
connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_ip); connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_addr);
if !hmac_valid { if !hmac_valid {
return false; return false;
@ -67,7 +72,7 @@ impl ConnectionIdHandler {
fn create_connection_id_inner( fn create_connection_id_inner(
&mut self, &mut self,
elapsed_time_bytes: [u8; 4], elapsed_time_bytes: [u8; 4],
source_ip: IpAddr, source_addr: CanonicalSocketAddr,
) -> ConnectionId { ) -> ConnectionId {
// The first 4 bytes is the elapsed time since server start in seconds. The last 4 is a // The first 4 bytes is the elapsed time since server start in seconds. The last 4 is a
// truncated message authentication code. // truncated message authentication code.
@ -77,7 +82,7 @@ impl ConnectionIdHandler {
self.hmac.update(&elapsed_time_bytes); self.hmac.update(&elapsed_time_bytes);
match source_ip { match source_addr.get().ip() {
IpAddr::V4(ip) => self.hmac.update(&ip.octets()), IpAddr::V4(ip) => self.hmac.update(&ip.octets()),
IpAddr::V6(ip) => self.hmac.update(&ip.octets()), IpAddr::V6(ip) => self.hmac.update(&ip.octets()),
}; };

View file

@ -17,7 +17,8 @@ use aquatic_common::privileges::PrivilegeDropper;
use aquatic_common::PanicSentinelWatcher; use aquatic_common::PanicSentinelWatcher;
use common::{ use common::{
ConnectedRequestSender, ConnectedResponseSender, RequestWorkerIndex, SocketWorkerIndex, State, ConnectedRequestSender, ConnectedResponseSender, ConnectionValidator, RequestWorkerIndex,
SocketWorkerIndex, State,
}; };
use config::Config; use config::Config;
@ -31,6 +32,8 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
let mut signals = Signals::new([SIGUSR1, SIGTERM])?; let mut signals = Signals::new([SIGUSR1, SIGTERM])?;
let connection_validator = ConnectionValidator::new(&config)?;
let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel(); let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel();
let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
@ -96,6 +99,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
let sentinel = sentinel.clone(); let sentinel = sentinel.clone();
let state = state.clone(); let state = state.clone();
let config = config.clone(); let config = config.clone();
let connection_validator = connection_validator.clone();
let request_sender = let request_sender =
ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone()); ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone());
let response_receiver = response_receivers.remove(&i).unwrap(); let response_receiver = response_receivers.remove(&i).unwrap();
@ -117,6 +121,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
state, state,
config, config,
i, i,
connection_validator,
request_sender, request_sender,
response_receiver, response_receiver,
priv_dropper, priv_dropper,

View file

@ -9,7 +9,6 @@ use aquatic_common::privileges::PrivilegeDropper;
use crossbeam_channel::Receiver; use crossbeam_channel::Receiver;
use mio::net::UdpSocket; use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token}; use mio::{Events, Interest, Poll, Token};
use rand::prelude::{Rng, SeedableRng, StdRng};
use slab::Slab; use slab::Slab;
use aquatic_common::access_list::create_access_list_cache; use aquatic_common::access_list::create_access_list_cache;
@ -22,31 +21,6 @@ use socket2::{Domain, Protocol, Socket, Type};
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
#[derive(Default)]
pub struct ConnectionMap(AmortizedIndexMap<(ConnectionId, CanonicalSocketAddr), ValidUntil>);
impl ConnectionMap {
pub fn insert(
&mut self,
connection_id: ConnectionId,
socket_addr: CanonicalSocketAddr,
valid_until: ValidUntil,
) {
self.0.insert((connection_id, socket_addr), valid_until);
}
pub fn contains(&self, connection_id: ConnectionId, socket_addr: CanonicalSocketAddr) -> bool {
self.0.contains_key(&(connection_id, socket_addr))
}
pub fn clean(&mut self) {
let now = Instant::now();
self.0.retain(|_, v| v.0 > now);
self.0.shrink_to_fit();
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct PendingScrapeResponseSlabEntry { pub struct PendingScrapeResponseSlabEntry {
num_pending: usize, num_pending: usize,
@ -155,11 +129,11 @@ pub fn run_socket_worker(
state: State, state: State,
config: Config, config: Config,
token_num: usize, token_num: usize,
mut connection_validator: ConnectionValidator,
request_sender: ConnectedRequestSender, request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
priv_dropper: PrivilegeDropper, priv_dropper: PrivilegeDropper,
) { ) {
let mut rng = StdRng::from_entropy();
let mut buffer = [0u8; MAX_PACKET_SIZE]; let mut buffer = [0u8; MAX_PACKET_SIZE];
let mut socket = let mut socket =
@ -173,7 +147,6 @@ pub fn run_socket_worker(
.unwrap(); .unwrap();
let mut events = Events::with_capacity(config.network.poll_event_capacity); let mut events = Events::with_capacity(config.network.poll_event_capacity);
let mut connections = ConnectionMap::default();
let mut pending_scrape_responses = PendingScrapeResponseSlab::default(); let mut pending_scrape_responses = PendingScrapeResponseSlab::default();
let mut access_list_cache = create_access_list_cache(&state.access_list); let mut access_list_cache = create_access_list_cache(&state.access_list);
@ -181,15 +154,10 @@ pub fn run_socket_worker(
let poll_timeout = Duration::from_millis(config.network.poll_timeout_ms); let poll_timeout = Duration::from_millis(config.network.poll_timeout_ms);
let connection_cleaning_duration =
Duration::from_secs(config.cleaning.connection_cleaning_interval);
let pending_scrape_cleaning_duration = let pending_scrape_cleaning_duration =
Duration::from_secs(config.cleaning.pending_scrape_cleaning_interval); Duration::from_secs(config.cleaning.pending_scrape_cleaning_interval);
let mut connection_valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let mut pending_scrape_valid_until = ValidUntil::new(config.cleaning.max_pending_scrape_age); let mut pending_scrape_valid_until = ValidUntil::new(config.cleaning.max_pending_scrape_age);
let mut last_connection_cleaning = Instant::now();
let mut last_pending_scrape_cleaning = Instant::now(); let mut last_pending_scrape_cleaning = Instant::now();
let mut iter_counter = 0usize; let mut iter_counter = 0usize;
@ -205,15 +173,13 @@ pub fn run_socket_worker(
read_requests( read_requests(
&config, &config,
&state, &state,
&mut connections, &mut connection_validator,
&mut pending_scrape_responses, &mut pending_scrape_responses,
&mut access_list_cache, &mut access_list_cache,
&mut rng,
&mut socket, &mut socket,
&mut buffer, &mut buffer,
&request_sender, &request_sender,
&mut local_responses, &mut local_responses,
connection_valid_until,
pending_scrape_valid_until, pending_scrape_valid_until,
); );
} }
@ -233,16 +199,9 @@ pub fn run_socket_worker(
if iter_counter % 128 == 0 { if iter_counter % 128 == 0 {
let now = Instant::now(); let now = Instant::now();
connection_valid_until =
ValidUntil::new_with_now(now, config.cleaning.max_connection_age);
pending_scrape_valid_until = pending_scrape_valid_until =
ValidUntil::new_with_now(now, config.cleaning.max_pending_scrape_age); ValidUntil::new_with_now(now, config.cleaning.max_pending_scrape_age);
if now > last_connection_cleaning + connection_cleaning_duration {
connections.clean();
last_connection_cleaning = now;
}
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration {
pending_scrape_responses.clean(); pending_scrape_responses.clean();
@ -258,15 +217,13 @@ pub fn run_socket_worker(
fn read_requests( fn read_requests(
config: &Config, config: &Config,
state: &State, state: &State,
connections: &mut ConnectionMap, connection_validator: &mut ConnectionValidator,
pending_scrape_responses: &mut PendingScrapeResponseSlab, pending_scrape_responses: &mut PendingScrapeResponseSlab,
access_list_cache: &mut AccessListCache, access_list_cache: &mut AccessListCache,
rng: &mut StdRng,
socket: &mut UdpSocket, socket: &mut UdpSocket,
buffer: &mut [u8], buffer: &mut [u8],
request_sender: &ConnectedRequestSender, request_sender: &ConnectedRequestSender,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
connection_valid_until: ValidUntil,
pending_scrape_valid_until: ValidUntil, pending_scrape_valid_until: ValidUntil,
) { ) {
let mut requests_received_ipv4: usize = 0; let mut requests_received_ipv4: usize = 0;
@ -297,13 +254,11 @@ fn read_requests(
handle_request( handle_request(
config, config,
connections, connection_validator,
pending_scrape_responses, pending_scrape_responses,
access_list_cache, access_list_cache,
rng,
request_sender, request_sender,
local_responses, local_responses,
connection_valid_until,
pending_scrape_valid_until, pending_scrape_valid_until,
res_request, res_request,
src, src,
@ -341,13 +296,11 @@ fn read_requests(
pub fn handle_request( pub fn handle_request(
config: &Config, config: &Config,
connections: &mut ConnectionMap, connection_validator: &mut ConnectionValidator,
pending_scrape_responses: &mut PendingScrapeResponseSlab, pending_scrape_responses: &mut PendingScrapeResponseSlab,
access_list_cache: &mut AccessListCache, access_list_cache: &mut AccessListCache,
rng: &mut StdRng,
request_sender: &ConnectedRequestSender, request_sender: &ConnectedRequestSender,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>, local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
connection_valid_until: ValidUntil,
pending_scrape_valid_until: ValidUntil, pending_scrape_valid_until: ValidUntil,
res_request: Result<Request, RequestParseError>, res_request: Result<Request, RequestParseError>,
src: CanonicalSocketAddr, src: CanonicalSocketAddr,
@ -356,9 +309,7 @@ pub fn handle_request(
match res_request { match res_request {
Ok(Request::Connect(request)) => { Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen()); let connection_id = connection_validator.create_connection_id(src);
connections.insert(connection_id, src, connection_valid_until);
let response = Response::Connect(ConnectResponse { let response = Response::Connect(ConnectResponse {
connection_id, connection_id,
@ -368,7 +319,7 @@ pub fn handle_request(
local_responses.push((response, src)) local_responses.push((response, src))
} }
Ok(Request::Announce(request)) => { Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) { if connection_validator.connection_id_valid(src, request.connection_id) {
if access_list_cache if access_list_cache
.load() .load()
.allows(access_list_mode, &request.info_hash.0) .allows(access_list_mode, &request.info_hash.0)
@ -392,7 +343,7 @@ pub fn handle_request(
} }
} }
Ok(Request::Scrape(request)) => { Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) { if connection_validator.connection_id_valid(src, request.connection_id) {
let split_requests = pending_scrape_responses.prepare_split_requests( let split_requests = pending_scrape_responses.prepare_split_requests(
config, config,
request, request,
@ -417,7 +368,7 @@ pub fn handle_request(
err, err,
} = err } = err
{ {
if connections.contains(connection_id, src) { if connection_validator.connection_id_valid(src, connection_id) {
let response = ErrorResponse { let response = ErrorResponse {
transaction_id, transaction_id,
message: err.right_or("Parse error").into(), message: err.right_or("Parse error").into(),