udp: use hmac ConnectionValidator in socket workers

This commit is contained in:
Joakim Frostegård 2022-04-13 22:27:45 +02:00
parent dc4523ede5
commit 8b70034900
3 changed files with 28 additions and 67 deletions

View file

@ -16,13 +16,14 @@ use crate::config::Config;
pub const MAX_PACKET_SIZE: usize = 8192;
pub struct ConnectionIdHandler {
#[derive(Clone)]
pub struct ConnectionValidator {
start_time: Instant,
max_connection_age: Duration,
hmac: blake3::Hasher,
}
impl ConnectionIdHandler {
impl ConnectionValidator {
pub fn new(config: &Config) -> anyhow::Result<Self> {
let mut key = [0; 32];
@ -40,19 +41,23 @@ impl ConnectionIdHandler {
})
}
pub fn create_connection_id(&mut self, source_ip: IpAddr) -> ConnectionId {
pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId {
// Seconds elapsed since server start, as bytes
let elapsed_time_bytes = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes();
self.create_connection_id_inner(elapsed_time_bytes, source_ip)
self.create_connection_id_inner(elapsed_time_bytes, source_addr)
}
pub fn connection_id_valid(&mut self, source_ip: IpAddr, connection_id: ConnectionId) -> bool {
pub fn connection_id_valid(
&mut self,
source_addr: CanonicalSocketAddr,
connection_id: ConnectionId,
) -> bool {
let elapsed_time_bytes = connection_id.0.to_ne_bytes()[..4].try_into().unwrap();
// i64 comparison should be constant-time
let hmac_valid =
connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_ip);
connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_addr);
if !hmac_valid {
return false;
@ -67,7 +72,7 @@ impl ConnectionIdHandler {
fn create_connection_id_inner(
&mut self,
elapsed_time_bytes: [u8; 4],
source_ip: IpAddr,
source_addr: CanonicalSocketAddr,
) -> ConnectionId {
// The first 4 bytes is the elapsed time since server start in seconds. The last 4 is a
// truncated message authentication code.
@ -77,7 +82,7 @@ impl ConnectionIdHandler {
self.hmac.update(&elapsed_time_bytes);
match source_ip {
match source_addr.get().ip() {
IpAddr::V4(ip) => self.hmac.update(&ip.octets()),
IpAddr::V6(ip) => self.hmac.update(&ip.octets()),
};

View file

@ -17,7 +17,8 @@ use aquatic_common::privileges::PrivilegeDropper;
use aquatic_common::PanicSentinelWatcher;
use common::{
ConnectedRequestSender, ConnectedResponseSender, RequestWorkerIndex, SocketWorkerIndex, State,
ConnectedRequestSender, ConnectedResponseSender, ConnectionValidator, RequestWorkerIndex,
SocketWorkerIndex, State,
};
use config::Config;
@ -31,6 +32,8 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
let mut signals = Signals::new([SIGUSR1, SIGTERM])?;
let connection_validator = ConnectionValidator::new(&config)?;
let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel();
let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
@ -96,6 +99,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
let sentinel = sentinel.clone();
let state = state.clone();
let config = config.clone();
let connection_validator = connection_validator.clone();
let request_sender =
ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone());
let response_receiver = response_receivers.remove(&i).unwrap();
@ -117,6 +121,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
state,
config,
i,
connection_validator,
request_sender,
response_receiver,
priv_dropper,

View file

@ -9,7 +9,6 @@ use aquatic_common::privileges::PrivilegeDropper;
use crossbeam_channel::Receiver;
use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token};
use rand::prelude::{Rng, SeedableRng, StdRng};
use slab::Slab;
use aquatic_common::access_list::create_access_list_cache;
@ -22,31 +21,6 @@ use socket2::{Domain, Protocol, Socket, Type};
use crate::common::*;
use crate::config::Config;
#[derive(Default)]
pub struct ConnectionMap(AmortizedIndexMap<(ConnectionId, CanonicalSocketAddr), ValidUntil>);
impl ConnectionMap {
pub fn insert(
&mut self,
connection_id: ConnectionId,
socket_addr: CanonicalSocketAddr,
valid_until: ValidUntil,
) {
self.0.insert((connection_id, socket_addr), valid_until);
}
pub fn contains(&self, connection_id: ConnectionId, socket_addr: CanonicalSocketAddr) -> bool {
self.0.contains_key(&(connection_id, socket_addr))
}
pub fn clean(&mut self) {
let now = Instant::now();
self.0.retain(|_, v| v.0 > now);
self.0.shrink_to_fit();
}
}
#[derive(Debug)]
pub struct PendingScrapeResponseSlabEntry {
num_pending: usize,
@ -155,11 +129,11 @@ pub fn run_socket_worker(
state: State,
config: Config,
token_num: usize,
mut connection_validator: ConnectionValidator,
request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
priv_dropper: PrivilegeDropper,
) {
let mut rng = StdRng::from_entropy();
let mut buffer = [0u8; MAX_PACKET_SIZE];
let mut socket =
@ -173,7 +147,6 @@ pub fn run_socket_worker(
.unwrap();
let mut events = Events::with_capacity(config.network.poll_event_capacity);
let mut connections = ConnectionMap::default();
let mut pending_scrape_responses = PendingScrapeResponseSlab::default();
let mut access_list_cache = create_access_list_cache(&state.access_list);
@ -181,15 +154,10 @@ pub fn run_socket_worker(
let poll_timeout = Duration::from_millis(config.network.poll_timeout_ms);
let connection_cleaning_duration =
Duration::from_secs(config.cleaning.connection_cleaning_interval);
let pending_scrape_cleaning_duration =
Duration::from_secs(config.cleaning.pending_scrape_cleaning_interval);
let mut connection_valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let mut pending_scrape_valid_until = ValidUntil::new(config.cleaning.max_pending_scrape_age);
let mut last_connection_cleaning = Instant::now();
let mut last_pending_scrape_cleaning = Instant::now();
let mut iter_counter = 0usize;
@ -205,15 +173,13 @@ pub fn run_socket_worker(
read_requests(
&config,
&state,
&mut connections,
&mut connection_validator,
&mut pending_scrape_responses,
&mut access_list_cache,
&mut rng,
&mut socket,
&mut buffer,
&request_sender,
&mut local_responses,
connection_valid_until,
pending_scrape_valid_until,
);
}
@ -233,16 +199,9 @@ pub fn run_socket_worker(
if iter_counter % 128 == 0 {
let now = Instant::now();
connection_valid_until =
ValidUntil::new_with_now(now, config.cleaning.max_connection_age);
pending_scrape_valid_until =
ValidUntil::new_with_now(now, config.cleaning.max_pending_scrape_age);
if now > last_connection_cleaning + connection_cleaning_duration {
connections.clean();
last_connection_cleaning = now;
}
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration {
pending_scrape_responses.clean();
@ -258,15 +217,13 @@ pub fn run_socket_worker(
fn read_requests(
config: &Config,
state: &State,
connections: &mut ConnectionMap,
connection_validator: &mut ConnectionValidator,
pending_scrape_responses: &mut PendingScrapeResponseSlab,
access_list_cache: &mut AccessListCache,
rng: &mut StdRng,
socket: &mut UdpSocket,
buffer: &mut [u8],
request_sender: &ConnectedRequestSender,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
connection_valid_until: ValidUntil,
pending_scrape_valid_until: ValidUntil,
) {
let mut requests_received_ipv4: usize = 0;
@ -297,13 +254,11 @@ fn read_requests(
handle_request(
config,
connections,
connection_validator,
pending_scrape_responses,
access_list_cache,
rng,
request_sender,
local_responses,
connection_valid_until,
pending_scrape_valid_until,
res_request,
src,
@ -341,13 +296,11 @@ fn read_requests(
pub fn handle_request(
config: &Config,
connections: &mut ConnectionMap,
connection_validator: &mut ConnectionValidator,
pending_scrape_responses: &mut PendingScrapeResponseSlab,
access_list_cache: &mut AccessListCache,
rng: &mut StdRng,
request_sender: &ConnectedRequestSender,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
connection_valid_until: ValidUntil,
pending_scrape_valid_until: ValidUntil,
res_request: Result<Request, RequestParseError>,
src: CanonicalSocketAddr,
@ -356,9 +309,7 @@ pub fn handle_request(
match res_request {
Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen());
connections.insert(connection_id, src, connection_valid_until);
let connection_id = connection_validator.create_connection_id(src);
let response = Response::Connect(ConnectResponse {
connection_id,
@ -368,7 +319,7 @@ pub fn handle_request(
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) {
if connection_validator.connection_id_valid(src, request.connection_id) {
if access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
@ -392,7 +343,7 @@ pub fn handle_request(
}
}
Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) {
if connection_validator.connection_id_valid(src, request.connection_id) {
let split_requests = pending_scrape_responses.prepare_split_requests(
config,
request,
@ -417,7 +368,7 @@ pub fn handle_request(
err,
} = err
{
if connections.contains(connection_id, src) {
if connection_validator.connection_id_valid(src, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),