access lists: filter requests in socket workers instead

This commit is contained in:
Joakim Frostegård 2021-10-16 17:26:40 +02:00
parent 33966bed57
commit 7ccd5fcbf7
18 changed files with 221 additions and 163 deletions

9
Cargo.lock generated
View file

@ -1,5 +1,7 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3
[[package]] [[package]]
name = "addr2line" name = "addr2line"
version = "0.16.0" version = "0.16.0"
@ -73,6 +75,7 @@ name = "aquatic_common"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"arc-swap",
"hashbrown 0.11.2", "hashbrown 0.11.2",
"hex", "hex",
"indexmap", "indexmap",
@ -279,6 +282,12 @@ dependencies = [
"tungstenite", "tungstenite",
] ]
[[package]]
name = "arc-swap"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6df5aef5c5830360ce5218cecb8f018af3438af5686ae945094affc86fdec63"
[[package]] [[package]]
name = "arrayvec" name = "arrayvec"
version = "0.4.12" version = "0.4.12"

View file

@ -12,6 +12,7 @@ name = "aquatic_common"
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
arc-swap = "1"
hashbrown = "0.11.2" hashbrown = "0.11.2"
hex = "0.4" hex = "0.4"
indexmap = "1" indexmap = "1"

View file

@ -1,7 +1,9 @@
use std::fs::File; use std::fs::File;
use std::io::{BufRead, BufReader}; use std::io::{BufRead, BufReader};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use arc_swap::ArcSwap;
use hashbrown::HashSet; use hashbrown::HashSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -34,8 +36,13 @@ impl Default for AccessListConfig {
} }
} }
#[derive(Default)] pub struct AccessList(ArcSwap<HashSet<[u8; 20]>>);
pub struct AccessList(HashSet<[u8; 20]>);
impl Default for AccessList {
fn default() -> Self {
Self(ArcSwap::from(Arc::new(HashSet::default())))
}
}
impl AccessList { impl AccessList {
fn parse_info_hash(line: String) -> anyhow::Result<[u8; 20]> { fn parse_info_hash(line: String) -> anyhow::Result<[u8; 20]> {
@ -46,7 +53,7 @@ impl AccessList {
Ok(bytes) Ok(bytes)
} }
pub fn update_from_path(&mut self, path: &PathBuf) -> anyhow::Result<()> { pub fn update_from_path(&self, path: &PathBuf) -> anyhow::Result<()> {
let file = File::open(path)?; let file = File::open(path)?;
let reader = BufReader::new(file); let reader = BufReader::new(file);
@ -56,15 +63,15 @@ impl AccessList {
new_list.insert(Self::parse_info_hash(line?)?); new_list.insert(Self::parse_info_hash(line?)?);
} }
self.0 = new_list; self.0.store(Arc::new(new_list));
Ok(()) Ok(())
} }
pub fn allows(&self, list_type: AccessListMode, info_hash_bytes: &[u8; 20]) -> bool { pub fn allows(&self, list_type: AccessListMode, info_hash_bytes: &[u8; 20]) -> bool {
match list_type { match list_type {
AccessListMode::Require => self.0.contains(info_hash_bytes), AccessListMode::Require => self.0.load().contains(info_hash_bytes),
AccessListMode::Forbid => !self.0.contains(info_hash_bytes), AccessListMode::Forbid => !self.0.load().contains(info_hash_bytes),
AccessListMode::Ignore => true, AccessListMode::Ignore => true,
} }
} }

View file

@ -115,18 +115,17 @@ pub type TorrentMap<I> = HashMap<InfoHash, TorrentData<I>>;
pub struct TorrentMaps { pub struct TorrentMaps {
pub ipv4: TorrentMap<Ipv4Addr>, pub ipv4: TorrentMap<Ipv4Addr>,
pub ipv6: TorrentMap<Ipv6Addr>, pub ipv6: TorrentMap<Ipv6Addr>,
pub access_list: AccessList,
} }
impl TorrentMaps { impl TorrentMaps {
pub fn clean(&mut self, config: &Config) { pub fn clean(&mut self, config: &Config, access_list: &Arc<AccessList>) {
Self::clean_torrent_map(config, &self.access_list, &mut self.ipv4); Self::clean_torrent_map(config, access_list, &mut self.ipv4);
Self::clean_torrent_map(config, &self.access_list, &mut self.ipv6); Self::clean_torrent_map(config, access_list, &mut self.ipv6);
} }
fn clean_torrent_map<I: Ip>( fn clean_torrent_map<I: Ip>(
config: &Config, config: &Config,
access_list: &AccessList, access_list: &Arc<AccessList>,
torrent_map: &mut TorrentMap<I>, torrent_map: &mut TorrentMap<I>,
) { ) {
let now = Instant::now(); let now = Instant::now();
@ -166,12 +165,14 @@ impl TorrentMaps {
#[derive(Clone)] #[derive(Clone)]
pub struct State { pub struct State {
pub access_list: Arc<AccessList>,
pub torrent_maps: Arc<Mutex<TorrentMaps>>, pub torrent_maps: Arc<Mutex<TorrentMaps>>,
} }
impl Default for State { impl Default for State {
fn default() -> Self { fn default() -> Self {
Self { Self {
access_list: Arc::new(Default::default()),
torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())),
} }
} }

View file

@ -105,77 +105,69 @@ pub fn handle_announce_requests(
let valid_until = ValidUntil::new(config.cleaning.max_peer_age); let valid_until = ValidUntil::new(config.cleaning.max_peer_age);
for (meta, request) in requests { for (meta, request) in requests {
let info_hash_allowed = torrent_maps let peer_ip = convert_ipv4_mapped_ipv6(meta.peer_addr.ip());
.access_list
.allows(config.access_list.mode, &request.info_hash.0);
let response = if info_hash_allowed { ::log::debug!("peer ip: {:?}", peer_ip);
let peer_ip = convert_ipv4_mapped_ipv6(meta.peer_addr.ip());
::log::debug!("peer ip: {:?}", peer_ip); let response = match peer_ip {
IpAddr::V4(peer_ip_address) => {
let torrent_data: &mut TorrentData<Ipv4Addr> =
torrent_maps.ipv4.entry(request.info_hash).or_default();
match peer_ip { let peer_connection_meta = PeerConnectionMeta {
IpAddr::V4(peer_ip_address) => { worker_index: meta.worker_index,
let torrent_data: &mut TorrentData<Ipv4Addr> = poll_token: meta.poll_token,
torrent_maps.ipv4.entry(request.info_hash).or_default(); peer_ip_address,
};
let peer_connection_meta = PeerConnectionMeta { let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers(
worker_index: meta.worker_index, config,
poll_token: meta.poll_token, rng,
peer_ip_address, peer_connection_meta,
}; torrent_data,
request,
valid_until,
);
let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers( let response = AnnounceResponse {
config, complete: seeders,
rng, incomplete: leechers,
peer_connection_meta, announce_interval: config.protocol.peer_announce_interval,
torrent_data, peers: ResponsePeerListV4(response_peers),
request, peers6: ResponsePeerListV6(vec![]),
valid_until, };
);
let response = AnnounceResponse { Response::Announce(response)
complete: seeders, }
incomplete: leechers, IpAddr::V6(peer_ip_address) => {
announce_interval: config.protocol.peer_announce_interval, let torrent_data: &mut TorrentData<Ipv6Addr> =
peers: ResponsePeerListV4(response_peers), torrent_maps.ipv6.entry(request.info_hash).or_default();
peers6: ResponsePeerListV6(vec![]),
}; let peer_connection_meta = PeerConnectionMeta {
worker_index: meta.worker_index,
Response::Announce(response) poll_token: meta.poll_token,
} peer_ip_address,
IpAddr::V6(peer_ip_address) => { };
let torrent_data: &mut TorrentData<Ipv6Addr> =
torrent_maps.ipv6.entry(request.info_hash).or_default(); let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers(
config,
let peer_connection_meta = PeerConnectionMeta { rng,
worker_index: meta.worker_index, peer_connection_meta,
poll_token: meta.poll_token, torrent_data,
peer_ip_address, request,
}; valid_until,
);
let (seeders, leechers, response_peers) = upsert_peer_and_get_response_peers(
config, let response = AnnounceResponse {
rng, complete: seeders,
peer_connection_meta, incomplete: leechers,
torrent_data, announce_interval: config.protocol.peer_announce_interval,
request, peers: ResponsePeerListV4(vec![]),
valid_until, peers6: ResponsePeerListV6(response_peers),
); };
let response = AnnounceResponse { Response::Announce(response)
complete: seeders,
incomplete: leechers,
announce_interval: config.protocol.peer_announce_interval,
peers: ResponsePeerListV4(vec![]),
peers6: ResponsePeerListV6(response_peers),
};
Response::Announce(response)
}
} }
} else {
Response::Failure(FailureResponse::new("Info hash not allowed"))
}; };
response_channel_sender.send(meta, response); response_channel_sender.send(meta, response);

View file

@ -22,18 +22,16 @@ pub const APP_NAME: &str = "aquatic_http: HTTP/TLS BitTorrent tracker";
pub fn run(config: Config) -> anyhow::Result<()> { pub fn run(config: Config) -> anyhow::Result<()> {
let state = State::default(); let state = State::default();
tasks::update_access_list(&config, &mut state.torrent_maps.lock()); tasks::update_access_list(&config, &state);
start_workers(config.clone(), state.clone())?; start_workers(config.clone(), state.clone())?;
loop { loop {
::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); ::std::thread::sleep(Duration::from_secs(config.cleaning.interval));
let mut torrent_maps = state.torrent_maps.lock(); tasks::update_access_list(&config, &state);
tasks::update_access_list(&config, &mut torrent_maps); state.torrent_maps.lock().clean(&config, &state.access_list);
torrent_maps.clean(&config);
} }
} }
@ -57,6 +55,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
for i in 0..config.socket_workers { for i in 0..config.socket_workers {
let config = config.clone(); let config = config.clone();
let state = state.clone();
let socket_worker_statuses = socket_worker_statuses.clone(); let socket_worker_statuses = socket_worker_statuses.clone();
let request_channel_sender = request_channel_sender.clone(); let request_channel_sender = request_channel_sender.clone();
let opt_tls_acceptor = opt_tls_acceptor.clone(); let opt_tls_acceptor = opt_tls_acceptor.clone();
@ -73,6 +72,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
.spawn(move || { .spawn(move || {
network::run_socket_worker( network::run_socket_worker(
config, config,
state,
i, i,
socket_worker_statuses, socket_worker_statuses,
request_channel_sender, request_channel_sender,

View file

@ -3,6 +3,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::vec::Drain; use std::vec::Drain;
use aquatic_http_protocol::request::Request;
use hashbrown::HashMap; use hashbrown::HashMap;
use log::{debug, error, info}; use log::{debug, error, info};
use mio::net::TcpListener; use mio::net::TcpListener;
@ -25,6 +26,7 @@ const CONNECTION_CLEAN_INTERVAL: usize = 2 ^ 22;
pub fn run_socket_worker( pub fn run_socket_worker(
config: Config, config: Config,
state: State,
socket_worker_index: usize, socket_worker_index: usize,
socket_worker_statuses: SocketWorkerStatuses, socket_worker_statuses: SocketWorkerStatuses,
request_channel_sender: RequestChannelSender, request_channel_sender: RequestChannelSender,
@ -38,6 +40,7 @@ pub fn run_socket_worker(
run_poll_loop( run_poll_loop(
config, config,
&state,
socket_worker_index, socket_worker_index,
request_channel_sender, request_channel_sender,
response_channel_receiver, response_channel_receiver,
@ -55,6 +58,7 @@ pub fn run_socket_worker(
pub fn run_poll_loop( pub fn run_poll_loop(
config: Config, config: Config,
state: &State,
socket_worker_index: usize, socket_worker_index: usize,
request_channel_sender: RequestChannelSender, request_channel_sender: RequestChannelSender,
response_channel_receiver: ResponseChannelReceiver, response_channel_receiver: ResponseChannelReceiver,
@ -100,6 +104,7 @@ pub fn run_poll_loop(
} else if token != CHANNEL_TOKEN { } else if token != CHANNEL_TOKEN {
handle_connection_read_event( handle_connection_read_event(
&config, &config,
&state,
socket_worker_index, socket_worker_index,
&mut poll, &mut poll,
&request_channel_sender, &request_channel_sender,
@ -179,6 +184,7 @@ fn accept_new_streams(
/// then read requests and pass on through channel. /// then read requests and pass on through channel.
pub fn handle_connection_read_event( pub fn handle_connection_read_event(
config: &Config, config: &Config,
state: &State,
socket_worker_index: usize, socket_worker_index: usize,
poll: &mut Poll, poll: &mut Poll,
request_channel_sender: &RequestChannelSender, request_channel_sender: &RequestChannelSender,
@ -187,6 +193,7 @@ pub fn handle_connection_read_event(
poll_token: Token, poll_token: Token,
) { ) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age); let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode;
loop { loop {
// Get connection, updating valid_until // Get connection, updating valid_until
@ -210,10 +217,26 @@ pub fn handle_connection_read_event(
peer_addr: established.peer_addr, peer_addr: established.peer_addr,
}; };
debug!("read request, sending to handler"); let opt_allowed_request = if let Request::Announce(ref r) = request {
if state.access_list.allows(access_list_mode, &r.info_hash.0){
Some(request)
} else {
None
}
} else {
Some(request)
};
if let Err(err) = request_channel_sender.send((meta, request)) { if let Some(request) = opt_allowed_request {
error!("RequestChannelSender: couldn't send message: {:?}", err); debug!("read allowed request, sending on to channel");
if let Err(err) = request_channel_sender.send((meta, request)) {
error!("RequestChannelSender: couldn't send message: {:?}", err);
}
} else {
let response = FailureResponse::new("Info hash not allowed");
local_responses.push((meta, Response::Failure(response)))
} }
break; break;

View file

@ -4,10 +4,10 @@ use aquatic_common::access_list::AccessListMode;
use crate::{common::*, config::Config}; use crate::{common::*, config::Config};
pub fn update_access_list(config: &Config, torrent_maps: &mut TorrentMaps) { pub fn update_access_list(config: &Config, state: &State) {
match config.access_list.mode { match config.access_list.mode {
AccessListMode::Require | AccessListMode::Forbid => { AccessListMode::Require | AccessListMode::Forbid => {
if let Err(err) = torrent_maps if let Err(err) = state
.access_list .access_list
.update_from_path(&config.access_list.path) .update_from_path(&config.access_list.path)
{ {

View file

@ -118,15 +118,13 @@ pub type TorrentMap<I> = HashMap<InfoHash, TorrentData<I>>;
pub struct TorrentMaps { pub struct TorrentMaps {
pub ipv4: TorrentMap<Ipv4Addr>, pub ipv4: TorrentMap<Ipv4Addr>,
pub ipv6: TorrentMap<Ipv6Addr>, pub ipv6: TorrentMap<Ipv6Addr>,
pub access_list: AccessList,
} }
impl TorrentMaps { impl TorrentMaps {
/// Remove disallowed and inactive torrents /// Remove disallowed and inactive torrents
pub fn clean(&mut self, config: &Config) { pub fn clean(&mut self, config: &Config, access_list: &Arc<AccessList>) {
let now = Instant::now(); let now = Instant::now();
let access_list = &self.access_list;
let access_list_mode = config.access_list.mode; let access_list_mode = config.access_list.mode;
self.ipv4.retain(|info_hash, torrent| { self.ipv4.retain(|info_hash, torrent| {
@ -181,6 +179,7 @@ pub struct Statistics {
#[derive(Clone)] #[derive(Clone)]
pub struct State { pub struct State {
pub access_list: Arc<AccessList>,
pub connections: Arc<Mutex<ConnectionMap>>, pub connections: Arc<Mutex<ConnectionMap>>,
pub torrents: Arc<Mutex<TorrentMaps>>, pub torrents: Arc<Mutex<TorrentMaps>>,
pub statistics: Arc<Statistics>, pub statistics: Arc<Statistics>,
@ -189,6 +188,7 @@ pub struct State {
impl Default for State { impl Default for State {
fn default() -> Self { fn default() -> Self {
Self { Self {
access_list: Arc::new(AccessList::default()),
connections: Arc::new(Mutex::new(HashMap::new())), connections: Arc::new(Mutex::new(HashMap::new())),
torrents: Arc::new(Mutex::new(TorrentMaps::default())), torrents: Arc::new(Mutex::new(TorrentMaps::default())),
statistics: Arc::new(Statistics::default()), statistics: Arc::new(Statistics::default()),

View file

@ -19,44 +19,30 @@ pub fn handle_announce_requests(
responses: &mut Vec<(Response, SocketAddr)>, responses: &mut Vec<(Response, SocketAddr)>,
) { ) {
let peer_valid_until = ValidUntil::new(config.cleaning.max_peer_age); let peer_valid_until = ValidUntil::new(config.cleaning.max_peer_age);
let access_list_mode = config.access_list.mode;
responses.extend(requests.map(|(request, src)| { responses.extend(requests.map(|(request, src)| {
let info_hash_allowed = torrents let peer_ip = convert_ipv4_mapped_ipv6(src.ip());
.access_list
.allows(access_list_mode, &request.info_hash.0);
let response = if info_hash_allowed { let response = match peer_ip {
let peer_ip = convert_ipv4_mapped_ipv6(src.ip()); IpAddr::V4(ip) => handle_announce_request(
config,
let response = match peer_ip { rng,
IpAddr::V4(ip) => handle_announce_request( &mut torrents.ipv4,
config, request,
rng, ip,
&mut torrents.ipv4, peer_valid_until,
request, ),
ip, IpAddr::V6(ip) => handle_announce_request(
peer_valid_until, config,
), rng,
IpAddr::V6(ip) => handle_announce_request( &mut torrents.ipv6,
config, request,
rng, ip,
&mut torrents.ipv6, peer_valid_until,
request, ),
ip,
peer_valid_until,
),
};
Response::Announce(response)
} else {
Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
})
}; };
(response, src) (Response::Announce(response), src)
})); }));
} }

View file

@ -23,7 +23,7 @@ pub const APP_NAME: &str = "aquatic_udp: UDP BitTorrent tracker";
pub fn run(config: Config) -> ::anyhow::Result<()> { pub fn run(config: Config) -> ::anyhow::Result<()> {
let state = State::default(); let state = State::default();
tasks::update_access_list(&config, &mut state.torrents.lock()); tasks::update_access_list(&config, &state);
let num_bound_sockets = start_workers(config.clone(), state.clone())?; let num_bound_sockets = start_workers(config.clone(), state.clone())?;
@ -56,12 +56,9 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); ::std::thread::sleep(Duration::from_secs(config.cleaning.interval));
tasks::clean_connections(&state); tasks::clean_connections(&state);
tasks::update_access_list(&config, &state);
let mut torrent_maps = state.torrents.lock(); state.torrents.lock().clean(&config, &state.access_list);
tasks::update_access_list(&config, &mut torrent_maps);
torrent_maps.clean(&config);
} }
} }

View file

@ -131,6 +131,8 @@ fn read_requests(
let mut requests_received: usize = 0; let mut requests_received: usize = 0;
let mut bytes_received: usize = 0; let mut bytes_received: usize = 0;
let access_list_mode = config.access_list.mode;
loop { loop {
match socket.recv_from(&mut buffer[..]) { match socket.recv_from(&mut buffer[..]) {
Ok((amt, src)) => { Ok((amt, src)) => {
@ -145,7 +147,20 @@ fn read_requests(
match request { match request {
Ok(request) => { Ok(request) => {
requests.push((request, src)); if let Request::Announce(AnnounceRequest { info_hash, transaction_id, ..}) = request {
if state.access_list.allows(access_list_mode, &info_hash.0) {
requests.push((request, src));
} else {
let response = Response::Error(ErrorResponse {
transaction_id,
message: "Info hash not allowed".into()
});
local_responses.push((response, src))
}
} else {
requests.push((request, src));
}
} }
Err(err) => { Err(err) => {
::log::debug!("request_from_bytes error: {:?}", err); ::log::debug!("request_from_bytes error: {:?}", err);

View file

@ -8,10 +8,10 @@ use aquatic_common::access_list::AccessListMode;
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
pub fn update_access_list(config: &Config, torrent_maps: &mut TorrentMaps) { pub fn update_access_list(config: &Config, state: &State) {
match config.access_list.mode { match config.access_list.mode {
AccessListMode::Require | AccessListMode::Forbid => { AccessListMode::Require | AccessListMode::Forbid => {
if let Err(err) = torrent_maps if let Err(err) = state
.access_list .access_list
.update_from_path(&config.access_list.path) .update_from_path(&config.access_list.path)
{ {

View file

@ -86,16 +86,15 @@ pub type TorrentMap = HashMap<InfoHash, TorrentData>;
pub struct TorrentMaps { pub struct TorrentMaps {
pub ipv4: TorrentMap, pub ipv4: TorrentMap,
pub ipv6: TorrentMap, pub ipv6: TorrentMap,
pub access_list: AccessList,
} }
impl TorrentMaps { impl TorrentMaps {
pub fn clean(&mut self, config: &Config) { pub fn clean(&mut self, config: &Config, access_list: &Arc<AccessList>) {
Self::clean_torrent_map(config, &self.access_list, &mut self.ipv4); Self::clean_torrent_map(config, access_list, &mut self.ipv4);
Self::clean_torrent_map(config, &self.access_list, &mut self.ipv6); Self::clean_torrent_map(config, access_list, &mut self.ipv6);
} }
fn clean_torrent_map(config: &Config, access_list: &AccessList, torrent_map: &mut TorrentMap) { fn clean_torrent_map(config: &Config, access_list: &Arc<AccessList>, torrent_map: &mut TorrentMap) {
let now = Instant::now(); let now = Instant::now();
torrent_map.retain(|info_hash, torrent_data| { torrent_map.retain(|info_hash, torrent_data| {
@ -133,12 +132,14 @@ impl TorrentMaps {
#[derive(Clone)] #[derive(Clone)]
pub struct State { pub struct State {
pub access_list: Arc<AccessList>,
pub torrent_maps: Arc<Mutex<TorrentMaps>>, pub torrent_maps: Arc<Mutex<TorrentMaps>>,
} }
impl Default for State { impl Default for State {
fn default() -> Self { fn default() -> Self {
Self { Self {
access_list: Arc::new(Default::default()),
torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())),
} }
} }

View file

@ -99,23 +99,6 @@ pub fn handle_announce_requests(
let valid_until = ValidUntil::new(config.cleaning.max_peer_age); let valid_until = ValidUntil::new(config.cleaning.max_peer_age);
for (request_sender_meta, request) in requests { for (request_sender_meta, request) in requests {
let info_hash_allowed = torrent_maps
.access_list
.allows(config.access_list.mode, &request.info_hash.0);
if !info_hash_allowed {
let response = OutMessage::ErrorResponse(ErrorResponse {
failure_reason: "Info hash not allowed".into(),
action: Some(ErrorResponseAction::Announce),
info_hash: Some(request.info_hash),
});
out_message_sender.send(request_sender_meta, response);
wake_socket_workers[request_sender_meta.worker_index] = true;
continue;
}
let torrent_data: &mut TorrentData = if request_sender_meta.converted_peer_ip.is_ipv4() { let torrent_data: &mut TorrentData = if request_sender_meta.converted_peer_ip.is_ipv4() {
torrent_maps.ipv4.entry(request.info_hash).or_default() torrent_maps.ipv4.entry(request.info_hash).or_default()
} else { } else {

View file

@ -24,18 +24,16 @@ pub const APP_NAME: &str = "aquatic_ws: WebTorrent tracker";
pub fn run(config: Config) -> anyhow::Result<()> { pub fn run(config: Config) -> anyhow::Result<()> {
let state = State::default(); let state = State::default();
tasks::update_access_list(&config, &mut state.torrent_maps.lock()); tasks::update_access_list(&config, &state);
start_workers(config.clone(), state.clone())?; start_workers(config.clone(), state.clone())?;
loop { loop {
::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); ::std::thread::sleep(Duration::from_secs(config.cleaning.interval));
let mut torrent_maps = state.torrent_maps.lock(); tasks::update_access_list(&config, &state);
tasks::update_access_list(&config, &mut torrent_maps); state.torrent_maps.lock().clean(&config, &state.access_list);
torrent_maps.clean(&config);
} }
} }
@ -59,6 +57,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
for i in 0..config.socket_workers { for i in 0..config.socket_workers {
let config = config.clone(); let config = config.clone();
let state = state.clone();
let socket_worker_statuses = socket_worker_statuses.clone(); let socket_worker_statuses = socket_worker_statuses.clone();
let in_message_sender = in_message_sender.clone(); let in_message_sender = in_message_sender.clone();
let opt_tls_acceptor = opt_tls_acceptor.clone(); let opt_tls_acceptor = opt_tls_acceptor.clone();
@ -75,6 +74,7 @@ pub fn start_workers(config: Config, state: State) -> anyhow::Result<()> {
.spawn(move || { .spawn(move || {
network::run_socket_worker( network::run_socket_worker(
config, config,
state,
i, i,
socket_worker_statuses, socket_worker_statuses,
poll, poll,

View file

@ -1,7 +1,9 @@
use std::io::ErrorKind; use std::io::ErrorKind;
use std::time::Duration; use std::time::Duration;
use std::vec::Drain;
use crossbeam_channel::Receiver; use crossbeam_channel::Receiver;
use either::Either;
use hashbrown::HashMap; use hashbrown::HashMap;
use log::{debug, error, info}; use log::{debug, error, info};
use mio::net::TcpListener; use mio::net::TcpListener;
@ -23,6 +25,7 @@ use utils::*;
pub fn run_socket_worker( pub fn run_socket_worker(
config: Config, config: Config,
state: State,
socket_worker_index: usize, socket_worker_index: usize,
socket_worker_statuses: SocketWorkerStatuses, socket_worker_statuses: SocketWorkerStatuses,
poll: Poll, poll: Poll,
@ -36,6 +39,7 @@ pub fn run_socket_worker(
run_poll_loop( run_poll_loop(
config, config,
&state,
socket_worker_index, socket_worker_index,
poll, poll,
in_message_sender, in_message_sender,
@ -53,6 +57,7 @@ pub fn run_socket_worker(
pub fn run_poll_loop( pub fn run_poll_loop(
config: Config, config: Config,
state: &State,
socket_worker_index: usize, socket_worker_index: usize,
mut poll: Poll, mut poll: Poll,
in_message_sender: InMessageSender, in_message_sender: InMessageSender,
@ -76,6 +81,7 @@ pub fn run_poll_loop(
.unwrap(); .unwrap();
let mut connections: ConnectionMap = HashMap::new(); let mut connections: ConnectionMap = HashMap::new();
let mut local_responses = Vec::new();
let mut poll_token_counter = Token(0usize); let mut poll_token_counter = Token(0usize);
let mut iter_counter = 0usize; let mut iter_counter = 0usize;
@ -100,7 +106,10 @@ pub fn run_poll_loop(
); );
} else if token != CHANNEL_TOKEN { } else if token != CHANNEL_TOKEN {
run_handshakes_and_read_messages( run_handshakes_and_read_messages(
&config,
state,
socket_worker_index, socket_worker_index,
&mut local_responses,
&in_message_sender, &in_message_sender,
&opt_tls_acceptor, &opt_tls_acceptor,
&mut poll, &mut poll,
@ -110,7 +119,12 @@ pub fn run_poll_loop(
); );
} }
send_out_messages(&mut poll, &out_message_receiver, &mut connections); send_out_messages(
&mut poll,
local_responses.drain(..),
&out_message_receiver,
&mut connections
);
} }
// Remove inactive connections, but not every iteration // Remove inactive connections, but not every iteration
@ -165,7 +179,10 @@ fn accept_new_streams(
/// On the stream given by poll_token, get TLS (if requested) and tungstenite /// On the stream given by poll_token, get TLS (if requested) and tungstenite
/// up and running, then read messages and pass on through channel. /// up and running, then read messages and pass on through channel.
pub fn run_handshakes_and_read_messages( pub fn run_handshakes_and_read_messages(
config: &Config,
state: &State,
socket_worker_index: usize, socket_worker_index: usize,
local_responses: &mut Vec<(ConnectionMeta, OutMessage)>,
in_message_sender: &InMessageSender, in_message_sender: &InMessageSender,
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS
poll: &mut Poll, poll: &mut Poll,
@ -173,6 +190,8 @@ pub fn run_handshakes_and_read_messages(
poll_token: Token, poll_token: Token,
valid_until: ValidUntil, valid_until: ValidUntil,
) { ) {
let access_list_mode = config.access_list.mode;
loop { loop {
if let Some(established_ws) = connections if let Some(established_ws) = connections
.get_mut(&poll_token) .get_mut(&poll_token)
@ -201,8 +220,31 @@ pub fn run_handshakes_and_read_messages(
debug!("read message"); debug!("read message");
if let Err(err) = in_message_sender.send((meta, in_message)) { let message = if let InMessage::AnnounceRequest(ref request) = in_message {
error!("InMessageSender: couldn't send message: {:?}", err); if state.access_list.allows(access_list_mode, &request.info_hash.0){
Either::Left(in_message)
} else {
let out_message = OutMessage::ErrorResponse(ErrorResponse {
failure_reason: "Info hash not allowed".into(),
action: Some(ErrorResponseAction::Announce),
info_hash: Some(request.info_hash),
});
Either::Right(out_message)
}
} else {
Either::Left(in_message)
};
match message {
Either::Left(in_message) => {
if let Err(err) = in_message_sender.send((meta, in_message)) {
error!("InMessageSender: couldn't send message: {:?}", err);
}
},
Either::Right(out_message) => {
local_responses.push((meta, out_message));
}
} }
} }
} }
@ -242,12 +284,13 @@ pub fn run_handshakes_and_read_messages(
/// Read messages from channel, send to peers /// Read messages from channel, send to peers
pub fn send_out_messages( pub fn send_out_messages(
poll: &mut Poll, poll: &mut Poll,
local_responses: Drain<(ConnectionMeta, OutMessage)>,
out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>, out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>,
connections: &mut ConnectionMap, connections: &mut ConnectionMap,
) { ) {
let len = out_message_receiver.len(); let len = out_message_receiver.len();
for (meta, out_message) in out_message_receiver.try_iter().take(len) { for (meta, out_message) in local_responses.chain(out_message_receiver.try_iter().take(len)) {
let opt_established_ws = connections let opt_established_ws = connections
.get_mut(&meta.poll_token) .get_mut(&meta.poll_token)
.and_then(Connection::get_established_ws); .and_then(Connection::get_established_ws);

View file

@ -4,10 +4,10 @@ use histogram::Histogram;
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
pub fn update_access_list(config: &Config, torrent_maps: &mut TorrentMaps) { pub fn update_access_list(config: &Config, state: &State) {
match config.access_list.mode { match config.access_list.mode {
AccessListMode::Require | AccessListMode::Forbid => { AccessListMode::Require | AccessListMode::Forbid => {
if let Err(err) = torrent_maps if let Err(err) = state
.access_list .access_list
.update_from_path(&config.access_list.path) .update_from_path(&config.access_list.path)
{ {