WIP: use shared swarm state in mio worker

This commit is contained in:
Joakim Frostegård 2024-02-10 11:21:48 +01:00
parent 53497308f1
commit 2da966098f
5 changed files with 79 additions and 255 deletions

View file

@ -13,6 +13,7 @@ use crossbeam_utils::CachePadded;
use hdrhistogram::Histogram; use hdrhistogram::Histogram;
use crate::config::Config; use crate::config::Config;
use crate::swarm::TorrentMaps;
pub const BUFFER_SIZE: usize = 8192; pub const BUFFER_SIZE: usize = 8192;
@ -230,13 +231,15 @@ pub enum StatisticsMessage {
#[derive(Clone)] #[derive(Clone)]
pub struct State { pub struct State {
pub access_list: Arc<AccessListArcSwap>, pub access_list: Arc<AccessListArcSwap>,
pub torrent_maps: TorrentMaps,
pub server_start_instant: ServerStartInstant, pub server_start_instant: ServerStartInstant,
} }
impl Default for State { impl State {
fn default() -> Self { pub fn new(config: &Config) -> Self {
Self { Self {
access_list: Arc::new(AccessListArcSwap::default()), access_list: Arc::new(AccessListArcSwap::default()),
torrent_maps: TorrentMaps::new(config),
server_start_instant: ServerStartInstant::new(), server_start_instant: ServerStartInstant::new(),
} }
} }

View file

@ -23,6 +23,7 @@ use common::{
SwarmWorkerIndex, SwarmWorkerIndex,
}; };
use config::Config; use config::Config;
use swarm::TorrentMaps;
use workers::socket::ConnectionValidator; use workers::socket::ConnectionValidator;
use workers::swarm::SwarmWorker; use workers::swarm::SwarmWorker;
@ -32,81 +33,24 @@ pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
pub fn run(config: Config) -> ::anyhow::Result<()> { pub fn run(config: Config) -> ::anyhow::Result<()> {
let mut signals = Signals::new([SIGUSR1])?; let mut signals = Signals::new([SIGUSR1])?;
let state = State::default(); let state = State::new(&config);
let statistics = Statistics::new(&config); let statistics = Statistics::new(&config);
let connection_validator = ConnectionValidator::new(&config)?; let connection_validator = ConnectionValidator::new(&config)?;
let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
let mut join_handles = Vec::new(); let mut join_handles = Vec::new();
update_access_list(&config.access_list, &state.access_list)?; update_access_list(&config.access_list, &state.access_list)?;
let mut request_senders = Vec::new();
let mut request_receivers = BTreeMap::new();
let mut response_senders = Vec::new();
let mut response_receivers = BTreeMap::new();
let (statistics_sender, statistics_receiver) = unbounded(); let (statistics_sender, statistics_receiver) = unbounded();
for i in 0..config.swarm_workers {
let (request_sender, request_receiver) = bounded(config.worker_channel_size);
request_senders.push(request_sender);
request_receivers.insert(i, request_receiver);
}
for i in 0..config.socket_workers {
let (response_sender, response_receiver) = bounded(config.worker_channel_size);
response_senders.push(response_sender);
response_receivers.insert(i, response_receiver);
}
for i in 0..config.swarm_workers {
let config = config.clone();
let state = state.clone();
let request_receiver = request_receivers.remove(&i).unwrap().clone();
let response_sender = ConnectedResponseSender::new(response_senders.clone());
let statistics_sender = statistics_sender.clone();
let statistics = statistics.swarm[i].clone();
let handle = Builder::new()
.name(format!("swarm-{:02}", i + 1))
.spawn(move || {
#[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to(
&config.cpu_pinning,
config.socket_workers,
config.swarm_workers,
WorkerIndex::SwarmWorker(i),
);
let mut worker = SwarmWorker {
config,
state,
statistics,
request_receiver,
response_sender,
statistics_sender,
worker_index: SwarmWorkerIndex(i),
};
worker.run()
})
.with_context(|| "spawn swarm worker")?;
join_handles.push((WorkerType::Swarm(i), handle));
}
for i in 0..config.socket_workers { for i in 0..config.socket_workers {
let state = state.clone(); let state = state.clone();
let config = config.clone(); let config = config.clone();
let connection_validator = connection_validator.clone(); let connection_validator = connection_validator.clone();
let request_sender =
ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone());
let response_receiver = response_receivers.remove(&i).unwrap();
let priv_dropper = priv_dropper.clone(); let priv_dropper = priv_dropper.clone();
let statistics = statistics.socket[i].clone(); let statistics = statistics.socket[i].clone();
let statistics_sender = statistics_sender.clone();
let handle = Builder::new() let handle = Builder::new()
.name(format!("socket-{:02}", i + 1)) .name(format!("socket-{:02}", i + 1))
@ -123,9 +67,8 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
config, config,
state, state,
statistics, statistics,
statistics_sender,
connection_validator, connection_validator,
request_sender,
response_receiver,
priv_dropper, priv_dropper,
) )
}) })

