mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-03-31 17:55:36 +00:00
udp: use hmac ConnectionValidator in socket workers
This commit is contained in:
parent
dc4523ede5
commit
8b70034900
3 changed files with 28 additions and 67 deletions
|
|
@ -16,13 +16,14 @@ use crate::config::Config;
|
|||
|
||||
pub const MAX_PACKET_SIZE: usize = 8192;
|
||||
|
||||
pub struct ConnectionIdHandler {
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectionValidator {
|
||||
start_time: Instant,
|
||||
max_connection_age: Duration,
|
||||
hmac: blake3::Hasher,
|
||||
}
|
||||
|
||||
impl ConnectionIdHandler {
|
||||
impl ConnectionValidator {
|
||||
pub fn new(config: &Config) -> anyhow::Result<Self> {
|
||||
let mut key = [0; 32];
|
||||
|
||||
|
|
@ -40,19 +41,23 @@ impl ConnectionIdHandler {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn create_connection_id(&mut self, source_ip: IpAddr) -> ConnectionId {
|
||||
pub fn create_connection_id(&mut self, source_addr: CanonicalSocketAddr) -> ConnectionId {
|
||||
// Seconds elapsed since server start, as bytes
|
||||
let elapsed_time_bytes = (self.start_time.elapsed().as_secs() as u32).to_ne_bytes();
|
||||
|
||||
self.create_connection_id_inner(elapsed_time_bytes, source_ip)
|
||||
self.create_connection_id_inner(elapsed_time_bytes, source_addr)
|
||||
}
|
||||
|
||||
pub fn connection_id_valid(&mut self, source_ip: IpAddr, connection_id: ConnectionId) -> bool {
|
||||
pub fn connection_id_valid(
|
||||
&mut self,
|
||||
source_addr: CanonicalSocketAddr,
|
||||
connection_id: ConnectionId,
|
||||
) -> bool {
|
||||
let elapsed_time_bytes = connection_id.0.to_ne_bytes()[..4].try_into().unwrap();
|
||||
|
||||
// i64 comparison should be constant-time
|
||||
let hmac_valid =
|
||||
connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_ip);
|
||||
connection_id == self.create_connection_id_inner(elapsed_time_bytes, source_addr);
|
||||
|
||||
if !hmac_valid {
|
||||
return false;
|
||||
|
|
@ -67,7 +72,7 @@ impl ConnectionIdHandler {
|
|||
fn create_connection_id_inner(
|
||||
&mut self,
|
||||
elapsed_time_bytes: [u8; 4],
|
||||
source_ip: IpAddr,
|
||||
source_addr: CanonicalSocketAddr,
|
||||
) -> ConnectionId {
|
||||
// The first 4 bytes is the elapsed time since server start in seconds. The last 4 is a
|
||||
// truncated message authentication code.
|
||||
|
|
@ -77,7 +82,7 @@ impl ConnectionIdHandler {
|
|||
|
||||
self.hmac.update(&elapsed_time_bytes);
|
||||
|
||||
match source_ip {
|
||||
match source_addr.get().ip() {
|
||||
IpAddr::V4(ip) => self.hmac.update(&ip.octets()),
|
||||
IpAddr::V6(ip) => self.hmac.update(&ip.octets()),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ use aquatic_common::privileges::PrivilegeDropper;
|
|||
use aquatic_common::PanicSentinelWatcher;
|
||||
|
||||
use common::{
|
||||
ConnectedRequestSender, ConnectedResponseSender, RequestWorkerIndex, SocketWorkerIndex, State,
|
||||
ConnectedRequestSender, ConnectedResponseSender, ConnectionValidator, RequestWorkerIndex,
|
||||
SocketWorkerIndex, State,
|
||||
};
|
||||
use config::Config;
|
||||
|
||||
|
|
@ -31,6 +32,8 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
|
|||
|
||||
let mut signals = Signals::new([SIGUSR1, SIGTERM])?;
|
||||
|
||||
let connection_validator = ConnectionValidator::new(&config)?;
|
||||
|
||||
let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel();
|
||||
let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
|
||||
|
||||
|
|
@ -96,6 +99,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
|
|||
let sentinel = sentinel.clone();
|
||||
let state = state.clone();
|
||||
let config = config.clone();
|
||||
let connection_validator = connection_validator.clone();
|
||||
let request_sender =
|
||||
ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone());
|
||||
let response_receiver = response_receivers.remove(&i).unwrap();
|
||||
|
|
@ -117,6 +121,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
|
|||
state,
|
||||
config,
|
||||
i,
|
||||
connection_validator,
|
||||
request_sender,
|
||||
response_receiver,
|
||||
priv_dropper,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ use aquatic_common::privileges::PrivilegeDropper;
|
|||
use crossbeam_channel::Receiver;
|
||||
use mio::net::UdpSocket;
|
||||
use mio::{Events, Interest, Poll, Token};
|
||||
use rand::prelude::{Rng, SeedableRng, StdRng};
|
||||
use slab::Slab;
|
||||
|
||||
use aquatic_common::access_list::create_access_list_cache;
|
||||
|
|
@ -22,31 +21,6 @@ use socket2::{Domain, Protocol, Socket, Type};
|
|||
use crate::common::*;
|
||||
use crate::config::Config;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ConnectionMap(AmortizedIndexMap<(ConnectionId, CanonicalSocketAddr), ValidUntil>);
|
||||
|
||||
impl ConnectionMap {
|
||||
pub fn insert(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
socket_addr: CanonicalSocketAddr,
|
||||
valid_until: ValidUntil,
|
||||
) {
|
||||
self.0.insert((connection_id, socket_addr), valid_until);
|
||||
}
|
||||
|
||||
pub fn contains(&self, connection_id: ConnectionId, socket_addr: CanonicalSocketAddr) -> bool {
|
||||
self.0.contains_key(&(connection_id, socket_addr))
|
||||
}
|
||||
|
||||
pub fn clean(&mut self) {
|
||||
let now = Instant::now();
|
||||
|
||||
self.0.retain(|_, v| v.0 > now);
|
||||
self.0.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PendingScrapeResponseSlabEntry {
|
||||
num_pending: usize,
|
||||
|
|
@ -155,11 +129,11 @@ pub fn run_socket_worker(
|
|||
state: State,
|
||||
config: Config,
|
||||
token_num: usize,
|
||||
mut connection_validator: ConnectionValidator,
|
||||
request_sender: ConnectedRequestSender,
|
||||
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
|
||||
priv_dropper: PrivilegeDropper,
|
||||
) {
|
||||
let mut rng = StdRng::from_entropy();
|
||||
let mut buffer = [0u8; MAX_PACKET_SIZE];
|
||||
|
||||
let mut socket =
|
||||
|
|
@ -173,7 +147,6 @@ pub fn run_socket_worker(
|
|||
.unwrap();
|
||||
|
||||
let mut events = Events::with_capacity(config.network.poll_event_capacity);
|
||||
let mut connections = ConnectionMap::default();
|
||||
let mut pending_scrape_responses = PendingScrapeResponseSlab::default();
|
||||
let mut access_list_cache = create_access_list_cache(&state.access_list);
|
||||
|
||||
|
|
@ -181,15 +154,10 @@ pub fn run_socket_worker(
|
|||
|
||||
let poll_timeout = Duration::from_millis(config.network.poll_timeout_ms);
|
||||
|
||||
let connection_cleaning_duration =
|
||||
Duration::from_secs(config.cleaning.connection_cleaning_interval);
|
||||
let pending_scrape_cleaning_duration =
|
||||
Duration::from_secs(config.cleaning.pending_scrape_cleaning_interval);
|
||||
|
||||
let mut connection_valid_until = ValidUntil::new(config.cleaning.max_connection_age);
|
||||
let mut pending_scrape_valid_until = ValidUntil::new(config.cleaning.max_pending_scrape_age);
|
||||
|
||||
let mut last_connection_cleaning = Instant::now();
|
||||
let mut last_pending_scrape_cleaning = Instant::now();
|
||||
|
||||
let mut iter_counter = 0usize;
|
||||
|
|
@ -205,15 +173,13 @@ pub fn run_socket_worker(
|
|||
read_requests(
|
||||
&config,
|
||||
&state,
|
||||
&mut connections,
|
||||
&mut connection_validator,
|
||||
&mut pending_scrape_responses,
|
||||
&mut access_list_cache,
|
||||
&mut rng,
|
||||
&mut socket,
|
||||
&mut buffer,
|
||||
&request_sender,
|
||||
&mut local_responses,
|
||||
connection_valid_until,
|
||||
pending_scrape_valid_until,
|
||||
);
|
||||
}
|
||||
|
|
@ -233,16 +199,9 @@ pub fn run_socket_worker(
|
|||
if iter_counter % 128 == 0 {
|
||||
let now = Instant::now();
|
||||
|
||||
connection_valid_until =
|
||||
ValidUntil::new_with_now(now, config.cleaning.max_connection_age);
|
||||
pending_scrape_valid_until =
|
||||
ValidUntil::new_with_now(now, config.cleaning.max_pending_scrape_age);
|
||||
|
||||
if now > last_connection_cleaning + connection_cleaning_duration {
|
||||
connections.clean();
|
||||
|
||||
last_connection_cleaning = now;
|
||||
}
|
||||
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration {
|
||||
pending_scrape_responses.clean();
|
||||
|
||||
|
|
@ -258,15 +217,13 @@ pub fn run_socket_worker(
|
|||
fn read_requests(
|
||||
config: &Config,
|
||||
state: &State,
|
||||
connections: &mut ConnectionMap,
|
||||
connection_validator: &mut ConnectionValidator,
|
||||
pending_scrape_responses: &mut PendingScrapeResponseSlab,
|
||||
access_list_cache: &mut AccessListCache,
|
||||
rng: &mut StdRng,
|
||||
socket: &mut UdpSocket,
|
||||
buffer: &mut [u8],
|
||||
request_sender: &ConnectedRequestSender,
|
||||
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
|
||||
connection_valid_until: ValidUntil,
|
||||
pending_scrape_valid_until: ValidUntil,
|
||||
) {
|
||||
let mut requests_received_ipv4: usize = 0;
|
||||
|
|
@ -297,13 +254,11 @@ fn read_requests(
|
|||
|
||||
handle_request(
|
||||
config,
|
||||
connections,
|
||||
connection_validator,
|
||||
pending_scrape_responses,
|
||||
access_list_cache,
|
||||
rng,
|
||||
request_sender,
|
||||
local_responses,
|
||||
connection_valid_until,
|
||||
pending_scrape_valid_until,
|
||||
res_request,
|
||||
src,
|
||||
|
|
@ -341,13 +296,11 @@ fn read_requests(
|
|||
|
||||
pub fn handle_request(
|
||||
config: &Config,
|
||||
connections: &mut ConnectionMap,
|
||||
connection_validator: &mut ConnectionValidator,
|
||||
pending_scrape_responses: &mut PendingScrapeResponseSlab,
|
||||
access_list_cache: &mut AccessListCache,
|
||||
rng: &mut StdRng,
|
||||
request_sender: &ConnectedRequestSender,
|
||||
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
|
||||
connection_valid_until: ValidUntil,
|
||||
pending_scrape_valid_until: ValidUntil,
|
||||
res_request: Result<Request, RequestParseError>,
|
||||
src: CanonicalSocketAddr,
|
||||
|
|
@ -356,9 +309,7 @@ pub fn handle_request(
|
|||
|
||||
match res_request {
|
||||
Ok(Request::Connect(request)) => {
|
||||
let connection_id = ConnectionId(rng.gen());
|
||||
|
||||
connections.insert(connection_id, src, connection_valid_until);
|
||||
let connection_id = connection_validator.create_connection_id(src);
|
||||
|
||||
let response = Response::Connect(ConnectResponse {
|
||||
connection_id,
|
||||
|
|
@ -368,7 +319,7 @@ pub fn handle_request(
|
|||
local_responses.push((response, src))
|
||||
}
|
||||
Ok(Request::Announce(request)) => {
|
||||
if connections.contains(request.connection_id, src) {
|
||||
if connection_validator.connection_id_valid(src, request.connection_id) {
|
||||
if access_list_cache
|
||||
.load()
|
||||
.allows(access_list_mode, &request.info_hash.0)
|
||||
|
|
@ -392,7 +343,7 @@ pub fn handle_request(
|
|||
}
|
||||
}
|
||||
Ok(Request::Scrape(request)) => {
|
||||
if connections.contains(request.connection_id, src) {
|
||||
if connection_validator.connection_id_valid(src, request.connection_id) {
|
||||
let split_requests = pending_scrape_responses.prepare_split_requests(
|
||||
config,
|
||||
request,
|
||||
|
|
@ -417,7 +368,7 @@ pub fn handle_request(
|
|||
err,
|
||||
} = err
|
||||
{
|
||||
if connections.contains(connection_id, src) {
|
||||
if connection_validator.connection_id_valid(src, connection_id) {
|
||||
let response = ErrorResponse {
|
||||
transaction_id,
|
||||
message: err.right_or("Parse error").into(),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue