udp: rewrite socket worker to use SocketWorker struct

Also, stop checking token number all the time
This commit is contained in:
Joakim Frostegård 2022-10-25 01:33:35 +02:00
parent 9d37b3d285
commit 4587c267d6
4 changed files with 366 additions and 400 deletions

View file

@ -21,6 +21,7 @@ use common::{
}; };
use config::Config; use config::Config;
use workers::socket::validator::ConnectionValidator; use workers::socket::validator::ConnectionValidator;
use workers::socket::SocketWorker;
pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker"; pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker";
pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
@ -121,11 +122,10 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
WorkerIndex::SocketWorker(i), WorkerIndex::SocketWorker(i),
); );
workers::socket::run_socket_worker( SocketWorker::run(
sentinel, sentinel,
state, state,
config, config,
i,
connection_validator, connection_validator,
server_start_instant, server_start_instant,
request_sender, request_sender,

View file

@ -1,11 +1,12 @@
mod requests;
mod responses;
mod storage; mod storage;
pub mod validator; pub mod validator;
use std::io::{Cursor, ErrorKind};
use std::sync::atomic::Ordering;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use anyhow::Context; use anyhow::Context;
use aquatic_common::access_list::AccessListCache;
use aquatic_common::ServerStartInstant; use aquatic_common::ServerStartInstant;
use crossbeam_channel::Receiver; use crossbeam_channel::Receiver;
use mio::net::UdpSocket; use mio::net::UdpSocket;
@ -21,105 +22,397 @@ use aquatic_udp_protocol::*;
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
use requests::read_requests;
use responses::send_responses;
use storage::PendingScrapeResponseSlab; use storage::PendingScrapeResponseSlab;
use validator::ConnectionValidator; use validator::ConnectionValidator;
pub fn run_socket_worker( pub struct SocketWorker {
_sentinel: PanicSentinel,
state: State,
config: Config, config: Config,
token_num: usize, shared_state: State,
mut connection_validator: ConnectionValidator,
server_start_instant: ServerStartInstant,
request_sender: ConnectedRequestSender, request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
priv_dropper: PrivilegeDropper, access_list_cache: AccessListCache,
) { validator: ConnectionValidator,
let mut buffer = [0u8; BUFFER_SIZE]; server_start_instant: ServerStartInstant,
pending_scrape_responses: PendingScrapeResponseSlab,
socket: UdpSocket,
buffer: [u8; BUFFER_SIZE],
}
let mut socket = impl SocketWorker {
UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket")); pub fn run(
let mut poll = Poll::new().expect("create poll"); _sentinel: PanicSentinel,
shared_state: State,
config: Config,
validator: ConnectionValidator,
server_start_instant: ServerStartInstant,
request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
priv_dropper: PrivilegeDropper,
) {
let socket =
UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket"));
let access_list_cache = create_access_list_cache(&shared_state.access_list);
let interests = Interest::READABLE; let mut worker = Self {
config,
shared_state,
validator,
server_start_instant,
request_sender,
response_receiver,
access_list_cache,
pending_scrape_responses: Default::default(),
socket,
buffer: [0; BUFFER_SIZE],
};
poll.registry() worker.run_inner();
.register(&mut socket, Token(token_num), interests) }
.unwrap();
let mut events = Events::with_capacity(config.network.poll_event_capacity); pub fn run_inner(&mut self) {
let mut pending_scrape_responses = PendingScrapeResponseSlab::default(); let mut local_responses = Vec::new();
let mut access_list_cache = create_access_list_cache(&state.access_list); let mut opt_resend_buffer =
(self.config.network.resend_buffer_max_len > 0).then_some(Vec::new());
let mut local_responses: Vec<(Response, CanonicalSocketAddr)> = Vec::new(); let mut events = Events::with_capacity(self.config.network.poll_event_capacity);
let mut opt_resend_buffer = (config.network.resend_buffer_max_len > 0).then_some(Vec::new()); let mut poll = Poll::new().expect("create poll");
let poll_timeout = Duration::from_millis(config.network.poll_timeout_ms); poll.registry()
.register(&mut self.socket, Token(0), Interest::READABLE)
.expect("register poll");
let pending_scrape_cleaning_duration = let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms);
Duration::from_secs(config.cleaning.pending_scrape_cleaning_interval);
let mut pending_scrape_valid_until = let pending_scrape_cleaning_duration =
ValidUntil::new(server_start_instant, config.cleaning.max_pending_scrape_age); Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval);
let mut last_pending_scrape_cleaning = Instant::now();
let mut iter_counter = 0usize; let mut pending_scrape_valid_until = ValidUntil::new(
self.server_start_instant,
self.config.cleaning.max_pending_scrape_age,
);
let mut last_pending_scrape_cleaning = Instant::now();
loop { let mut iter_counter = 0usize;
poll.poll(&mut events, Some(poll_timeout))
.expect("failed polling");
for event in events.iter() { loop {
let token = event.token(); poll.poll(&mut events, Some(poll_timeout))
.expect("failed polling");
if (token.0 == token_num) & event.is_readable() { for event in events.iter() {
read_requests( if event.is_readable() {
&config, self.read_requests(&mut local_responses, pending_scrape_valid_until);
&state, }
&mut connection_validator, }
&mut pending_scrape_responses,
&mut access_list_cache, if let Some(resend_buffer) = opt_resend_buffer.as_mut() {
&mut socket, for (response, addr) in resend_buffer.drain(..) {
&mut buffer, Self::send_response(
&request_sender, &self.config,
&mut local_responses, &self.shared_state,
pending_scrape_valid_until, &mut self.socket,
&mut self.buffer,
&mut None,
response,
addr,
);
}
}
for (response, addr) in local_responses.drain(..) {
Self::send_response(
&self.config,
&self.shared_state,
&mut self.socket,
&mut self.buffer,
&mut opt_resend_buffer,
response,
addr,
); );
} }
for (response, addr) in self.response_receiver.try_iter() {
let opt_response = match response {
ConnectedResponse::Scrape(r) => self
.pending_scrape_responses
.add_and_get_finished(r)
.map(Response::Scrape),
ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)),
ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)),
};
if let Some(response) = opt_response {
Self::send_response(
&self.config,
&self.shared_state,
&mut self.socket,
&mut self.buffer,
&mut opt_resend_buffer,
response,
addr,
);
}
}
// Run periodic ValidUntil updates and state cleaning
if iter_counter % 256 == 0 {
let seconds_since_start = self.server_start_instant.seconds_elapsed();
pending_scrape_valid_until = ValidUntil::new_with_now(
seconds_since_start,
self.config.cleaning.max_pending_scrape_age,
);
let now = Instant::now();
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration {
self.pending_scrape_responses.clean(seconds_since_start);
last_pending_scrape_cleaning = now;
}
}
iter_counter = iter_counter.wrapping_add(1);
} }
}
send_responses( fn read_requests(
&state, &mut self,
&config, local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
&mut socket, pending_scrape_valid_until: ValidUntil,
&mut buffer, ) {
&response_receiver, let mut requests_received_ipv4: usize = 0;
&mut pending_scrape_responses, let mut requests_received_ipv6: usize = 0;
local_responses.drain(..), let mut bytes_received_ipv4: usize = 0;
&mut opt_resend_buffer, let mut bytes_received_ipv6 = 0;
);
// Run periodic ValidUntil updates and state cleaning loop {
if iter_counter % 256 == 0 { match self.socket.recv_from(&mut self.buffer[..]) {
let seconds_since_start = server_start_instant.seconds_elapsed(); Ok((bytes_read, src)) => {
if src.port() == 0 {
::log::info!("Ignored request from {} because source port is zero", src);
pending_scrape_valid_until = ValidUntil::new_with_now( continue;
seconds_since_start, }
config.cleaning.max_pending_scrape_age,
);
let now = Instant::now(); let res_request = Request::from_bytes(
&self.buffer[..bytes_read],
self.config.protocol.max_scrape_torrents,
);
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration { let src = CanonicalSocketAddr::new(src);
pending_scrape_responses.clean(seconds_since_start);
last_pending_scrape_cleaning = now; // Update statistics for converted address
if src.is_ipv4() {
if res_request.is_ok() {
requests_received_ipv4 += 1;
}
bytes_received_ipv4 += bytes_read;
} else {
if res_request.is_ok() {
requests_received_ipv6 += 1;
}
bytes_received_ipv6 += bytes_read;
}
self.handle_request(
local_responses,
pending_scrape_valid_until,
res_request,
src,
);
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
break;
}
Err(err) => {
::log::warn!("recv_from error: {:#}", err);
}
} }
} }
iter_counter = iter_counter.wrapping_add(1); if self.config.statistics.active() {
self.shared_state
.statistics_ipv4
.requests_received
.fetch_add(requests_received_ipv4, Ordering::Release);
self.shared_state
.statistics_ipv6
.requests_received
.fetch_add(requests_received_ipv6, Ordering::Release);
self.shared_state
.statistics_ipv4
.bytes_received
.fetch_add(bytes_received_ipv4, Ordering::Release);
self.shared_state
.statistics_ipv6
.bytes_received
.fetch_add(bytes_received_ipv6, Ordering::Release);
}
}
fn handle_request(
&mut self,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
pending_scrape_valid_until: ValidUntil,
res_request: Result<Request, RequestParseError>,
src: CanonicalSocketAddr,
) {
let access_list_mode = self.config.access_list.mode;
match res_request {
Ok(Request::Connect(request)) => {
let connection_id = self.validator.create_connection_id(src);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if self
.validator
.connection_id_valid(src, request.connection_id)
{
if self
.access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
let worker_index =
SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash);
self.request_sender.try_send_to(
worker_index,
ConnectedRequest::Announce(request),
src,
);
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Ok(Request::Scrape(request)) => {
if self
.validator
.connection_id_valid(src, request.connection_id)
{
let split_requests = self.pending_scrape_responses.prepare_split_requests(
&self.config,
request,
pending_scrape_valid_until,
);
for (swarm_worker_index, request) in split_requests {
self.request_sender.try_send_to(
swarm_worker_index,
ConnectedRequest::Scrape(request),
src,
);
}
}
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if self.validator.connection_id_valid(src, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_responses.push((response.into(), src));
}
}
}
}
}
fn send_response(
config: &Config,
shared_state: &State,
socket: &mut UdpSocket,
buffer: &mut [u8],
opt_resend_buffer: &mut Option<Vec<(Response, CanonicalSocketAddr)>>,
response: Response,
canonical_addr: CanonicalSocketAddr,
) {
let mut cursor = Cursor::new(buffer);
if let Err(err) = response.write(&mut cursor) {
::log::error!("Converting response to bytes failed: {:#}", err);
return;
}
let bytes_written = cursor.position() as usize;
let addr = if config.network.address.is_ipv4() {
canonical_addr
.get_ipv4()
.expect("found peer ipv6 address while running bound to ipv4 address")
} else {
canonical_addr.get_ipv6_mapped()
};
match socket.send_to(&cursor.get_ref()[..bytes_written], addr) {
Ok(amt) if config.statistics.active() => {
let stats = if canonical_addr.is_ipv4() {
&shared_state.statistics_ipv4
} else {
&shared_state.statistics_ipv6
};
stats.bytes_sent.fetch_add(amt, Ordering::Relaxed);
match response {
Response::Connect(_) => {
stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed);
}
Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => {
stats
.responses_sent_announce
.fetch_add(1, Ordering::Relaxed);
}
Response::Scrape(_) => {
stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed);
}
Response::Error(_) => {
stats.responses_sent_error.fetch_add(1, Ordering::Relaxed);
}
}
}
Ok(_) => (),
Err(err) => match opt_resend_buffer.as_mut() {
Some(resend_buffer)
if (err.raw_os_error() == Some(libc::ENOBUFS))
|| (err.kind() == ErrorKind::WouldBlock) =>
{
if resend_buffer.len() < config.network.resend_buffer_max_len {
::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err);
resend_buffer.push((response, canonical_addr));
} else {
::log::warn!("Response resend buffer full, dropping response");
}
}
_ => {
::log::warn!("Sending response to {} failed: {:#}", addr, err);
}
},
}
} }
} }

