diff --git a/Cargo.lock b/Cargo.lock index 64f6218..742e3c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,6 +259,7 @@ dependencies = [ "rand", "rustls-pemfile", "serde", + "signal-hook", "slab", "tungstenite", ] diff --git a/aquatic_ws/Cargo.toml b/aquatic_ws/Cargo.toml index af69a0a..f2a41bb 100644 --- a/aquatic_ws/Cargo.toml +++ b/aquatic_ws/Cargo.toml @@ -35,6 +35,7 @@ privdrop = "0.5" rand = { version = "0.8", features = ["small_rng"] } rustls-pemfile = "0.2" serde = { version = "1", features = ["derive"] } +signal-hook = { version = "0.3" } slab = "0.4" tungstenite = "0.15" diff --git a/aquatic_ws/src/lib/common.rs b/aquatic_ws/src/lib/common.rs index 9ee675c..1b136d9 100644 --- a/aquatic_ws/src/lib/common.rs +++ b/aquatic_ws/src/lib/common.rs @@ -1,13 +1,8 @@ -use std::borrow::Borrow; -use std::cell::RefCell; use std::net::{IpAddr, SocketAddr}; -use std::rc::Rc; +use std::sync::Arc; use std::time::Instant; -use aquatic_common::access_list::AccessList; -use futures_lite::AsyncBufReadExt; -use glommio::io::{BufferedFile, StreamReaderBuilder}; -use glommio::yield_if_needed; +use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use hashbrown::HashMap; use indexmap::IndexMap; @@ -99,16 +94,25 @@ pub struct TorrentMaps { } impl TorrentMaps { - pub fn clean(&mut self, config: &Config, access_list: &AccessList) { - Self::clean_torrent_map(config, access_list, &mut self.ipv4); - Self::clean_torrent_map(config, access_list, &mut self.ipv6); + pub fn clean(&mut self, config: &Config, access_list: &Arc) { + let mut access_list_cache = create_access_list_cache(access_list); + + Self::clean_torrent_map(config, &mut access_list_cache, &mut self.ipv4); + Self::clean_torrent_map(config, &mut access_list_cache, &mut self.ipv6); } - fn clean_torrent_map(config: &Config, access_list: &AccessList, torrent_map: &mut TorrentMap) { + fn clean_torrent_map( + config: &Config, + access_list_cache: &mut AccessListCache, + torrent_map: &mut TorrentMap, + ) { let now = Instant::now(); torrent_map.retain(|info_hash, torrent_data| { - if !access_list.allows(config.access_list.mode, &info_hash.0) { + if !access_list_cache + .load() + .allows(config.access_list.mode, &info_hash.0) + { return false; } @@ -140,44 +144,7 @@ impl TorrentMaps { } } -pub async fn update_access_list>( - config: C, - access_list: Rc>, -) { - if config.borrow().access_list.mode.is_on() { - match BufferedFile::open(&config.borrow().access_list.path).await { - Ok(file) => { - let mut reader = StreamReaderBuilder::new(file).build(); - let mut new_access_list = AccessList::default(); - - loop { - let mut buf = String::with_capacity(42); - - match reader.read_line(&mut buf).await { - Ok(_) => { - if let Err(err) = new_access_list.insert_from_line(&buf) { - ::log::error!( - "Couln't parse access list line '{}': {:?}", - buf, - err - ); - } - } - Err(err) => { - ::log::error!("Couln't read access list line {:?}", err); - - break; - } - } - - yield_if_needed().await; - } - - *access_list.borrow_mut() = new_access_list; - } - Err(err) => { - ::log::error!("Couldn't open access list file: {:?}", err) - } - }; - } +#[derive(Default, Clone)] +pub struct State { + pub access_list: Arc, } diff --git a/aquatic_ws/src/lib/handlers.rs b/aquatic_ws/src/lib/handlers.rs index 31e0388..9b6100d 100644 --- a/aquatic_ws/src/lib/handlers.rs +++ b/aquatic_ws/src/lib/handlers.rs @@ -2,7 +2,6 @@ use std::cell::RefCell; use std::rc::Rc; use std::time::Duration; -use aquatic_common::access_list::AccessList; use aquatic_common::extract_response_peers; use futures_lite::StreamExt; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; @@ -19,9 +18,9 @@ use crate::config::Config; pub async fn run_request_worker( config: Config, + state: State, in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, - access_list: AccessList, ) { let (_, mut in_message_receivers) = in_message_mesh_builder.join(Role::Consumer).await.unwrap(); let (out_message_senders, _) = out_message_mesh_builder.join(Role::Producer).await.unwrap(); @@ -29,14 +28,12 @@ pub async fn run_request_worker( let out_message_senders = Rc::new(out_message_senders); let torrents = Rc::new(RefCell::new(TorrentMaps::default())); - let access_list = Rc::new(RefCell::new(access_list)); + let access_list = state.access_list; - // Periodically clean torrents and update access list + // Periodically clean torrents TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || { enclose!((config, torrents, access_list) move || async move { - update_access_list(&config, access_list.clone()).await; - - torrents.borrow_mut().clean(&config, &*access_list.borrow()); + torrents.borrow_mut().clean(&config, &access_list); Some(Duration::from_secs(config.cleaning.interval)) })() diff --git a/aquatic_ws/src/lib/lib.rs b/aquatic_ws/src/lib/lib.rs index c1de1d2..61d8eb7 100644 --- a/aquatic_ws/src/lib/lib.rs +++ b/aquatic_ws/src/lib/lib.rs @@ -4,9 +4,12 @@ use std::{ sync::{atomic::AtomicUsize, Arc}, }; -use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding}; -use common::TlsConfig; +use aquatic_common::{ + access_list::AccessListQuery, privileges::drop_privileges_after_socket_binding, +}; +use common::{State, TlsConfig}; use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; +use signal_hook::{consts::SIGUSR1, iterator::Signals}; use crate::config::Config; @@ -19,18 +22,44 @@ pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; const SHARED_CHANNEL_SIZE: usize = 1024; -pub fn run(config: Config) -> anyhow::Result<()> { +pub fn run(config: Config) -> ::anyhow::Result<()> { if config.cpu_pinning.active { core_affinity::set_for_current(core_affinity::CoreId { id: config.cpu_pinning.offset, }); } - let access_list = if config.access_list.mode.is_on() { - AccessList::create_from_path(&config.access_list.path).expect("Load access list") - } else { - AccessList::default() - }; + let state = State::default(); + + update_access_list(&config, &state)?; + + let mut signals = Signals::new(::std::iter::once(SIGUSR1))?; + + { + let config = config.clone(); + let state = state.clone(); + + ::std::thread::spawn(move || run_inner(config, state)); + } + + for signal in &mut signals { + match signal { + SIGUSR1 => { + let _ = update_access_list(&config, &state); + } + _ => unreachable!(), + } + } + + Ok(()) +} + +pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { + if config.cpu_pinning.active { + core_affinity::set_for_current(core_affinity::CoreId { + id: config.cpu_pinning.offset, + }); + } let num_peers = config.socket_workers + config.request_workers; @@ -45,11 +74,11 @@ pub fn run(config: Config) -> anyhow::Result<()> { for i in 0..(config.socket_workers) { let config = config.clone(); + let state = state.clone(); let tls_config = tls_config.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); let num_bound_sockets = num_bound_sockets.clone(); - let access_list = access_list.clone(); let mut builder = LocalExecutorBuilder::default(); @@ -60,11 +89,11 @@ pub fn run(config: Config) -> anyhow::Result<()> { let executor = builder.spawn(|| async move { network::run_socket_worker( config, + state, tls_config, request_mesh_builder, response_mesh_builder, num_bound_sockets, - access_list, ) .await }); @@ -74,9 +103,9 @@ pub fn run(config: Config) -> anyhow::Result<()> { for i in 0..(config.request_workers) { let config = config.clone(); + let state = state.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); - let access_list = access_list.clone(); let mut builder = LocalExecutorBuilder::default(); @@ -85,13 +114,8 @@ pub fn run(config: Config) -> anyhow::Result<()> { } let executor = builder.spawn(|| async move { - handlers::run_request_worker( - config, - request_mesh_builder, - response_mesh_builder, - access_list, - ) - .await + handlers::run_request_worker(config, state, request_mesh_builder, response_mesh_builder) + .await }); executors.push(executor); @@ -142,3 +166,20 @@ fn create_tls_config(config: &Config) -> anyhow::Result { Ok(tls_config) } + +fn update_access_list(config: &Config, state: &State) -> anyhow::Result<()> { + if config.access_list.mode.is_on() { + match state.access_list.update(&config.access_list) { + Ok(()) => { + ::log::info!("Access list updated") + } + Err(err) => { + ::log::error!("Updating access list failed: {:#}", err); + + return Err(err); + } + } + } + + Ok(()) +} diff --git a/aquatic_ws/src/lib/network.rs b/aquatic_ws/src/lib/network.rs index d0e9679..64fbbfa 100644 --- a/aquatic_ws/src/lib/network.rs +++ b/aquatic_ws/src/lib/network.rs @@ -7,7 +7,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; -use aquatic_common::access_list::AccessList; +use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use aquatic_common::convert_ipv4_mapped_ipv6; use aquatic_ws_protocol::*; use async_tungstenite::WebSocketStream; @@ -40,14 +40,14 @@ struct ConnectionReference { pub async fn run_socket_worker( config: Config, + state: State, tls_config: Arc, in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, num_bound_sockets: Arc, - access_list: AccessList, ) { let config = Rc::new(config); - let access_list = Rc::new(RefCell::new(access_list)); + let access_list = state.access_list; let listener = TcpListener::bind(config.network.address).expect("bind socket"); num_bound_sockets.fetch_add(1, Ordering::SeqCst); @@ -62,15 +62,6 @@ pub async fn run_socket_worker( let connection_slab = Rc::new(RefCell::new(Slab::new())); let connections_to_remove = Rc::new(RefCell::new(Vec::new())); - // Periodically update access list - TimerActionRepeat::repeat(enclose!((config, access_list) move || { - enclose!((config, access_list) move || async move { - update_access_list(config.clone(), access_list.clone()).await; - - Some(Duration::from_secs(config.cleaning.interval)) - })() - })); - // Periodically remove closed connections TimerActionRepeat::repeat( enclose!((config, connection_slab, connections_to_remove) move || { @@ -176,7 +167,7 @@ struct Connection; impl Connection { async fn run( config: Rc, - access_list: Rc>, + access_list: Arc, in_message_senders: Rc>, out_message_sender: Rc>, out_message_receiver: LocalReceiver<(ConnectionMeta, OutMessage)>, @@ -201,11 +192,12 @@ impl Connection { let (ws_out, ws_in) = futures::StreamExt::split(stream); let pending_scrape_slab = Rc::new(RefCell::new(Slab::new())); + let access_list_cache = create_access_list_cache(&access_list); let reader_handle = spawn_local(enclose!((pending_scrape_slab) async move { let mut reader = ConnectionReader { config, - access_list, + access_list_cache, in_message_senders, out_message_sender, pending_scrape_slab, @@ -237,7 +229,7 @@ impl Connection { struct ConnectionReader { config: Rc, - access_list: Rc>, + access_list_cache: AccessListCache, in_message_senders: Rc>, out_message_sender: Rc>, pending_scrape_slab: Rc>>, @@ -275,8 +267,8 @@ impl ConnectionReader { let info_hash = announce_request.info_hash; if self - .access_list - .borrow() + .access_list_cache + .load() .allows(self.config.access_list.mode, &info_hash.0) { let in_message = InMessage::AnnounceRequest(announce_request);