mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-03-31 17:55:36 +00:00
491 lines
18 KiB
Rust
491 lines
18 KiB
Rust
use std::io::{Cursor, ErrorKind};
|
|
use std::sync::atomic::Ordering;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use anyhow::Context;
|
|
use aquatic_common::access_list::AccessListCache;
|
|
use mio::net::UdpSocket;
|
|
use mio::{Events, Interest, Poll, Token};
|
|
|
|
use aquatic_common::{
|
|
access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr,
|
|
ValidUntil,
|
|
};
|
|
use aquatic_udp_protocol::*;
|
|
|
|
use crate::common::*;
|
|
use crate::config::Config;
|
|
|
|
use super::storage::PendingScrapeResponseSlab;
|
|
use super::validator::ConnectionValidator;
|
|
use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6};
|
|
|
|
enum HandleRequestError {
|
|
RequestChannelFull(Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>),
|
|
}
|
|
|
|
#[derive(Clone, Copy, Debug)]
|
|
enum PollMode {
|
|
Regular,
|
|
SkipPolling,
|
|
SkipReceiving,
|
|
}
|
|
|
|
pub struct SocketWorker {
|
|
config: Config,
|
|
shared_state: State,
|
|
statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>,
|
|
request_sender: ConnectedRequestSender,
|
|
response_receiver: ConnectedResponseReceiver,
|
|
access_list_cache: AccessListCache,
|
|
validator: ConnectionValidator,
|
|
pending_scrape_responses: PendingScrapeResponseSlab,
|
|
socket: UdpSocket,
|
|
opt_resend_buffer: Option<Vec<(CanonicalSocketAddr, Response)>>,
|
|
buffer: [u8; BUFFER_SIZE],
|
|
polling_mode: PollMode,
|
|
/// Storage for requests that couldn't be sent to swarm worker because channel was full
|
|
pending_requests: Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>,
|
|
}
|
|
|
|
impl SocketWorker {
|
|
pub fn run(
|
|
config: Config,
|
|
shared_state: State,
|
|
statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>,
|
|
validator: ConnectionValidator,
|
|
request_sender: ConnectedRequestSender,
|
|
response_receiver: ConnectedResponseReceiver,
|
|
priv_dropper: PrivilegeDropper,
|
|
) -> anyhow::Result<()> {
|
|
let socket = UdpSocket::from_std(create_socket(&config, priv_dropper)?);
|
|
let access_list_cache = create_access_list_cache(&shared_state.access_list);
|
|
let opt_resend_buffer = (config.network.resend_buffer_max_len > 0).then_some(Vec::new());
|
|
|
|
let mut worker = Self {
|
|
config,
|
|
shared_state,
|
|
statistics,
|
|
validator,
|
|
request_sender,
|
|
response_receiver,
|
|
access_list_cache,
|
|
pending_scrape_responses: Default::default(),
|
|
socket,
|
|
opt_resend_buffer,
|
|
buffer: [0; BUFFER_SIZE],
|
|
polling_mode: PollMode::Regular,
|
|
pending_requests: Default::default(),
|
|
};
|
|
|
|
worker.run_inner()
|
|
}
|
|
|
|
pub fn run_inner(&mut self) -> anyhow::Result<()> {
|
|
let mut events = Events::with_capacity(1);
|
|
let mut poll = Poll::new().context("create poll")?;
|
|
|
|
poll.registry()
|
|
.register(&mut self.socket, Token(0), Interest::READABLE)
|
|
.context("register poll")?;
|
|
|
|
let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms);
|
|
|
|
let pending_scrape_cleaning_duration =
|
|
Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval);
|
|
|
|
let mut pending_scrape_valid_until = ValidUntil::new(
|
|
self.shared_state.server_start_instant,
|
|
self.config.cleaning.max_pending_scrape_age,
|
|
);
|
|
let mut last_pending_scrape_cleaning = Instant::now();
|
|
|
|
let mut iter_counter = 0usize;
|
|
|
|
loop {
|
|
match self.polling_mode {
|
|
PollMode::Regular => {
|
|
poll.poll(&mut events, Some(poll_timeout)).context("poll")?;
|
|
|
|
for event in events.iter() {
|
|
if event.is_readable() {
|
|
self.read_and_handle_requests(pending_scrape_valid_until);
|
|
}
|
|
}
|
|
}
|
|
PollMode::SkipPolling => {
|
|
self.polling_mode = PollMode::Regular;
|
|
|
|
// Continue reading from socket without polling, since
|
|
// reading was previouly cancelled
|
|
self.read_and_handle_requests(pending_scrape_valid_until);
|
|
}
|
|
PollMode::SkipReceiving => {
|
|
::log::debug!("Postponing receiving requests because swarm worker channel is full. This means that the OS will be relied on to buffer incoming packets. To prevent this, raise config.worker_channel_size.");
|
|
|
|
self.polling_mode = PollMode::SkipPolling;
|
|
}
|
|
}
|
|
|
|
// If resend buffer is enabled, send any responses in it
|
|
if let Some(resend_buffer) = self.opt_resend_buffer.as_mut() {
|
|
for (addr, response) in resend_buffer.drain(..) {
|
|
send_response(
|
|
&self.config,
|
|
&self.statistics,
|
|
&mut self.socket,
|
|
&mut self.buffer,
|
|
&mut None,
|
|
response,
|
|
addr,
|
|
);
|
|
}
|
|
}
|
|
|
|
// Check channel for any responses generated by swarm workers
|
|
self.handle_swarm_worker_responses();
|
|
|
|
// Try sending pending requests
|
|
while let Some((index, request, addr)) = self.pending_requests.pop() {
|
|
if let Err(r) = self.request_sender.try_send_to(index, request, addr) {
|
|
self.pending_requests.push(r);
|
|
|
|
self.polling_mode = PollMode::SkipReceiving;
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Run periodic ValidUntil updates and state cleaning
|
|
if iter_counter % 256 == 0 {
|
|
let seconds_since_start = self.shared_state.server_start_instant.seconds_elapsed();
|
|
|
|
pending_scrape_valid_until = ValidUntil::new_with_now(
|
|
seconds_since_start,
|
|
self.config.cleaning.max_pending_scrape_age,
|
|
);
|
|
|
|
let now = Instant::now();
|
|
|
|
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration {
|
|
self.pending_scrape_responses.clean(seconds_since_start);
|
|
|
|
last_pending_scrape_cleaning = now;
|
|
}
|
|
}
|
|
|
|
iter_counter = iter_counter.wrapping_add(1);
|
|
}
|
|
}
|
|
|
|
fn read_and_handle_requests(&mut self, pending_scrape_valid_until: ValidUntil) {
|
|
let max_scrape_torrents = self.config.protocol.max_scrape_torrents;
|
|
|
|
loop {
|
|
match self.socket.recv_from(&mut self.buffer[..]) {
|
|
Ok((bytes_read, src)) => {
|
|
let src_port = src.port();
|
|
let src = CanonicalSocketAddr::new(src);
|
|
|
|
// Use canonical address for statistics
|
|
let opt_statistics = if self.config.statistics.active() {
|
|
if src.is_ipv4() {
|
|
let statistics = &self.statistics.ipv4;
|
|
|
|
statistics
|
|
.bytes_received
|
|
.fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed);
|
|
|
|
Some(statistics)
|
|
} else {
|
|
let statistics = &self.statistics.ipv6;
|
|
|
|
statistics
|
|
.bytes_received
|
|
.fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed);
|
|
|
|
Some(statistics)
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if src_port == 0 {
|
|
::log::debug!("Ignored request because source port is zero");
|
|
|
|
continue;
|
|
}
|
|
|
|
match Request::parse_bytes(&self.buffer[..bytes_read], max_scrape_torrents) {
|
|
Ok(request) => {
|
|
if let Some(statistics) = opt_statistics {
|
|
statistics.requests.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
|
|
if let Err(HandleRequestError::RequestChannelFull(failed_requests)) =
|
|
self.handle_request(pending_scrape_valid_until, request, src)
|
|
{
|
|
self.pending_requests.extend(failed_requests);
|
|
self.polling_mode = PollMode::SkipReceiving;
|
|
|
|
break;
|
|
}
|
|
}
|
|
Err(RequestParseError::Sendable {
|
|
connection_id,
|
|
transaction_id,
|
|
err,
|
|
}) if self.validator.connection_id_valid(src, connection_id) => {
|
|
let response = ErrorResponse {
|
|
transaction_id,
|
|
message: err.into(),
|
|
};
|
|
|
|
send_response(
|
|
&self.config,
|
|
&self.statistics,
|
|
&mut self.socket,
|
|
&mut self.buffer,
|
|
&mut self.opt_resend_buffer,
|
|
Response::Error(response),
|
|
src,
|
|
);
|
|
|
|
::log::debug!("request parse error (sent error response): {:?}", err);
|
|
}
|
|
Err(err) => {
|
|
::log::debug!(
|
|
"request parse error (didn't send error response): {:?}",
|
|
err
|
|
);
|
|
}
|
|
};
|
|
}
|
|
Err(err) if err.kind() == ErrorKind::WouldBlock => {
|
|
break;
|
|
}
|
|
Err(err) => {
|
|
::log::warn!("recv_from error: {:#}", err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn handle_request(
|
|
&mut self,
|
|
pending_scrape_valid_until: ValidUntil,
|
|
request: Request,
|
|
src: CanonicalSocketAddr,
|
|
) -> Result<(), HandleRequestError> {
|
|
let access_list_mode = self.config.access_list.mode;
|
|
|
|
match request {
|
|
Request::Connect(request) => {
|
|
let connection_id = self.validator.create_connection_id(src);
|
|
|
|
let response = ConnectResponse {
|
|
connection_id,
|
|
transaction_id: request.transaction_id,
|
|
};
|
|
|
|
send_response(
|
|
&self.config,
|
|
&self.statistics,
|
|
&mut self.socket,
|
|
&mut self.buffer,
|
|
&mut self.opt_resend_buffer,
|
|
Response::Connect(response),
|
|
src,
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
Request::Announce(request) => {
|
|
if self
|
|
.validator
|
|
.connection_id_valid(src, request.connection_id)
|
|
{
|
|
if self
|
|
.access_list_cache
|
|
.load()
|
|
.allows(access_list_mode, &request.info_hash.0)
|
|
{
|
|
let worker_index =
|
|
SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash);
|
|
|
|
self.request_sender
|
|
.try_send_to(worker_index, ConnectedRequest::Announce(request), src)
|
|
.map_err(|request| {
|
|
HandleRequestError::RequestChannelFull(vec![request])
|
|
})
|
|
} else {
|
|
let response = ErrorResponse {
|
|
transaction_id: request.transaction_id,
|
|
message: "Info hash not allowed".into(),
|
|
};
|
|
|
|
send_response(
|
|
&self.config,
|
|
&self.statistics,
|
|
&mut self.socket,
|
|
&mut self.buffer,
|
|
&mut self.opt_resend_buffer,
|
|
Response::Error(response),
|
|
src,
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
Request::Scrape(request) => {
|
|
if self
|
|
.validator
|
|
.connection_id_valid(src, request.connection_id)
|
|
{
|
|
let split_requests = self.pending_scrape_responses.prepare_split_requests(
|
|
&self.config,
|
|
request,
|
|
pending_scrape_valid_until,
|
|
);
|
|
|
|
let mut failed = Vec::new();
|
|
|
|
for (swarm_worker_index, request) in split_requests {
|
|
if let Err(request) = self.request_sender.try_send_to(
|
|
swarm_worker_index,
|
|
ConnectedRequest::Scrape(request),
|
|
src,
|
|
) {
|
|
failed.push(request);
|
|
}
|
|
}
|
|
|
|
if failed.is_empty() {
|
|
Ok(())
|
|
} else {
|
|
Err(HandleRequestError::RequestChannelFull(failed))
|
|
}
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn handle_swarm_worker_responses(&mut self) {
|
|
for (addr, response) in self.response_receiver.try_iter() {
|
|
let response = match response {
|
|
ConnectedResponse::Scrape(response) => {
|
|
if let Some(r) = self
|
|
.pending_scrape_responses
|
|
.add_and_get_finished(&response)
|
|
{
|
|
Response::Scrape(r)
|
|
} else {
|
|
continue;
|
|
}
|
|
}
|
|
ConnectedResponse::AnnounceIpv4(r) => Response::AnnounceIpv4(r),
|
|
ConnectedResponse::AnnounceIpv6(r) => Response::AnnounceIpv6(r),
|
|
};
|
|
|
|
send_response(
|
|
&self.config,
|
|
&self.statistics,
|
|
&mut self.socket,
|
|
&mut self.buffer,
|
|
&mut self.opt_resend_buffer,
|
|
response,
|
|
addr,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn send_response(
|
|
config: &Config,
|
|
statistics: &CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>,
|
|
socket: &mut UdpSocket,
|
|
buffer: &mut [u8],
|
|
opt_resend_buffer: &mut Option<Vec<(CanonicalSocketAddr, Response)>>,
|
|
response: Response,
|
|
canonical_addr: CanonicalSocketAddr,
|
|
) {
|
|
let mut buffer = Cursor::new(&mut buffer[..]);
|
|
|
|
if let Err(err) = response.write_bytes(&mut buffer) {
|
|
::log::error!("failed writing response to buffer: {:#}", err);
|
|
|
|
return;
|
|
}
|
|
|
|
let bytes_written = buffer.position() as usize;
|
|
|
|
let addr = if config.network.address.is_ipv4() {
|
|
canonical_addr
|
|
.get_ipv4()
|
|
.expect("found peer ipv6 address while running bound to ipv4 address")
|
|
} else {
|
|
canonical_addr.get_ipv6_mapped()
|
|
};
|
|
|
|
match socket.send_to(&buffer.into_inner()[..bytes_written], addr) {
|
|
Ok(amt) if config.statistics.active() => {
|
|
let stats = if canonical_addr.is_ipv4() {
|
|
let stats = &statistics.ipv4;
|
|
|
|
stats
|
|
.bytes_sent
|
|
.fetch_add(amt + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed);
|
|
|
|
stats
|
|
} else {
|
|
let stats = &statistics.ipv6;
|
|
|
|
stats
|
|
.bytes_sent
|
|
.fetch_add(amt + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed);
|
|
|
|
stats
|
|
};
|
|
|
|
match response {
|
|
Response::Connect(_) => {
|
|
stats.responses_connect.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => {
|
|
stats.responses_announce.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
Response::Scrape(_) => {
|
|
stats.responses_scrape.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
Response::Error(_) => {
|
|
stats.responses_error.fetch_add(1, Ordering::Relaxed);
|
|
}
|
|
}
|
|
}
|
|
Ok(_) => (),
|
|
Err(err) => {
|
|
match opt_resend_buffer.as_mut() {
|
|
Some(resend_buffer)
|
|
if (err.raw_os_error() == Some(libc::ENOBUFS))
|
|
|| (err.kind() == ErrorKind::WouldBlock) =>
|
|
{
|
|
if resend_buffer.len() < config.network.resend_buffer_max_len {
|
|
::log::debug!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err);
|
|
|
|
resend_buffer.push((canonical_addr, response));
|
|
} else {
|
|
::log::warn!("Response resend buffer full, dropping response");
|
|
}
|
|
}
|
|
_ => {
|
|
::log::warn!("Sending response to {} failed: {:#}", addr, err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|