diff --git a/Cargo.lock b/Cargo.lock index 1f2308c..64f6218 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,7 @@ dependencies = [ "rand", "rustls-pemfile", "serde", + "signal-hook", "slab", "smartstring", ] diff --git a/aquatic_http/Cargo.toml b/aquatic_http/Cargo.toml index b3f9dca..bc0c43b 100644 --- a/aquatic_http/Cargo.toml +++ b/aquatic_http/Cargo.toml @@ -37,6 +37,7 @@ privdrop = "0.5" rand = { version = "0.8", features = ["small_rng"] } rustls-pemfile = "0.2" serde = { version = "1", features = ["derive"] } +signal-hook = { version = "0.3" } slab = "0.4" smartstring = "0.2" diff --git a/aquatic_http/src/lib/common.rs b/aquatic_http/src/lib/common.rs index df1dfce..19e552f 100644 --- a/aquatic_http/src/lib/common.rs +++ b/aquatic_http/src/lib/common.rs @@ -1,7 +1,8 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; use std::time::Instant; -use aquatic_common::access_list::AccessList; +use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use either::Either; use hashbrown::HashMap; use indexmap::IndexMap; @@ -14,14 +15,6 @@ use aquatic_http_protocol::response::ResponsePeer; use crate::config::Config; -use std::borrow::Borrow; -use std::cell::RefCell; -use std::rc::Rc; - -use futures_lite::AsyncBufReadExt; -use glommio::io::{BufferedFile, StreamReaderBuilder}; -use glommio::prelude::*; - use aquatic_http_protocol::{ request::{AnnounceRequest, ScrapeRequest}, response::{AnnounceResponse, ScrapeResponse}, @@ -80,48 +73,6 @@ impl ChannelResponse { } } -pub async fn update_access_list>( - config: C, - access_list: Rc>, -) { - if config.borrow().access_list.mode.is_on() { - match BufferedFile::open(&config.borrow().access_list.path).await { - Ok(file) => { - let mut reader = StreamReaderBuilder::new(file).build(); - let mut new_access_list = AccessList::default(); - - loop { - let mut buf = String::with_capacity(42); - - match reader.read_line(&mut buf).await { - Ok(_) => { - if let Err(err) = new_access_list.insert_from_line(&buf) { - ::log::error!( - "Couln't parse access list line '{}': {:?}", - buf, - err - ); - } - } - Err(err) => { - ::log::error!("Couln't read access list line {:?}", err); - - break; - } - } - - yield_if_needed().await; - } - - *access_list.borrow_mut() = new_access_list; - } - Err(err) => { - ::log::error!("Couldn't open access list file: {:?}", err) - } - }; - } -} - pub trait Ip: ::std::fmt::Debug + Copy + Eq + ::std::hash::Hash {} impl Ip for Ipv4Addr {} @@ -218,20 +169,25 @@ pub struct TorrentMaps { } impl TorrentMaps { - pub fn clean(&mut self, config: &Config, access_list: &AccessList) { - Self::clean_torrent_map(config, access_list, &mut self.ipv4); - Self::clean_torrent_map(config, access_list, &mut self.ipv6); + pub fn clean(&mut self, config: &Config, access_list: &Arc) { + let mut access_list_cache = create_access_list_cache(access_list); + + Self::clean_torrent_map(config, &mut access_list_cache, &mut self.ipv4); + Self::clean_torrent_map(config, &mut access_list_cache, &mut self.ipv6); } fn clean_torrent_map( config: &Config, - access_list: &AccessList, + access_list_cache: &mut AccessListCache, torrent_map: &mut TorrentMap, ) { let now = Instant::now(); torrent_map.retain(|info_hash, torrent_data| { - if !access_list.allows(config.access_list.mode, &info_hash.0) { + if !access_list_cache + .load() + .allows(config.access_list.mode, &info_hash.0) + { return false; } @@ -263,6 +219,11 @@ impl TorrentMaps { } } +#[derive(Default, Clone)] +pub struct State { + pub access_list: Arc, +} + pub fn num_digits_in_usize(mut number: usize) -> usize { let mut num_digits = 1usize; diff --git a/aquatic_http/src/lib/handlers.rs b/aquatic_http/src/lib/handlers.rs index d30fe64..eb46d50 100644 --- a/aquatic_http/src/lib/handlers.rs +++ b/aquatic_http/src/lib/handlers.rs @@ -12,7 +12,6 @@ use std::cell::RefCell; use std::rc::Rc; use std::time::Duration; -use aquatic_common::access_list::AccessList; use futures_lite::{Stream, StreamExt}; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::timer::TimerActionRepeat; @@ -25,9 +24,9 @@ use crate::config::Config; pub async fn run_request_worker( config: Config, + state: State, request_mesh_builder: MeshBuilder, response_mesh_builder: MeshBuilder, - access_list: AccessList, ) { let (_, mut request_receivers) = request_mesh_builder.join(Role::Consumer).await.unwrap(); let (response_senders, _) = response_mesh_builder.join(Role::Producer).await.unwrap(); @@ -35,14 +34,12 @@ pub async fn run_request_worker( let response_senders = Rc::new(response_senders); let torrents = Rc::new(RefCell::new(TorrentMaps::default())); - let access_list = Rc::new(RefCell::new(access_list)); + let access_list = state.access_list; - // Periodically clean torrents and update access list + // Periodically clean torrents TimerActionRepeat::repeat(enclose!((config, torrents, access_list) move || { enclose!((config, torrents, access_list) move || async move { - update_access_list(&config, access_list.clone()).await; - - torrents.borrow_mut().clean(&config, &*access_list.borrow()); + torrents.borrow_mut().clean(&config, &access_list); Some(Duration::from_secs(config.cleaning.interval)) })() diff --git a/aquatic_http/src/lib/lib.rs b/aquatic_http/src/lib/lib.rs index 690df06..c9e4fca 100644 --- a/aquatic_http/src/lib/lib.rs +++ b/aquatic_http/src/lib/lib.rs @@ -4,9 +4,12 @@ use std::{ sync::{atomic::AtomicUsize, Arc}, }; -use aquatic_common::{access_list::AccessList, privileges::drop_privileges_after_socket_binding}; -use common::TlsConfig; +use aquatic_common::{ + access_list::AccessListQuery, privileges::drop_privileges_after_socket_binding, +}; +use common::{State, TlsConfig}; use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; +use signal_hook::{consts::SIGUSR1, iterator::Signals}; use crate::config::Config; @@ -19,18 +22,44 @@ pub const APP_NAME: &str = "aquatic_http: HTTP/TLS BitTorrent tracker"; const SHARED_CHANNEL_SIZE: usize = 1024; -pub fn run(config: Config) -> anyhow::Result<()> { +pub fn run(config: Config) -> ::anyhow::Result<()> { if config.cpu_pinning.active { core_affinity::set_for_current(core_affinity::CoreId { id: config.cpu_pinning.offset, }); } - let access_list = if config.access_list.mode.is_on() { - AccessList::create_from_path(&config.access_list.path).expect("Load access list") - } else { - AccessList::default() - }; + let state = State::default(); + + update_access_list(&config, &state)?; + + let mut signals = Signals::new(::std::iter::once(SIGUSR1))?; + + { + let config = config.clone(); + let state = state.clone(); + + ::std::thread::spawn(move || run_inner(config, state)); + } + + for signal in &mut signals { + match signal { + SIGUSR1 => { + let _ = update_access_list(&config, &state); + } + _ => unreachable!(), + } + } + + Ok(()) +} + +pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { + if config.cpu_pinning.active { + core_affinity::set_for_current(core_affinity::CoreId { + id: config.cpu_pinning.offset, + }); + } let num_peers = config.socket_workers + config.request_workers; @@ -45,11 +74,11 @@ pub fn run(config: Config) -> anyhow::Result<()> { for i in 0..(config.socket_workers) { let config = config.clone(); + let state = state.clone(); let tls_config = tls_config.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); let num_bound_sockets = num_bound_sockets.clone(); - let access_list = access_list.clone(); let mut builder = LocalExecutorBuilder::default(); @@ -60,11 +89,11 @@ pub fn run(config: Config) -> anyhow::Result<()> { let executor = builder.spawn(|| async move { network::run_socket_worker( config, + state, tls_config, request_mesh_builder, response_mesh_builder, num_bound_sockets, - access_list, ) .await }); @@ -74,9 +103,9 @@ pub fn run(config: Config) -> anyhow::Result<()> { for i in 0..(config.request_workers) { let config = config.clone(); + let state = state.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); - let access_list = access_list.clone(); let mut builder = LocalExecutorBuilder::default(); @@ -85,13 +114,8 @@ pub fn run(config: Config) -> anyhow::Result<()> { } let executor = builder.spawn(|| async move { - handlers::run_request_worker( - config, - request_mesh_builder, - response_mesh_builder, - access_list, - ) - .await + handlers::run_request_worker(config, state, request_mesh_builder, response_mesh_builder) + .await }); executors.push(executor); @@ -142,3 +166,20 @@ fn create_tls_config(config: &Config) -> anyhow::Result { Ok(tls_config) } + +fn update_access_list(config: &Config, state: &State) -> anyhow::Result<()> { + if config.access_list.mode.is_on() { + match state.access_list.update(&config.access_list) { + Ok(()) => { + ::log::info!("Access list updated") + } + Err(err) => { + ::log::error!("Updating access list failed: {:#}", err); + + return Err(err); + } + } + } + + Ok(()) +} diff --git a/aquatic_http/src/lib/network.rs b/aquatic_http/src/lib/network.rs index 6970343..fd35ab1 100644 --- a/aquatic_http/src/lib/network.rs +++ b/aquatic_http/src/lib/network.rs @@ -6,7 +6,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; -use aquatic_common::access_list::AccessList; +use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use aquatic_http_protocol::common::InfoHash; use aquatic_http_protocol::request::{Request, RequestParseError, ScrapeRequest}; use aquatic_http_protocol::response::{ @@ -41,29 +41,16 @@ struct ConnectionReference { response_sender: LocalSender, } -struct Connection { - config: Rc, - access_list: Rc>, - request_senders: Rc>, - response_receiver: LocalReceiver, - response_consumer_id: ConsumerId, - stream: TlsStream, - peer_addr: SocketAddr, - connection_id: ConnectionId, - request_buffer: [u8; MAX_REQUEST_SIZE], - request_buffer_position: usize, -} - pub async fn run_socket_worker( config: Config, + state: State, tls_config: Arc, request_mesh_builder: MeshBuilder, response_mesh_builder: MeshBuilder, num_bound_sockets: Arc, - access_list: AccessList, ) { let config = Rc::new(config); - let access_list = Rc::new(RefCell::new(access_list)); + let access_list = state.access_list; let listener = TcpListener::bind(config.network.address).expect("bind socket"); num_bound_sockets.fetch_add(1, Ordering::SeqCst); @@ -77,15 +64,6 @@ pub async fn run_socket_worker( let connection_slab = Rc::new(RefCell::new(Slab::new())); let connections_to_remove = Rc::new(RefCell::new(Vec::new())); - // Periodically update access list - TimerActionRepeat::repeat(enclose!((config, access_list) move || { - enclose!((config, access_list) move || async move { - update_access_list(config.clone(), access_list.clone()).await; - - Some(Duration::from_secs(config.cleaning.interval)) - })() - })); - // Periodically remove closed connections TimerActionRepeat::repeat( enclose!((config, connection_slab, connections_to_remove) move || { @@ -177,10 +155,23 @@ async fn receive_responses( } } +struct Connection { + config: Rc, + access_list_cache: AccessListCache, + request_senders: Rc>, + response_receiver: LocalReceiver, + response_consumer_id: ConsumerId, + stream: TlsStream, + peer_addr: SocketAddr, + connection_id: ConnectionId, + request_buffer: [u8; MAX_REQUEST_SIZE], + request_buffer_position: usize, +} + impl Connection { async fn run( config: Rc, - access_list: Rc>, + access_list: Arc, request_senders: Rc>, response_receiver: LocalReceiver, response_consumer_id: ConsumerId, @@ -197,7 +188,7 @@ impl Connection { let mut conn = Connection { config: config.clone(), - access_list: access_list.clone(), + access_list_cache: create_access_list_cache(&access_list), request_senders: request_senders.clone(), response_receiver, response_consumer_id, @@ -297,14 +288,14 @@ impl Connection { /// response /// - If it is a scrape requests, split it up, pass on the parts to /// relevant request workers and await a response - async fn handle_request(&self, request: Request) -> anyhow::Result { + async fn handle_request(&mut self, request: Request) -> anyhow::Result { match request { Request::Announce(request) => { let info_hash = request.info_hash; if self - .access_list - .borrow() + .access_list_cache + .load() .allows(self.config.access_list.mode, &info_hash.0) { let request = ChannelRequest::Announce {