diff --git a/CHANGELOG.md b/CHANGELOG.md index 3134f9c..ede1de0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,7 @@ * Actually close connections that are too slow to send responses to * If peers announce with AnnounceEvent::Stopped, allow them to later announce on same torrent with different peer_id +* Quit whole application if any worker thread quits ## 0.8.0 - 2023-03-17 diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 033fffc..a04cc7e 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -1,7 +1,5 @@ use std::fmt::Display; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; use std::time::Instant; use ahash::RandomState; @@ -57,42 +55,6 @@ impl ServerStartInstant { #[derive(Debug, Clone, Copy)] pub struct SecondsSinceServerStart(u32); -pub struct PanicSentinelWatcher(Arc); - -impl PanicSentinelWatcher { - pub fn create_with_sentinel() -> (Self, PanicSentinel) { - let triggered = Arc::new(AtomicBool::new(false)); - let sentinel = PanicSentinel(triggered.clone()); - - (Self(triggered), sentinel) - } - - pub fn panic_was_triggered(&self) -> bool { - self.0.load(Ordering::SeqCst) - } -} - -/// Raises SIGTERM when dropped -/// -/// Pass to threads to have panics in them cause whole program to exit. -#[derive(Clone)] -pub struct PanicSentinel(Arc); - -impl Drop for PanicSentinel { - fn drop(&mut self) { - if ::std::thread::panicking() { - let already_triggered = self.0.fetch_or(true, Ordering::SeqCst); - - if !already_triggered && unsafe { libc::raise(15) } == -1 { - panic!( - "Could not raise SIGTERM: {:#}", - ::std::io::Error::last_os_error() - ) - } - } - } -} - /// SocketAddr that is not an IPv6-mapped IPv4 address #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] pub struct CanonicalSocketAddr(SocketAddr); diff --git a/crates/ws/Cargo.toml b/crates/ws/Cargo.toml index 662bc61..5f4f7fb 100644 --- a/crates/ws/Cargo.toml +++ b/crates/ws/Cargo.toml @@ -26,7 +26,7 @@ metrics = ["dep:metrics", "metrics-util"] mimalloc = ["dep:mimalloc"] [dependencies] -aquatic_common = { workspace = true, features = ["rustls", "glommio"] } +aquatic_common = { workspace = true, features = ["rustls"] } aquatic_peer_id.workspace = true aquatic_toml_config.workspace = true aquatic_ws_protocol.workspace = true diff --git a/crates/ws/src/lib.rs b/crates/ws/src/lib.rs index 806267f..cfaaaad 100644 --- a/crates/ws/src/lib.rs +++ b/crates/ws/src/lib.rs @@ -3,19 +3,15 @@ pub mod config; pub mod workers; use std::sync::Arc; +use std::thread::{sleep, Builder, JoinHandle}; use std::time::Duration; use anyhow::Context; -use aquatic_common::cpu_pinning::glommio::{get_worker_placement, set_affinity_for_util_worker}; -use aquatic_common::cpu_pinning::WorkerIndex; use aquatic_common::rustls_config::create_rustls_config; -use aquatic_common::{PanicSentinelWatcher, ServerStartInstant}; +use aquatic_common::{ServerStartInstant, WorkerType}; use arc_swap::ArcSwap; use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; -use signal_hook::{ - consts::{SIGTERM, SIGUSR1}, - iterator::Signals, -}; +use signal_hook::{consts::SIGUSR1, iterator::Signals}; use aquatic_common::access_list::update_access_list; use aquatic_common::privileges::PrivilegeDropper; @@ -35,45 +31,18 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { )); } - let mut signals = Signals::new([SIGUSR1, SIGTERM])?; - - #[cfg(feature = "prometheus")] - if config.metrics.run_prometheus_endpoint { - use metrics_exporter_prometheus::PrometheusBuilder; - - let idle_timeout = config - .cleaning - .connection_cleaning_interval - .max(config.cleaning.torrent_cleaning_interval) - .max(config.metrics.torrent_count_update_interval) - * 2; - - PrometheusBuilder::new() - .idle_timeout( - metrics_util::MetricKindMask::GAUGE, - Some(Duration::from_secs(idle_timeout)), - ) - .with_http_listener(config.metrics.prometheus_endpoint_address) - .install() - .with_context(|| { - format!( - "Install prometheus endpoint on {}", - config.metrics.prometheus_endpoint_address - ) - })?; - } + let mut signals = Signals::new([SIGUSR1])?; let state = State::default(); update_access_list(&config.access_list, &state.access_list)?; - let num_peers = config.socket_workers + config.swarm_workers; + let num_mesh_peers = config.socket_workers + config.swarm_workers; - let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE); - let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16); - let control_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16); + let request_mesh_builder = MeshBuilder::partial(num_mesh_peers, SHARED_IN_CHANNEL_SIZE); + let response_mesh_builder = MeshBuilder::partial(num_mesh_peers, SHARED_IN_CHANNEL_SIZE * 16); + let control_mesh_builder = MeshBuilder::partial(num_mesh_peers, SHARED_IN_CHANNEL_SIZE * 16); - let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel(); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); let opt_tls_config = if config.network.enable_tls { @@ -98,10 +67,9 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let server_start_instant = ServerStartInstant::new(); - let mut executors = Vec::new(); + let mut join_handles = Vec::new(); for i in 0..(config.socket_workers) { - let sentinel = sentinel.clone(); let config = config.clone(); let state = state.clone(); let opt_tls_config = opt_tls_config.clone(); @@ -110,120 +78,137 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let response_mesh_builder = response_mesh_builder.clone(); let priv_dropper = priv_dropper.clone(); - let placement = get_worker_placement( - &config.cpu_pinning, - config.socket_workers, - config.swarm_workers, - WorkerIndex::SocketWorker(i), - )?; - let builder = LocalExecutorBuilder::new(placement).name(&format!("socket-{:02}", i + 1)); - - let executor = builder - .spawn(move || async move { - workers::socket::run_socket_worker( - sentinel, - config, - state, - opt_tls_config, - control_mesh_builder, - request_mesh_builder, - response_mesh_builder, - priv_dropper, - server_start_instant, - i, - ) - .await + let handle = Builder::new() + .name(format!("socket-{:02}", i + 1)) + .spawn(move || { + LocalExecutorBuilder::default() + .make() + .map_err(|err| anyhow::anyhow!("Spawning executor failed: {:#}", err))? + .run(workers::socket::run_socket_worker( + config, + state, + opt_tls_config, + control_mesh_builder, + request_mesh_builder, + response_mesh_builder, + priv_dropper, + server_start_instant, + i, + )) }) - .map_err(|err| anyhow::anyhow!("Spawning executor failed: {:#}", err))?; + .context("spawn socket worker")?; - executors.push(executor); + join_handles.push((WorkerType::Socket(i), handle)); } - ::log::info!("spawned socket workers"); - for i in 0..(config.swarm_workers) { - let sentinel = sentinel.clone(); let config = config.clone(); let state = state.clone(); let control_mesh_builder = control_mesh_builder.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); - let placement = get_worker_placement( - &config.cpu_pinning, - config.socket_workers, - config.swarm_workers, - WorkerIndex::SwarmWorker(i), - )?; - let builder = LocalExecutorBuilder::new(placement).name(&format!("swarm-{:02}", i + 1)); - - let executor = builder - .spawn(move || async move { - workers::swarm::run_swarm_worker( - sentinel, - config, - state, - control_mesh_builder, - request_mesh_builder, - response_mesh_builder, - server_start_instant, - i, - ) - .await + let handle = Builder::new() + .name(format!("socket-{:02}", i + 1)) + .spawn(move || { + LocalExecutorBuilder::default() + .make() + .map_err(|err| anyhow::anyhow!("Spawning executor failed: {:#}", err))? + .run(workers::swarm::run_swarm_worker( + config, + state, + control_mesh_builder, + request_mesh_builder, + response_mesh_builder, + server_start_instant, + i, + )) }) - .map_err(|err| anyhow::anyhow!("Spawning executor failed: {:#}", err))?; + .context("spawn swarm worker")?; - executors.push(executor); + join_handles.push((WorkerType::Socket(i), handle)); } - ::log::info!("spawned swarm workers"); + #[cfg(feature = "prometheus")] + if config.metrics.run_prometheus_endpoint { + let idle_timeout = config + .cleaning + .connection_cleaning_interval + .max(config.cleaning.torrent_cleaning_interval) + .max(config.metrics.torrent_count_update_interval) + * 2; - if config.cpu_pinning.active { - set_affinity_for_util_worker( - &config.cpu_pinning, - config.socket_workers, - config.swarm_workers, + let handle = aquatic_common::spawn_prometheus_endpoint( + config.metrics.prometheus_endpoint_address, + Some(Duration::from_secs(idle_timeout)), )?; + + join_handles.push((WorkerType::Prometheus, handle)); } - for signal in &mut signals { - match signal { - SIGUSR1 => { - let _ = update_access_list(&config.access_list, &state.access_list); + // Spawn signal handler thread + { + let handle: JoinHandle> = Builder::new() + .name("signals".into()) + .spawn(move || { + for signal in &mut signals { + match signal { + SIGUSR1 => { + let _ = update_access_list(&config.access_list, &state.access_list); - if let Some(tls_config) = opt_tls_config.as_ref() { - match ::std::fs::read(&config.network.tls_certificate_path) { - Ok(data) if &data == opt_tls_cert_data.as_ref().unwrap() => { - ::log::info!("skipping tls config update: certificate identical to currently loaded"); - } - Ok(data) => { - match create_rustls_config( - &config.network.tls_certificate_path, - &config.network.tls_private_key_path, - ) { - Ok(config) => { - tls_config.store(Arc::new(config)); - opt_tls_cert_data = Some(data); + if let Some(tls_config) = opt_tls_config.as_ref() { + match ::std::fs::read(&config.network.tls_certificate_path) { + Ok(data) if &data == opt_tls_cert_data.as_ref().unwrap() => { + ::log::info!("skipping tls config update: certificate identical to currently loaded"); + } + Ok(data) => { + match create_rustls_config( + &config.network.tls_certificate_path, + &config.network.tls_private_key_path, + ) { + Ok(config) => { + tls_config.store(Arc::new(config)); + opt_tls_cert_data = Some(data); - ::log::info!("successfully updated tls config"); + ::log::info!("successfully updated tls config"); + } + Err(err) => ::log::error!("could not update tls config: {:#}", err), + } + } + Err(err) => ::log::error!("couldn't read tls certificate file: {:#}", err), } - Err(err) => ::log::error!("could not update tls config: {:#}", err), } } - Err(err) => ::log::error!("couldn't read tls certificate file: {:#}", err), + _ => unreachable!(), + } + } + + Ok(()) + }) + .context("spawn signal worker")?; + + join_handles.push((WorkerType::Signals, handle)); + } + + loop { + for (i, (_, handle)) in join_handles.iter().enumerate() { + if handle.is_finished() { + let (worker_type, handle) = join_handles.remove(i); + + match handle.join() { + Ok(Ok(())) => { + return Err(anyhow::anyhow!("{} stopped", worker_type)); + } + Ok(Err(err)) => { + return Err(err.context(format!("{} stopped", worker_type))); + } + Err(_) => { + return Err(anyhow::anyhow!("{} panicked", worker_type)); } } } - SIGTERM => { - if sentinel_watcher.panic_was_triggered() { - return Err(anyhow::anyhow!("worker thread panicked")); - } else { - return Ok(()); - } - } - _ => unreachable!(), } - } - Ok(()) + sleep(Duration::from_secs(5)); + } } diff --git a/crates/ws/src/workers/socket/mod.rs b/crates/ws/src/workers/socket/mod.rs index 8d79679..c633d63 100644 --- a/crates/ws/src/workers/socket/mod.rs +++ b/crates/ws/src/workers/socket/mod.rs @@ -7,7 +7,7 @@ use std::time::Duration; use anyhow::Context; use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::rustls_config::RustlsConfig; -use aquatic_common::{PanicSentinel, ServerStartInstant}; +use aquatic_common::ServerStartInstant; use aquatic_ws_protocol::common::InfoHash; use aquatic_ws_protocol::incoming::InMessage; use aquatic_ws_protocol::outgoing::OutMessage; @@ -50,7 +50,6 @@ struct ConnectionHandle { #[allow(clippy::too_many_arguments)] pub async fn run_socket_worker( - _sentinel: PanicSentinel, config: Config, state: State, opt_tls_config: Option>>, @@ -60,26 +59,41 @@ pub async fn run_socket_worker( priv_dropper: PrivilegeDropper, server_start_instant: ServerStartInstant, worker_index: usize, -) { +) -> anyhow::Result<()> { #[cfg(feature = "metrics")] WORKER_INDEX.with(|index| index.set(worker_index)); let config = Rc::new(config); let access_list = state.access_list; - let listener = create_tcp_listener(&config, priv_dropper).expect("create tcp listener"); + let listener = create_tcp_listener(&config, priv_dropper).context("create tcp listener")?; ::log::info!("created tcp listener"); let (control_message_senders, _) = control_message_mesh_builder .join(Role::Producer) .await - .unwrap(); - let control_message_senders = Rc::new(control_message_senders); + .map_err(|err| anyhow::anyhow!("join control message mesh: {:#}", err))?; + let (in_message_senders, _) = in_message_mesh_builder + .join(Role::Producer) + .await + .map_err(|err| anyhow::anyhow!("join in message mesh: {:#}", err))?; + let (_, mut out_message_receivers) = out_message_mesh_builder + .join(Role::Consumer) + .await + .map_err(|err| anyhow::anyhow!("join out message mesh: {:#}", err))?; - let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap(); + let control_message_senders = Rc::new(control_message_senders); let in_message_senders = Rc::new(in_message_senders); + let out_message_consumer_id = ConsumerId( + out_message_receivers + .consumer_id() + .unwrap() + .try_into() + .unwrap(), + ); + let tq_prioritized = executor().create_task_queue( Shares::Static(100), Latency::Matters(Duration::from_millis(1)), @@ -88,16 +102,6 @@ pub async fn run_socket_worker( let tq_regular = executor().create_task_queue(Shares::Static(1), Latency::NotImportant, "regular"); - let (_, mut out_message_receivers) = - out_message_mesh_builder.join(Role::Consumer).await.unwrap(); - let out_message_consumer_id = ConsumerId( - out_message_receivers - .consumer_id() - .unwrap() - .try_into() - .unwrap(), - ); - ::log::info!("joined channels"); let connection_handles = Rc::new(RefCell::new(ConnectionHandles::default())); @@ -114,14 +118,14 @@ pub async fn run_socket_worker( }), tq_prioritized, ) - .unwrap(); + .map_err(|err| anyhow::anyhow!("spawn connection cleaning task: {:#}", err))?; for (_, out_message_receiver) in out_message_receivers.streams() { spawn_local_into( receive_out_messages(out_message_receiver, connection_handles.clone()), tq_regular, ) - .unwrap() + .map_err(|err| anyhow::anyhow!("spawn out message receiving task: {:#}", err))? .detach(); } @@ -197,6 +201,8 @@ pub async fn run_socket_worker( } } } + + Ok(()) } async fn clean_connections( diff --git a/crates/ws/src/workers/swarm/mod.rs b/crates/ws/src/workers/swarm/mod.rs index f885b97..419ee56 100644 --- a/crates/ws/src/workers/swarm/mod.rs +++ b/crates/ws/src/workers/swarm/mod.rs @@ -13,7 +13,7 @@ use glommio::prelude::*; use glommio::timer::TimerActionRepeat; use rand::{rngs::SmallRng, SeedableRng}; -use aquatic_common::{PanicSentinel, ServerStartInstant}; +use aquatic_common::ServerStartInstant; use crate::common::*; use crate::config::Config; @@ -21,9 +21,7 @@ use crate::SHARED_IN_CHANNEL_SIZE; use self::storage::TorrentMaps; -#[allow(clippy::too_many_arguments)] pub async fn run_swarm_worker( - _sentinel: PanicSentinel, config: Config, state: State, control_message_mesh_builder: MeshBuilder, @@ -31,14 +29,19 @@ pub async fn run_swarm_worker( out_message_mesh_builder: MeshBuilder<(OutMessageMeta, OutMessage), Partial>, server_start_instant: ServerStartInstant, worker_index: usize, -) { +) -> anyhow::Result<()> { let (_, mut control_message_receivers) = control_message_mesh_builder .join(Role::Consumer) .await - .unwrap(); - - let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); - let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); + .map_err(|err| anyhow::anyhow!("join control message mesh: {:#}", err))?; + let (_, mut in_message_receivers) = in_message_mesh_builder + .join(Role::Consumer) + .await + .map_err(|err| anyhow::anyhow!("join in message mesh: {:#}", err))?; + let (out_message_senders, _) = out_message_mesh_builder + .join(Role::Producer) + .await + .map_err(|err| anyhow::anyhow!("join out message mesh: {:#}", err))?; let out_message_senders = Rc::new(out_message_senders); @@ -89,6 +92,8 @@ pub async fn run_swarm_worker( for handle in handles { handle.await; } + + Ok(()) } async fn handle_control_message_stream(torrents: Rc>, mut stream: S)