View file

@ -1,184 +0,0 @@
use std::io::ErrorKind;
use std::sync::atomic::Ordering;
use mio::net::UdpSocket;
use aquatic_common::{access_list::AccessListCache, CanonicalSocketAddr, ValidUntil};
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
use super::storage::PendingScrapeResponseSlab;
use super::validator::ConnectionValidator;
pub fn read_requests(
config: &Config,
state: &State,
connection_validator: &mut ConnectionValidator,
pending_scrape_responses: &mut PendingScrapeResponseSlab,
access_list_cache: &mut AccessListCache,
socket: &mut UdpSocket,
buffer: &mut [u8],
request_sender: &ConnectedRequestSender,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
pending_scrape_valid_until: ValidUntil,
) {
let mut requests_received_ipv4: usize = 0;
let mut requests_received_ipv6: usize = 0;
let mut bytes_received_ipv4: usize = 0;
let mut bytes_received_ipv6 = 0;
loop {
match socket.recv_from(&mut buffer[..]) {
Ok((bytes_read, src)) => {
if src.port() == 0 {
::log::info!("Ignored request from {} because source port is zero", src);
continue;
}
let res_request =
Request::from_bytes(&buffer[..bytes_read], config.protocol.max_scrape_torrents);
let src = CanonicalSocketAddr::new(src);
// Update statistics for converted address
if src.is_ipv4() {
if res_request.is_ok() {
requests_received_ipv4 += 1;
}
bytes_received_ipv4 += bytes_read;
} else {
if res_request.is_ok() {
requests_received_ipv6 += 1;
}
bytes_received_ipv6 += bytes_read;
}
handle_request(
config,
connection_validator,
pending_scrape_responses,
access_list_cache,
request_sender,
local_responses,
pending_scrape_valid_until,
res_request,
src,
);
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
break;
}
Err(err) => {
::log::warn!("recv_from error: {:#}", err);
}
}
}
if config.statistics.active() {
state
.statistics_ipv4
.requests_received
.fetch_add(requests_received_ipv4, Ordering::Release);
state
.statistics_ipv6
.requests_received
.fetch_add(requests_received_ipv6, Ordering::Release);
state
.statistics_ipv4
.bytes_received
.fetch_add(bytes_received_ipv4, Ordering::Release);
state
.statistics_ipv6
.bytes_received
.fetch_add(bytes_received_ipv6, Ordering::Release);
}
}
fn handle_request(
config: &Config,
connection_validator: &mut ConnectionValidator,
pending_scrape_responses: &mut PendingScrapeResponseSlab,
access_list_cache: &mut AccessListCache,
request_sender: &ConnectedRequestSender,
local_responses: &mut Vec<(Response, CanonicalSocketAddr)>,
pending_scrape_valid_until: ValidUntil,
res_request: Result<Request, RequestParseError>,
src: CanonicalSocketAddr,
) {
let access_list_mode = config.access_list.mode;
match res_request {
Ok(Request::Connect(request)) => {
let connection_id = connection_validator.create_connection_id(src);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if connection_validator.connection_id_valid(src, request.connection_id) {
if access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
let worker_index = SwarmWorkerIndex::from_info_hash(config, request.info_hash);
request_sender.try_send_to(
worker_index,
ConnectedRequest::Announce(request),
src,
);
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Ok(Request::Scrape(request)) => {
if connection_validator.connection_id_valid(src, request.connection_id) {
let split_requests = pending_scrape_responses.prepare_split_requests(
config,
request,
pending_scrape_valid_until,
);
for (swarm_worker_index, request) in split_requests {
request_sender.try_send_to(
swarm_worker_index,
ConnectedRequest::Scrape(request),
src,
);
}
}
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if connection_validator.connection_id_valid(src, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_responses.push((response.into(), src));
}
}
}
}
}

View file

@ -1,143 +0,0 @@
use std::io::{Cursor, ErrorKind};
use std::sync::atomic::Ordering;
use std::vec::Drain;
use crossbeam_channel::Receiver;
use libc::ENOBUFS;
use mio::net::UdpSocket;
use aquatic_common::CanonicalSocketAddr;
use aquatic_udp_protocol::*;
use crate::common::*;
use crate::config::Config;
use super::storage::PendingScrapeResponseSlab;
pub fn send_responses(
state: &State,
config: &Config,
socket: &mut UdpSocket,
buffer: &mut [u8],
response_receiver: &Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
pending_scrape_responses: &mut PendingScrapeResponseSlab,
local_responses: Drain<(Response, CanonicalSocketAddr)>,
opt_resend_buffer: &mut Option<Vec<(Response, CanonicalSocketAddr)>>,
) {
if let Some(resend_buffer) = opt_resend_buffer {
for (response, addr) in resend_buffer.drain(..) {
send_response(state, config, socket, buffer, response, addr, &mut None);
}
}
for (response, addr) in local_responses {
send_response(
state,
config,
socket,
buffer,
response,
addr,
opt_resend_buffer,
);
}
for (response, addr) in response_receiver.try_iter() {
let opt_response = match response {
ConnectedResponse::Scrape(r) => pending_scrape_responses
.add_and_get_finished(r)
.map(Response::Scrape),
ConnectedResponse::AnnounceIpv4(r) => Some(Response::AnnounceIpv4(r)),
ConnectedResponse::AnnounceIpv6(r) => Some(Response::AnnounceIpv6(r)),
};
if let Some(response) = opt_response {
send_response(
state,
config,
socket,
buffer,
response,
addr,
opt_resend_buffer,
);
}
}
}
fn send_response(
state: &State,
config: &Config,
socket: &mut UdpSocket,
buffer: &mut [u8],
response: Response,
canonical_addr: CanonicalSocketAddr,
resend_buffer: &mut Option<Vec<(Response, CanonicalSocketAddr)>>,
) {
let mut cursor = Cursor::new(buffer);
if let Err(err) = response.write(&mut cursor) {
::log::error!("Converting response to bytes failed: {:#}", err);
return;
}
let bytes_written = cursor.position() as usize;
let addr = if config.network.address.is_ipv4() {
canonical_addr
.get_ipv4()
.expect("found peer ipv6 address while running bound to ipv4 address")
} else {
canonical_addr.get_ipv6_mapped()
};
match socket.send_to(&cursor.get_ref()[..bytes_written], addr) {
Ok(amt) if config.statistics.active() => {
let stats = if canonical_addr.is_ipv4() {
&state.statistics_ipv4
} else {
&state.statistics_ipv6
};
stats.bytes_sent.fetch_add(amt, Ordering::Relaxed);
match response {
Response::Connect(_) => {
stats.responses_sent_connect.fetch_add(1, Ordering::Relaxed);
}
Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => {
stats
.responses_sent_announce
.fetch_add(1, Ordering::Relaxed);
}
Response::Scrape(_) => {
stats.responses_sent_scrape.fetch_add(1, Ordering::Relaxed);
}
Response::Error(_) => {
stats.responses_sent_error.fetch_add(1, Ordering::Relaxed);
}
}
}
Ok(_) => (),
Err(err) => {
match resend_buffer {
Some(resend_buffer)
if (err.raw_os_error() == Some(ENOBUFS))
|| (err.kind() == ErrorKind::WouldBlock) =>
{
if resend_buffer.len() < config.network.resend_buffer_max_len {
::log::info!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err);
resend_buffer.push((response, canonical_addr));
} else {
::log::warn!("Response resend buffer full, dropping response");
}
}
_ => {
::log::warn!("Sending response to {} failed: {:#}", addr, err);
}
}
}
}
}