View file

@ -16,6 +16,7 @@ use aquatic_udp_protocol::*;
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use crossbeam_channel::Sender; use crossbeam_channel::Sender;
use hdrhistogram::Histogram; use hdrhistogram::Histogram;
use parking_lot::RwLockUpgradableReadGuard;
use rand::prelude::SmallRng; use rand::prelude::SmallRng;
use rand::Rng; use rand::Rng;
@ -37,7 +38,7 @@ pub struct TorrentMaps {
impl TorrentMaps { impl TorrentMaps {
pub fn new(config: &Config) -> Self { pub fn new(config: &Config) -> Self {
let num_shards = 128usize; let num_shards = 16usize;
Self { Self {
ipv4: TorrentMapShards::new(num_shards), ipv4: TorrentMapShards::new(num_shards),
@ -51,10 +52,10 @@ impl TorrentMaps {
statistics_sender: &Sender<StatisticsMessage>, statistics_sender: &Sender<StatisticsMessage>,
rng: &mut SmallRng, rng: &mut SmallRng,
request: &AnnounceRequest, request: &AnnounceRequest,
ip_address: CanonicalSocketAddr, src: CanonicalSocketAddr,
valid_until: ValidUntil, valid_until: ValidUntil,
) -> Response { ) -> Response {
match ip_address.get().ip() { match src.get().ip() {
IpAddr::V4(ip_address) => Response::AnnounceIpv4(self.ipv4.announce( IpAddr::V4(ip_address) => Response::AnnounceIpv4(self.ipv4.announce(
config, config,
statistics_sender, statistics_sender,
@ -74,8 +75,8 @@ impl TorrentMaps {
} }
} }
pub fn scrape(&self, ip_addr: CanonicalSocketAddr, request: ScrapeRequest) -> ScrapeResponse { pub fn scrape(&self, request: ScrapeRequest, src: CanonicalSocketAddr) -> ScrapeResponse {
if ip_addr.is_ipv4() { if src.is_ipv4() {
self.ipv4.scrape(request) self.ipv4.scrape(request)
} else { } else {
self.ipv6.scrape(request) self.ipv6.scrape(request)
@ -142,19 +143,19 @@ impl<I: Ip> TorrentMapShards<I> {
ip_address: I, ip_address: I,
valid_until: ValidUntil, valid_until: ValidUntil,
) -> AnnounceResponse<I> { ) -> AnnounceResponse<I> {
let torrent_map_shard = self.get_shard(&request.info_hash); let torrent_data = {
let torrent_map_shard = self.get_shard(&request.info_hash).upgradable_read();
// Clone Arc here to avoid keeping lock on whole shard // Clone Arc here to avoid keeping lock on whole shard
let torrent_data = if let Some(torrent_data) = torrent_map_shard.get(&request.info_hash) {
if let Some(torrent_data) = torrent_map_shard.read().get(&request.info_hash) {
torrent_data.clone() torrent_data.clone()
} else { } else {
// Don't overwrite entry if created in the meantime // Don't overwrite entry if created in the meantime
torrent_map_shard RwLockUpgradableReadGuard::upgrade(torrent_map_shard)
.write()
.entry(request.info_hash) .entry(request.info_hash)
.or_default() .or_default()
.clone() .clone()
}
}; };
let mut peer_map = torrent_data.peer_map.write(); let mut peer_map = torrent_data.peer_map.write();

View file

@ -1,9 +1,10 @@
use std::io::{Cursor, ErrorKind}; use std::io::{Cursor, ErrorKind};
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::{Duration, Instant}; use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use aquatic_common::access_list::AccessListCache; use aquatic_common::access_list::AccessListCache;
use crossbeam_channel::Sender;
use mio::net::UdpSocket; use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token}; use mio::{Events, Interest, Poll, Token};
@ -12,40 +13,26 @@ use aquatic_common::{
ValidUntil, ValidUntil,
}; };
use aquatic_udp_protocol::*; use aquatic_udp_protocol::*;
use rand::rngs::SmallRng;
use rand::SeedableRng;
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
use super::storage::PendingScrapeResponseSlab;
use super::validator::ConnectionValidator; use super::validator::ConnectionValidator;
use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6};
enum HandleRequestError {
RequestChannelFull(Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>),
}
#[derive(Clone, Copy, Debug)]
enum PollMode {
Regular,
SkipPolling,
SkipReceiving,
}
pub struct SocketWorker { pub struct SocketWorker {
config: Config, config: Config,
shared_state: State, shared_state: State,
statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>, statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>,
request_sender: ConnectedRequestSender, statistics_sender: Sender<StatisticsMessage>,
response_receiver: ConnectedResponseReceiver,
access_list_cache: AccessListCache, access_list_cache: AccessListCache,
validator: ConnectionValidator, validator: ConnectionValidator,
pending_scrape_responses: PendingScrapeResponseSlab,
socket: UdpSocket, socket: UdpSocket,
opt_resend_buffer: Option<Vec<(CanonicalSocketAddr, Response)>>, opt_resend_buffer: Option<Vec<(CanonicalSocketAddr, Response)>>,
buffer: [u8; BUFFER_SIZE], buffer: [u8; BUFFER_SIZE],
polling_mode: PollMode, rng: SmallRng,
/// Storage for requests that couldn't be sent to swarm worker because channel was full
pending_requests: Vec<(SwarmWorkerIndex, ConnectedRequest, CanonicalSocketAddr)>,
} }
impl SocketWorker { impl SocketWorker {
@ -53,9 +40,8 @@ impl SocketWorker {
config: Config, config: Config,
shared_state: State, shared_state: State,
statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>, statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>,
statistics_sender: Sender<StatisticsMessage>,
validator: ConnectionValidator, validator: ConnectionValidator,
request_sender: ConnectedRequestSender,
response_receiver: ConnectedResponseReceiver,
priv_dropper: PrivilegeDropper, priv_dropper: PrivilegeDropper,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let socket = UdpSocket::from_std(create_socket(&config, priv_dropper)?); let socket = UdpSocket::from_std(create_socket(&config, priv_dropper)?);
@ -66,16 +52,13 @@ impl SocketWorker {
config, config,
shared_state, shared_state,
statistics, statistics,
statistics_sender,
validator, validator,
request_sender,
response_receiver,
access_list_cache, access_list_cache,
pending_scrape_responses: Default::default(),
socket, socket,
opt_resend_buffer, opt_resend_buffer,
buffer: [0; BUFFER_SIZE], buffer: [0; BUFFER_SIZE],
polling_mode: PollMode::Regular, rng: SmallRng::from_entropy(),
pending_requests: Default::default(),
}; };
worker.run_inner() worker.run_inner()
@ -91,39 +74,12 @@ impl SocketWorker {
let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms); let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms);
let pending_scrape_cleaning_duration =
Duration::from_secs(self.config.cleaning.pending_scrape_cleaning_interval);
let mut pending_scrape_valid_until = ValidUntil::new(
self.shared_state.server_start_instant,
self.config.cleaning.max_pending_scrape_age,
);
let mut last_pending_scrape_cleaning = Instant::now();
let mut iter_counter = 0usize;
loop { loop {
match self.polling_mode {
PollMode::Regular => {
poll.poll(&mut events, Some(poll_timeout)).context("poll")?; poll.poll(&mut events, Some(poll_timeout)).context("poll")?;
for event in events.iter() { for event in events.iter() {
if event.is_readable() { if event.is_readable() {
self.read_and_handle_requests(pending_scrape_valid_until); self.read_and_handle_requests();
}
}
}
PollMode::SkipPolling => {
self.polling_mode = PollMode::Regular;
// Continue reading from socket without polling, since
// reading was previouly cancelled
self.read_and_handle_requests(pending_scrape_valid_until);
}
PollMode::SkipReceiving => {
::log::debug!("Postponing receiving requests because swarm worker channel is full. This means that the OS will be relied on to buffer incoming packets. To prevent this, raise config.worker_channel_size.");
self.polling_mode = PollMode::SkipPolling;
} }
} }
@ -141,44 +97,10 @@ impl SocketWorker {
); );
} }
} }
// Check channel for any responses generated by swarm workers
self.handle_swarm_worker_responses();
// Try sending pending requests
while let Some((index, request, addr)) = self.pending_requests.pop() {
if let Err(r) = self.request_sender.try_send_to(index, request, addr) {
self.pending_requests.push(r);
self.polling_mode = PollMode::SkipReceiving;
break;
} }
} }
// Run periodic ValidUntil updates and state cleaning fn read_and_handle_requests(&mut self) {
if iter_counter % 256 == 0 {
let seconds_since_start = self.shared_state.server_start_instant.seconds_elapsed();
pending_scrape_valid_until = ValidUntil::new_with_now(
seconds_since_start,
self.config.cleaning.max_pending_scrape_age,
);
let now = Instant::now();
if now > last_pending_scrape_cleaning + pending_scrape_cleaning_duration {
self.pending_scrape_responses.clean(seconds_since_start);
last_pending_scrape_cleaning = now;
}
}
iter_counter = iter_counter.wrapping_add(1);
}
}
fn read_and_handle_requests(&mut self, pending_scrape_valid_until: ValidUntil) {
let max_scrape_torrents = self.config.protocol.max_scrape_torrents; let max_scrape_torrents = self.config.protocol.max_scrape_torrents;
loop { loop {
@ -222,14 +144,7 @@ impl SocketWorker {
statistics.requests.fetch_add(1, Ordering::Relaxed); statistics.requests.fetch_add(1, Ordering::Relaxed);
} }
if let Err(HandleRequestError::RequestChannelFull(failed_requests)) = self.handle_request(request, src);
self.handle_request(pending_scrape_valid_until, request, src)
{
self.pending_requests.extend(failed_requests);
self.polling_mode = PollMode::SkipReceiving;
break;
}
} }
Err(RequestParseError::Sendable { Err(RequestParseError::Sendable {
connection_id, connection_id,
@ -271,20 +186,13 @@ impl SocketWorker {
} }
} }
fn handle_request( fn handle_request(&mut self, request: Request, src: CanonicalSocketAddr) {
&mut self,
pending_scrape_valid_until: ValidUntil,
request: Request,
src: CanonicalSocketAddr,
) -> Result<(), HandleRequestError> {
let access_list_mode = self.config.access_list.mode; let access_list_mode = self.config.access_list.mode;
match request { match request {
Request::Connect(request) => { Request::Connect(request) => {
let connection_id = self.validator.create_connection_id(src);
let response = ConnectResponse { let response = ConnectResponse {
connection_id, connection_id: self.validator.create_connection_id(src),
transaction_id: request.transaction_id, transaction_id: request.transaction_id,
}; };
@ -297,8 +205,6 @@ impl SocketWorker {
Response::Connect(response), Response::Connect(response),
src, src,
); );
Ok(())
} }
Request::Announce(request) => { Request::Announce(request) => {
if self if self
@ -310,14 +216,29 @@ impl SocketWorker {
.load() .load()
.allows(access_list_mode, &request.info_hash.0) .allows(access_list_mode, &request.info_hash.0)
{ {
let worker_index = let peer_valid_until = ValidUntil::new(
SwarmWorkerIndex::from_info_hash(&self.config, request.info_hash); self.shared_state.server_start_instant,
self.config.cleaning.max_peer_age,
);
self.request_sender let response = self.shared_state.torrent_maps.announce(
.try_send_to(worker_index, ConnectedRequest::Announce(request), src) &self.config,
.map_err(|request| { &self.statistics_sender,
HandleRequestError::RequestChannelFull(vec![request]) &mut self.rng,
}) &request,
src,
peer_valid_until,
);
send_response(
&self.config,
&self.statistics,
&mut self.socket,
&mut self.buffer,
&mut self.opt_resend_buffer,
response,
src,
);
} else { } else {
let response = ErrorResponse { let response = ErrorResponse {
transaction_id: request.transaction_id, transaction_id: request.transaction_id,
@ -333,11 +254,7 @@ impl SocketWorker {
Response::Error(response), Response::Error(response),
src, src,
); );
Ok(())
} }
} else {
Ok(())
} }
} }
Request::Scrape(request) => { Request::Scrape(request) => {
@ -345,52 +262,7 @@ impl SocketWorker {
.validator .validator
.connection_id_valid(src, request.connection_id) .connection_id_valid(src, request.connection_id)
{ {
let split_requests = self.pending_scrape_responses.prepare_split_requests( let response = self.shared_state.torrent_maps.scrape(request, src);
&self.config,
request,
pending_scrape_valid_until,
);
let mut failed = Vec::new();
for (swarm_worker_index, request) in split_requests {
if let Err(request) = self.request_sender.try_send_to(
swarm_worker_index,
ConnectedRequest::Scrape(request),
src,
) {
failed.push(request);
}
}
if failed.is_empty() {
Ok(())
} else {
Err(HandleRequestError::RequestChannelFull(failed))
}
} else {
Ok(())
}
}
}
}
fn handle_swarm_worker_responses(&mut self) {
for (addr, response) in self.response_receiver.try_iter() {
let response = match response {
ConnectedResponse::Scrape(response) => {
if let Some(r) = self
.pending_scrape_responses
.add_and_get_finished(&response)
{
Response::Scrape(r)
} else {
continue;
}
}
ConnectedResponse::AnnounceIpv4(r) => Response::AnnounceIpv4(r),
ConnectedResponse::AnnounceIpv6(r) => Response::AnnounceIpv6(r),
};
send_response( send_response(
&self.config, &self.config,
@ -398,11 +270,13 @@ impl SocketWorker {
&mut self.socket, &mut self.socket,
&mut self.buffer, &mut self.buffer,
&mut self.opt_resend_buffer, &mut self.opt_resend_buffer,
response, Response::Scrape(response),
addr, src,
); );
} }
} }
}
}
} }
fn send_response( fn send_response(
@ -488,4 +362,6 @@ fn send_response(
} }
} }
} }
::log::debug!("send response fn finished");
} }

View file

@ -6,12 +6,13 @@ mod validator;
use anyhow::Context; use anyhow::Context;
use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::privileges::PrivilegeDropper;
use crossbeam_channel::Sender;
use socket2::{Domain, Protocol, Socket, Type}; use socket2::{Domain, Protocol, Socket, Type};
use crate::{ use crate::{
common::{ common::{
CachePaddedArc, ConnectedRequestSender, ConnectedResponseReceiver, IpVersionStatistics, CachePaddedArc, ConnectedRequestSender, ConnectedResponseReceiver, IpVersionStatistics,
SocketWorkerStatistics, State, SocketWorkerStatistics, State, StatisticsMessage,
}, },
config::Config, config::Config,
}; };
@ -43,11 +44,11 @@ pub fn run_socket_worker(
config: Config, config: Config,
shared_state: State, shared_state: State,
statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>, statistics: CachePaddedArc<IpVersionStatistics<SocketWorkerStatistics>>,
statistics_sender: Sender<StatisticsMessage>,
validator: ConnectionValidator, validator: ConnectionValidator,
request_sender: ConnectedRequestSender,
response_receiver: ConnectedResponseReceiver,
priv_dropper: PrivilegeDropper, priv_dropper: PrivilegeDropper,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
/*
#[cfg(all(target_os = "linux", feature = "io-uring"))] #[cfg(all(target_os = "linux", feature = "io-uring"))]
if config.network.use_io_uring { if config.network.use_io_uring {
self::uring::supported_on_current_kernel().context("check for io_uring compatibility")?; self::uring::supported_on_current_kernel().context("check for io_uring compatibility")?;
@ -62,14 +63,14 @@ pub fn run_socket_worker(
priv_dropper, priv_dropper,
); );
} }
*/
self::mio::SocketWorker::run( self::mio::SocketWorker::run(
config, config,
shared_state, shared_state,
statistics, statistics,
statistics_sender,
validator, validator,
request_sender,
response_receiver,
priv_dropper, priv_dropper,
) )
} }