From 7ccd5fcbf77e940652304e1a5467b6674a63823c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Sat, 16 Oct 2021 17:26:40 +0200 Subject: [PATCH] access lists: filter requests in socket workers instead --- Cargo.lock | 9 ++ aquatic_common/Cargo.toml | 1 + aquatic_common/src/access_list.rs | 19 ++-- aquatic_http/src/lib/common.rs | 11 ++- aquatic_http/src/lib/handler.rs | 120 +++++++++++------------ aquatic_http/src/lib/lib.rs | 10 +- aquatic_http/src/lib/network/mod.rs | 29 +++++- aquatic_http/src/lib/tasks.rs | 4 +- aquatic_udp/src/lib/common.rs | 6 +- aquatic_udp/src/lib/handlers/announce.rs | 52 ++++------ aquatic_udp/src/lib/lib.rs | 9 +- aquatic_udp/src/lib/network.rs | 17 +++- aquatic_udp/src/lib/tasks.rs | 4 +- aquatic_ws/src/lib/common.rs | 11 ++- aquatic_ws/src/lib/handler.rs | 17 ---- aquatic_ws/src/lib/lib.rs | 10 +- aquatic_ws/src/lib/network/mod.rs | 51 +++++++++- aquatic_ws/src/lib/tasks.rs | 4 +- 18 files changed, 221 insertions(+), 163 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ec80e36..f1dcab0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "addr2line" version = "0.16.0" @@ -73,6 +75,7 @@ name = "aquatic_common" version = "0.1.0" dependencies = [ "anyhow", + "arc-swap", "hashbrown 0.11.2", "hex", "indexmap", @@ -279,6 +282,12 @@ dependencies = [ "tungstenite", ] +[[package]] +name = "arc-swap" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6df5aef5c5830360ce5218cecb8f018af3438af5686ae945094affc86fdec63" + [[package]] name = "arrayvec" version = "0.4.12" diff --git a/aquatic_common/Cargo.toml b/aquatic_common/Cargo.toml index 0072453..02fbe6e 100644 --- a/aquatic_common/Cargo.toml +++ b/aquatic_common/Cargo.toml @@ -12,6 +12,7 @@ name = "aquatic_common" [dependencies] anyhow = "1" +arc-swap = "1" hashbrown = "0.11.2" hex = "0.4" indexmap = "1" diff --git a/aquatic_common/src/access_list.rs b/aquatic_common/src/access_list.rs index 32792e6..588e7a8 100644 --- a/aquatic_common/src/access_list.rs +++ b/aquatic_common/src/access_list.rs @@ -1,7 +1,9 @@ use std::fs::File; use std::io::{BufRead, BufReader}; use std::path::PathBuf; +use std::sync::Arc; +use arc_swap::ArcSwap; use hashbrown::HashSet; use serde::{Deserialize, Serialize}; @@ -34,8 +36,13 @@ impl Default for AccessListConfig { } } -#[derive(Default)] -pub struct AccessList(HashSet<[u8; 20]>); +pub struct AccessList(ArcSwap>); + +impl Default for AccessList { + fn default() -> Self { + Self(ArcSwap::from(Arc::new(HashSet::default()))) + } +} impl AccessList { fn parse_info_hash(line: String) -> anyhow::Result<[u8; 20]> { @@ -46,7 +53,7 @@ impl AccessList { Ok(bytes) } - pub fn update_from_path(&mut self, path: &PathBuf) -> anyhow::Result<()> { + pub fn update_from_path(&self, path: &PathBuf) -> anyhow::Result<()> { let file = File::open(path)?; let reader = BufReader::new(file); @@ -56,15 +63,15 @@ impl AccessList { new_list.insert(Self::parse_info_hash(line?)?); } - self.0 = new_list; + self.0.store(Arc::new(new_list)); Ok(()) } pub fn allows(&self, list_type: AccessListMode, info_hash_bytes: &[u8; 20]) -> bool { match list_type { - AccessListMode::Require => self.0.contains(info_hash_bytes), - AccessListMode::Forbid => !self.0.contains(info_hash_bytes), + AccessListMode::Require => self.0.load().contains(info_hash_bytes), + AccessListMode::Forbid => !self.0.load().contains(info_hash_bytes), AccessListMode::Ignore => true, } } diff --git a/aquatic_http/src/lib/common.rs b/aquatic_http/src/lib/common.rs index 7de0444..e36e390 100644 --- a/aquatic_http/src/lib/common.rs +++ b/aquatic_http/src/lib/common.rs @@ -115,18 +115,17 @@ pub type TorrentMap = HashMap>; pub struct TorrentMaps { pub ipv4: TorrentMap, pub ipv6: TorrentMap, - pub access_list: AccessList, } impl TorrentMaps { - pub fn clean(&mut self, config: &Config) { - Self::clean_torrent_map(config, &self.access_list, &mut self.ipv4); - Self::clean_torrent_map(config, &self.access_list, &mut self.ipv6); + pub fn clean(&mut self, config: &Config, access_list: &Arc) { + Self::clean_torrent_map(config, access_list, &mut self.ipv4); + Self::clean_torrent_map(config, access_list, &mut self.ipv6); } fn clean_torrent_map( config: &Config, - access_list: &AccessList, + access_list: &Arc, torrent_map: &mut TorrentMap, ) { let now = Instant::now(); @@ -166,12 +165,14 @@ impl TorrentMaps { #[derive(Clone)] pub struct State { + pub access_list: Arc, pub torrent_maps: Arc>, } impl Default for State { fn default() -> Self { Self { + access_list: Arc::new(Default::default()), torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), } } diff --git a/aquatic_http/src/lib/handler.rs b/aquatic_http/src/lib/handler.rs index feef6b1..cd823d8 100644 --- a/aquatic_http/src/lib/handler.rs +++ b/aquatic_http/src/lib/handler.rs @@ -105,77 +105,69 @@ pub fn handle_announce_requests( let valid_until = ValidUntil::new(config.cleaning.max_peer_age); for (meta, request) in requests { - let info_hash_allowed = torrent_maps - .access_list - .allows(config.access_list.mode, &request.info_hash.0); + let peer_ip = convert_ipv4_mapped_ipv6(meta.peer_addr.ip()); - let response = if info_hash_allowed { - let peer_ip = convert_ipv4_mapped_ipv6(meta.peer_addr.ip()); + ::log::debug!("peer ip: {:?}", peer_ip); - ::log::debug!("peer ip: {:?}", peer_ip); + let response = match peer_ip { + IpAddr::V4(peer_ip_address) => { + let torrent_data: &mut TorrentData = + torrent_maps.ipv4.entry(request.info_hash).or_default(); - match peer_ip { - IpAddr::V4(peer_ip_address) => { - let torrent_data: &mut TorrentData = - torrent_maps.ipv4.entry(request.info_hash).or_default(); + let peer_connection_meta = PeerConnectionMeta { + worker_index: meta.worker_index, + poll_token: meta.poll_token, + peer_ip_address, + }; - let peer_connection_meta = PeerConnectionMeta { - worker_index: meta.worker_index, - poll_token: meta.poll_token, - peer_ip_address, - }; + let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers( + config, + rng, + peer_connection_meta, + torrent_data, + request, + valid_until, + ); - let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers( - config, - rng, - peer_connection_meta, - torrent_data, - request, - valid_until, - ); + let response = AnnounceResponse { + complete: seeders, + incomplete: leechers, + announce_interval: config.protocol.peer_announce_interval, + peers: ResponsePeerListV4(response_peers), + peers6: ResponsePeerListV6(vec![]), + }; - let response = AnnounceResponse { - complete: seeders, - incomplete: leechers, - announce_interval: config.protocol.peer_announce_interval, - peers: ResponsePeerListV4(response_peers), - peers6: ResponsePeerListV6(vec![]), - }; - - Response::Announce(response) - } - IpAddr::V6(peer_ip_address) => { - let torrent_data: &mut TorrentData = - torrent_maps.ipv6.entry(request.info_hash).or_default(); - - let peer_connection_meta = PeerConnectionMeta { - worker_index: meta.worker_index, - poll_token: meta.poll_token, - peer_ip_address, - }; - - let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers( - config, - rng, - peer_connection_meta, - torrent_data, - request, - valid_until, - ); - - let response = AnnounceResponse { - complete: seeders, - incomplete: leechers, - announce_interval: config.protocol.peer_announce_interval, - peers: ResponsePeerListV4(vec![]), - peers6: ResponsePeerListV6(response_peers), - }; - - Response::Announce(response) - } + Response::Announce(response) + } + IpAddr::V6(peer_ip_address) => { + let torrent_data: &mut TorrentData = + torrent_maps.ipv6.entry(request.info_hash).or_default(); + + let peer_connection_meta = PeerConnectionMeta { + worker_index: meta.worker_index, + poll_token: meta.poll_token, + peer_ip_address, + }; + + let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers( + config, + rng, + peer_connection_meta, + torrent_data, + request, + valid_until, + ); + + let response = AnnounceResponse { + complete: seeders, + incomplete: leechers, + announce_interval: config.protocol.peer_announce_interval, + peers: ResponsePeerListV4(vec![]), + peers6: ResponsePeerListV6(response_peers), + }; + + Response::Announce(response) } - } else { - Response::Failure(FailureResponse::new("Info hash not allowed")) }; response_channel_sender.send(meta, response); diff --git a/aquatic_http/src/lib/lib.rs b/aquatic_http/src/lib/lib.rs index 759f607..2662b2b 100644 --- a/aquatic_http/src/lib/lib.rs +++ b/aquatic_http/src/lib/lib.rs @@ -22,18 +22,16 @@ pub const APP_NAME: &str = "aquatic_http: HTTP/TLS BitTorrent tracker"; pub fn run(config: Config) -> anyhow::Result<()> { let state = State::default(); - tasks::update_access_list(&config, &mut state.torrent_maps.lock()); + tasks::update_access_list(&config, &state); start_workers(config.clone(), state.clone())?; loop { ::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); - let mut torrent_maps = state.torrent_maps.lock(); + tasks::update_access_list(&config, &state); - tasks::update_access_list(&config, &mut torrent_maps); - - torrent_maps.clean(&config); + state.torrent_maps.lock().clean(&config, &state.access_list); } } @@ -57,6 +55,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { for i in 0..config.socket_workers { let config = config.clone(); + let state = state.clone(); let socket_worker_statuses = socket_worker_statuses.clone(); let request_channel_sender = request_channel_sender.clone(); let opt_tls_acceptor = opt_tls_acceptor.clone(); @@ -73,6 +72,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { .spawn(move || { network::run_socket_worker( config, + state, i, socket_worker_statuses, request_channel_sender, diff --git a/aquatic_http/src/lib/network/mod.rs b/aquatic_http/src/lib/network/mod.rs index 8fa0ffd..7ef5e27 100644 --- a/aquatic_http/src/lib/network/mod.rs +++ b/aquatic_http/src/lib/network/mod.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use std::vec::Drain; +use aquatic_http_protocol::request::Request; use hashbrown::HashMap; use log::{debug, error, info}; use mio::net::TcpListener; @@ -25,6 +26,7 @@ const CONNECTION_CLEAN_INTERVAL: usize = 2 ^ 22; pub fn run_socket_worker( config: Config, + state: State, socket_worker_index: usize, socket_worker_statuses: SocketWorkerStatuses, request_channel_sender: RequestChannelSender, @@ -38,6 +40,7 @@ pub fn run_socket_worker( run_poll_loop( config, + &state, socket_worker_index, request_channel_sender, response_channel_receiver, @@ -55,6 +58,7 @@ pub fn run_socket_worker( pub fn run_poll_loop( config: Config, + state: &State, socket_worker_index: usize, request_channel_sender: RequestChannelSender, response_channel_receiver: ResponseChannelReceiver, @@ -100,6 +104,7 @@ pub fn run_poll_loop( } else if token != CHANNEL_TOKEN { handle_connection_read_event( &config, + &state, socket_worker_index, &mut poll, &request_channel_sender, @@ -179,6 +184,7 @@ fn accept_new_streams( /// then read requests and pass on through channel. pub fn handle_connection_read_event( config: &Config, + state: &State, socket_worker_index: usize, poll: &mut Poll, request_channel_sender: &RequestChannelSender, @@ -187,6 +193,7 @@ pub fn handle_connection_read_event( poll_token: Token, ) { let valid_until = ValidUntil::new(config.cleaning.max_connection_age); + let access_list_mode = config.access_list.mode; loop { // Get connection, updating valid_until @@ -210,10 +217,26 @@ pub fn handle_connection_read_event( peer_addr: established.peer_addr, }; - debug!("read request, sending to handler"); + let opt_allowed_request = if let Request::Announce(ref r) = request { + if state.access_list.allows(access_list_mode, &r.info_hash.0){ + Some(request) + } else { + None + } + } else { + Some(request) + }; - if let Err(err) = request_channel_sender.send((meta, request)) { - error!("RequestChannelSender: couldn't send message: {:?}", err); + if let Some(request) = opt_allowed_request { + debug!("read allowed request, sending on to channel"); + + if let Err(err) = request_channel_sender.send((meta, request)) { + error!("RequestChannelSender: couldn't send message: {:?}", err); + } + } else { + let response = FailureResponse::new("Info hash not allowed"); + + local_responses.push((meta, Response::Failure(response))) } break; diff --git a/aquatic_http/src/lib/tasks.rs b/aquatic_http/src/lib/tasks.rs index 93f931c..44dd26a 100644 --- a/aquatic_http/src/lib/tasks.rs +++ b/aquatic_http/src/lib/tasks.rs @@ -4,10 +4,10 @@ use aquatic_common::access_list::AccessListMode; use crate::{common::*, config::Config}; -pub fn update_access_list(config: &Config, torrent_maps: &mut TorrentMaps) { +pub fn update_access_list(config: &Config, state: &State) { match config.access_list.mode { AccessListMode::Require | AccessListMode::Forbid => { - if let Err(err) = torrent_maps + if let Err(err) = state .access_list .update_from_path(&config.access_list.path) { diff --git a/aquatic_udp/src/lib/common.rs b/aquatic_udp/src/lib/common.rs index 69ae48d..67f0f4b 100644 --- a/aquatic_udp/src/lib/common.rs +++ b/aquatic_udp/src/lib/common.rs @@ -118,15 +118,13 @@ pub type TorrentMap = HashMap>; pub struct TorrentMaps { pub ipv4: TorrentMap, pub ipv6: TorrentMap, - pub access_list: AccessList, } impl TorrentMaps { /// Remove disallowed and inactive torrents - pub fn clean(&mut self, config: &Config) { + pub fn clean(&mut self, config: &Config, access_list: &Arc) { let now = Instant::now(); - let access_list = &self.access_list; let access_list_mode = config.access_list.mode; self.ipv4.retain(|info_hash, torrent| { @@ -181,6 +179,7 @@ pub struct Statistics { #[derive(Clone)] pub struct State { + pub access_list: Arc, pub connections: Arc>, pub torrents: Arc>, pub statistics: Arc, @@ -189,6 +188,7 @@ pub struct State { impl Default for State { fn default() -> Self { Self { + access_list: Arc::new(AccessList::default()), connections: Arc::new(Mutex::new(HashMap::new())), torrents: Arc::new(Mutex::new(TorrentMaps::default())), statistics: Arc::new(Statistics::default()), diff --git a/aquatic_udp/src/lib/handlers/announce.rs b/aquatic_udp/src/lib/handlers/announce.rs index 5d73f54..bda60d6 100644 --- a/aquatic_udp/src/lib/handlers/announce.rs +++ b/aquatic_udp/src/lib/handlers/announce.rs @@ -19,44 +19,30 @@ pub fn handle_announce_requests( responses: &mut Vec<(Response, SocketAddr)>, ) { let peer_valid_until = ValidUntil::new(config.cleaning.max_peer_age); - let access_list_mode = config.access_list.mode; responses.extend(requests.map(|(request, src)| { - let info_hash_allowed = torrents - .access_list - .allows(access_list_mode, &request.info_hash.0); + let peer_ip = convert_ipv4_mapped_ipv6(src.ip()); - let response = if info_hash_allowed { - let peer_ip = convert_ipv4_mapped_ipv6(src.ip()); - - let response = match peer_ip { - IpAddr::V4(ip) => handle_announce_request( - config, - rng, - &mut torrents.ipv4, - request, - ip, - peer_valid_until, - ), - IpAddr::V6(ip) => handle_announce_request( - config, - rng, - &mut torrents.ipv6, - request, - ip, - peer_valid_until, - ), - }; - - Response::Announce(response) - } else { - Response::Error(ErrorResponse { - transaction_id: request.transaction_id, - message: "Info hash not allowed".into(), - }) + let response = match peer_ip { + IpAddr::V4(ip) => handle_announce_request( + config, + rng, + &mut torrents.ipv4, + request, + ip, + peer_valid_until, + ), + IpAddr::V6(ip) => handle_announce_request( + config, + rng, + &mut torrents.ipv6, + request, + ip, + peer_valid_until, + ), }; - (response, src) + (Response::Announce(response), src) })); } diff --git a/aquatic_udp/src/lib/lib.rs b/aquatic_udp/src/lib/lib.rs index 20c70e1..5e32d83 100644 --- a/aquatic_udp/src/lib/lib.rs +++ b/aquatic_udp/src/lib/lib.rs @@ -23,7 +23,7 @@ pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker"; pub fn run(config: Config) -> ::anyhow::Result<()> { let state = State::default(); - tasks::update_access_list(&config, &mut state.torrents.lock()); + tasks::update_access_list(&config, &state); let num_bound_sockets = start_workers(config.clone(), state.clone())?; @@ -56,12 +56,9 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { ::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); tasks::clean_connections(&state); + tasks::update_access_list(&config, &state); - let mut torrent_maps = state.torrents.lock(); - - tasks::update_access_list(&config, &mut torrent_maps); - - torrent_maps.clean(&config); + state.torrents.lock().clean(&config, &state.access_list); } } diff --git a/aquatic_udp/src/lib/network.rs b/aquatic_udp/src/lib/network.rs index d0a5506..306f167 100644 --- a/aquatic_udp/src/lib/network.rs +++ b/aquatic_udp/src/lib/network.rs @@ -131,6 +131,8 @@ fn read_requests( let mut requests_received: usize = 0; let mut bytes_received: usize = 0; + let access_list_mode = config.access_list.mode; + loop { match socket.recv_from(&mut buffer[..]) { Ok((amt, src)) => { @@ -145,7 +147,20 @@ fn read_requests( match request { Ok(request) => { - requests.push((request, src)); + if let Request::Announce(AnnounceRequest { info_hash, transaction_id, ..}) = request { + if state.access_list.allows(access_list_mode, &info_hash.0) { + requests.push((request, src)); + } else { + let response = Response::Error(ErrorResponse { + transaction_id, + message: "Info hash not allowed".into() + }); + + local_responses.push((response, src)) + } + } else { + requests.push((request, src)); + } } Err(err) => { ::log::debug!("request_from_bytes error: {:?}", err); diff --git a/aquatic_udp/src/lib/tasks.rs b/aquatic_udp/src/lib/tasks.rs index 9b18679..1d887bf 100644 --- a/aquatic_udp/src/lib/tasks.rs +++ b/aquatic_udp/src/lib/tasks.rs @@ -8,10 +8,10 @@ use aquatic_common::access_list::AccessListMode; use crate::common::*; use crate::config::Config; -pub fn update_access_list(config: &Config, torrent_maps: &mut TorrentMaps) { +pub fn update_access_list(config: &Config, state: &State) { match config.access_list.mode { AccessListMode::Require | AccessListMode::Forbid => { - if let Err(err) = torrent_maps + if let Err(err) = state .access_list .update_from_path(&config.access_list.path) { diff --git a/aquatic_ws/src/lib/common.rs b/aquatic_ws/src/lib/common.rs index dcaa96f..287b8a2 100644 --- a/aquatic_ws/src/lib/common.rs +++ b/aquatic_ws/src/lib/common.rs @@ -86,16 +86,15 @@ pub type TorrentMap = HashMap; pub struct TorrentMaps { pub ipv4: TorrentMap, pub ipv6: TorrentMap, - pub access_list: AccessList, } impl TorrentMaps { - pub fn clean(&mut self, config: &Config) { - Self::clean_torrent_map(config, &self.access_list, &mut self.ipv4); - Self::clean_torrent_map(config, &self.access_list, &mut self.ipv6); + pub fn clean(&mut self, config: &Config, access_list: &Arc) { + Self::clean_torrent_map(config, access_list, &mut self.ipv4); + Self::clean_torrent_map(config, access_list, &mut self.ipv6); } - fn clean_torrent_map(config: &Config, access_list: &AccessList, torrent_map: &mut TorrentMap) { + fn clean_torrent_map(config: &Config, access_list: &Arc, torrent_map: &mut TorrentMap) { let now = Instant::now(); torrent_map.retain(|info_hash, torrent_data| { @@ -133,12 +132,14 @@ impl TorrentMaps { #[derive(Clone)] pub struct State { + pub access_list: Arc, pub torrent_maps: Arc>, } impl Default for State { fn default() -> Self { Self { + access_list: Arc::new(Default::default()), torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), } } diff --git a/aquatic_ws/src/lib/handler.rs b/aquatic_ws/src/lib/handler.rs index d862926..1f61a17 100644 --- a/aquatic_ws/src/lib/handler.rs +++ b/aquatic_ws/src/lib/handler.rs @@ -99,23 +99,6 @@ pub fn handle_announce_requests( let valid_until = ValidUntil::new(config.cleaning.max_peer_age); for (request_sender_meta, request) in requests { - let info_hash_allowed = torrent_maps - .access_list - .allows(config.access_list.mode, &request.info_hash.0); - - if !info_hash_allowed { - let response = OutMessage::ErrorResponse(ErrorResponse { - failure_reason: "Info hash not allowed".into(), - action: Some(ErrorResponseAction::Announce), - info_hash: Some(request.info_hash), - }); - - out_message_sender.send(request_sender_meta, response); - wake_socket_workers[request_sender_meta.worker_index] = true; - - continue; - } - let torrent_data: &mut TorrentData = if request_sender_meta.converted_peer_ip.is_ipv4() { torrent_maps.ipv4.entry(request.info_hash).or_default() } else { diff --git a/aquatic_ws/src/lib/lib.rs b/aquatic_ws/src/lib/lib.rs index c02e2da..f6ee599 100644 --- a/aquatic_ws/src/lib/lib.rs +++ b/aquatic_ws/src/lib/lib.rs @@ -24,18 +24,16 @@ pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker"; pub fn run(config: Config) -> anyhow::Result<()> { let state = State::default(); - tasks::update_access_list(&config, &mut state.torrent_maps.lock()); + tasks::update_access_list(&config, &state); start_workers(config.clone(), state.clone())?; loop { ::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); - let mut torrent_maps = state.torrent_maps.lock(); + tasks::update_access_list(&config, &state); - tasks::update_access_list(&config, &mut torrent_maps); - - torrent_maps.clean(&config); + state.torrent_maps.lock().clean(&config, &state.access_list); } } @@ -59,6 +57,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { for i in 0..config.socket_workers { let config = config.clone(); + let state = state.clone(); let socket_worker_statuses = socket_worker_statuses.clone(); let in_message_sender = in_message_sender.clone(); let opt_tls_acceptor = opt_tls_acceptor.clone(); @@ -75,6 +74,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> { .spawn(move || { network::run_socket_worker( config, + state, i, socket_worker_statuses, poll, diff --git a/aquatic_ws/src/lib/network/mod.rs b/aquatic_ws/src/lib/network/mod.rs index 4d0e457..31e277d 100644 --- a/aquatic_ws/src/lib/network/mod.rs +++ b/aquatic_ws/src/lib/network/mod.rs @@ -1,7 +1,9 @@ use std::io::ErrorKind; use std::time::Duration; +use std::vec::Drain; use crossbeam_channel::Receiver; +use either::Either; use hashbrown::HashMap; use log::{debug, error, info}; use mio::net::TcpListener; @@ -23,6 +25,7 @@ use utils::*; pub fn run_socket_worker( config: Config, + state: State, socket_worker_index: usize, socket_worker_statuses: SocketWorkerStatuses, poll: Poll, @@ -36,6 +39,7 @@ pub fn run_socket_worker( run_poll_loop( config, + &state, socket_worker_index, poll, in_message_sender, @@ -53,6 +57,7 @@ pub fn run_socket_worker( pub fn run_poll_loop( config: Config, + state: &State, socket_worker_index: usize, mut poll: Poll, in_message_sender: InMessageSender, @@ -76,6 +81,7 @@ pub fn run_poll_loop( .unwrap(); let mut connections: ConnectionMap = HashMap::new(); + let mut local_responses = Vec::new(); let mut poll_token_counter = Token(0usize); let mut iter_counter = 0usize; @@ -100,7 +106,10 @@ pub fn run_poll_loop( ); } else if token != CHANNEL_TOKEN { run_handshakes_and_read_messages( + &config, + state, socket_worker_index, + &mut local_responses, &in_message_sender, &opt_tls_acceptor, &mut poll, @@ -110,7 +119,12 @@ pub fn run_poll_loop( ); } - send_out_messages(&mut poll, &out_message_receiver, &mut connections); + send_out_messages( + &mut poll, + local_responses.drain(..), + &out_message_receiver, + &mut connections + ); } // Remove inactive connections, but not every iteration @@ -165,7 +179,10 @@ fn accept_new_streams( /// On the stream given by poll_token, get TLS (if requested) and tungstenite /// up and running, then read messages and pass on through channel. pub fn run_handshakes_and_read_messages( + config: &Config, + state: &State, socket_worker_index: usize, + local_responses: &mut Vec<(ConnectionMeta, OutMessage)>, in_message_sender: &InMessageSender, opt_tls_acceptor: &Option, // If set, run TLS poll: &mut Poll, @@ -173,6 +190,8 @@ pub fn run_handshakes_and_read_messages( poll_token: Token, valid_until: ValidUntil, ) { + let access_list_mode = config.access_list.mode; + loop { if let Some(established_ws) = connections .get_mut(&poll_token) @@ -201,8 +220,31 @@ pub fn run_handshakes_and_read_messages( debug!("read message"); - if let Err(err) = in_message_sender.send((meta, in_message)) { - error!("InMessageSender: couldn't send message: {:?}", err); + let message = if let InMessage::AnnounceRequest(ref request) = in_message { + if state.access_list.allows(access_list_mode, &request.info_hash.0){ + Either::Left(in_message) + } else { + let out_message = OutMessage::ErrorResponse(ErrorResponse { + failure_reason: "Info hash not allowed".into(), + action: Some(ErrorResponseAction::Announce), + info_hash: Some(request.info_hash), + }); + + Either::Right(out_message) + } + } else { + Either::Left(in_message) + }; + + match message { + Either::Left(in_message) => { + if let Err(err) = in_message_sender.send((meta, in_message)) { + error!("InMessageSender: couldn't send message: {:?}", err); + } + }, + Either::Right(out_message) => { + local_responses.push((meta, out_message)); + } } } } @@ -242,12 +284,13 @@ pub fn run_handshakes_and_read_messages( /// Read messages from channel, send to peers pub fn send_out_messages( poll: &mut Poll, + local_responses: Drain<(ConnectionMeta, OutMessage)>, out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>, connections: &mut ConnectionMap, ) { let len = out_message_receiver.len(); - for (meta, out_message) in out_message_receiver.try_iter().take(len) { + for (meta, out_message) in local_responses.chain(out_message_receiver.try_iter().take(len)) { let opt_established_ws = connections .get_mut(&meta.poll_token) .and_then(Connection::get_established_ws); diff --git a/aquatic_ws/src/lib/tasks.rs b/aquatic_ws/src/lib/tasks.rs index 014591d..bec0285 100644 --- a/aquatic_ws/src/lib/tasks.rs +++ b/aquatic_ws/src/lib/tasks.rs @@ -4,10 +4,10 @@ use histogram::Histogram; use crate::common::*; use crate::config::Config; -pub fn update_access_list(config: &Config, torrent_maps: &mut TorrentMaps) { +pub fn update_access_list(config: &Config, state: &State) { match config.access_list.mode { AccessListMode::Require | AccessListMode::Forbid => { - if let Err(err) = torrent_maps + if let Err(err) = state .access_list .update_from_path(&config.access_list.path) {