diff --git a/Cargo.lock b/Cargo.lock index 4c33496..e9141fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -163,6 +163,7 @@ dependencies = [ "rand", "rustls 0.20.4", "serde", + "signal-hook", "socket2 0.4.4", "sqlx", "tokio", diff --git a/aquatic_http_private/Cargo.toml b/aquatic_http_private/Cargo.toml index 0d59f6a..7ec86c8 100644 --- a/aquatic_http_private/Cargo.toml +++ b/aquatic_http_private/Cargo.toml @@ -30,6 +30,7 @@ mimalloc = { version = "0.1", default-features = false } rand = { version = "0.8", features = ["small_rng"] } rustls = "0.20" serde = { version = "1", features = ["derive"] } +signal-hook = { version = "0.3" } socket2 = { version = "0.4", features = ["all"] } sqlx = { version = "0.5", features = [ "runtime-tokio-rustls" , "mysql" ] } tokio = { version = "1", features = ["full"] } diff --git a/aquatic_http_private/src/lib.rs b/aquatic_http_private/src/lib.rs index 8e9d7a7..45d3170 100644 --- a/aquatic_http_private/src/lib.rs +++ b/aquatic_http_private/src/lib.rs @@ -4,9 +4,10 @@ mod workers; use std::{collections::VecDeque, sync::Arc}; -use aquatic_common::rustls_config::create_rustls_config; +use aquatic_common::{rustls_config::create_rustls_config, PanicSentinelWatcher}; use common::ChannelRequestSender; use dotenv::dotenv; +use signal_hook::{consts::SIGTERM, iterator::Signals}; use tokio::sync::mpsc::channel; use config::Config; @@ -15,6 +16,8 @@ pub const APP_NAME: &str = "aquatic_http_private: private HTTP/TLS BitTorrent tr pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn run(config: Config) -> anyhow::Result<()> { + let mut signals = Signals::new([SIGTERM])?; + dotenv().ok(); let tls_config = Arc::new(create_rustls_config( @@ -32,9 +35,11 @@ pub fn run(config: Config) -> anyhow::Result<()> { request_receivers.push_back(request_receiver); } + let (sentinel_watcher, sentinel) = PanicSentinelWatcher::create_with_sentinel(); let mut handles = Vec::new(); for _ in 0..config.socket_workers { + let sentinel = sentinel.clone(); let config = config.clone(); let tls_config = tls_config.clone(); let request_sender = ChannelRequestSender::new(request_senders.clone()); @@ -42,27 +47,37 @@ pub fn run(config: Config) -> anyhow::Result<()> { let handle = ::std::thread::Builder::new() .name("socket".into()) .spawn(move || { - workers::socket::run_socket_worker(config, tls_config, request_sender) + workers::socket::run_socket_worker(sentinel, config, tls_config, request_sender) })?; handles.push(handle); } for _ in 0..config.request_workers { + let sentinel = sentinel.clone(); let config = config.clone(); let request_receiver = request_receivers.pop_front().unwrap(); let handle = ::std::thread::Builder::new() .name("request".into()) - .spawn(move || workers::request::run_request_worker(config, request_receiver))?; + .spawn(move || { + workers::request::run_request_worker(sentinel, config, request_receiver) + })?; handles.push(handle); } - for handle in handles { - handle - .join() - .map_err(|err| anyhow::anyhow!("thread join error: {:?}", err))??; + for signal in &mut signals { + match signal { + SIGTERM => { + if sentinel_watcher.panic_was_triggered() { + return Err(anyhow::anyhow!("worker thread panicked")); + } else { + return Ok(()); + } + } + _ => unreachable!(), + } } Ok(()) diff --git a/aquatic_http_private/src/workers/request/mod.rs b/aquatic_http_private/src/workers/request/mod.rs index 358ead6..c684256 100644 --- a/aquatic_http_private/src/workers/request/mod.rs +++ b/aquatic_http_private/src/workers/request/mod.rs @@ -11,7 +11,7 @@ use tokio::sync::mpsc::Receiver; use tokio::task::LocalSet; use tokio::time; -use aquatic_common::{extract_response_peers, CanonicalSocketAddr, ValidUntil}; +use aquatic_common::{extract_response_peers, CanonicalSocketAddr, PanicSentinel, ValidUntil}; use aquatic_http_protocol::response::{ AnnounceResponse, Response, ResponsePeer, ResponsePeerListV4, ResponsePeerListV6, }; @@ -22,6 +22,7 @@ use crate::config::Config; use common::*; pub fn run_request_worker( + _sentinel: PanicSentinel, config: Config, request_receiver: Receiver, ) -> anyhow::Result<()> { diff --git a/aquatic_http_private/src/workers/socket/mod.rs b/aquatic_http_private/src/workers/socket/mod.rs index 2b142c7..13304aa 100644 --- a/aquatic_http_private/src/workers/socket/mod.rs +++ b/aquatic_http_private/src/workers/socket/mod.rs @@ -8,7 +8,7 @@ use std::{ }; use anyhow::Context; -use aquatic_common::rustls_config::RustlsConfig; +use aquatic_common::{rustls_config::RustlsConfig, PanicSentinel}; use axum::{extract::connect_info::Connected, routing::get, Extension, Router}; use hyper::server::conn::AddrIncoming; use sqlx::mysql::MySqlPoolOptions; @@ -23,6 +23,7 @@ impl<'a> Connected<&'a tls::TlsStream> for SocketAddr { } pub fn run_socket_worker( + _sentinel: PanicSentinel, config: Config, tls_config: Arc, request_sender: ChannelRequestSender,