From 30a77072ed28b20eed903e5756aac8051ae49f49 Mon Sep 17 00:00:00 2001 From: postscriptum Date: Sun, 22 Mar 2026 09:18:09 +0200 Subject: [PATCH] implement in-memory ruleset update --- README.md | 7 +++++-- src/list.rs | 54 ++++++++++++++++++++++++++++++++++++---------------- src/main.rs | 50 ++++++++++++++++++++++++++++++++++-------------- src/stats.rs | 6 ++++++ 4 files changed, 85 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index cab0278..1f0dd2e 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,8 @@ Experimental async SOCKS5 (TCP/UDP) proxy server based on [fast-socks5](https:// * [ ] Range support * [ ] Local Web-API * [x] Block stats - * [ ] In-memory list update (without server restart) + * [x] In-memory list update (without server restart) + * [ ] Persist changes option * [ ] Performance optimization ## Usage @@ -18,7 +19,9 @@ RUST_LOG=psocks=trace cargo run -- -a=/path/to/allow1.txt \ no-auth ``` * set `socks5://127.0.0.1:1080` proxy in your application -* open http://127.0.0.1:8010 in browser for stats & control API +* open http://127.0.0.1:8010 in browser for global stats: + * http://127.0.0.1:8010/allow/domain.com - add rule to the current session + * http://127.0.0.1:8010/block/domain.com - delete rule from the current session ### Allow list example diff --git a/src/list.rs b/src/list.rs index 3cf626b..733b5a9 100644 --- a/src/list.rs +++ b/src/list.rs @@ -1,6 +1,7 @@ use anyhow::Result; use log::*; use std::collections::HashSet; +use tokio::sync::RwLock; #[derive(PartialEq, Eq, Hash)] pub enum Item { @@ -8,8 +9,20 @@ pub enum Item { Exact(String), } +impl Item { + pub fn from_line(rule: &str) -> Self { + if let Some(item) = rule.strip_prefix(".") { + debug!("Add `{rule}` to the whitelist"); + Self::Ending(item.to_string()) + } else { + debug!("Add `{rule}` exact match to the whitelist"); + Self::Exact(rule.to_string()) + } + } +} + pub enum List { - Allow(HashSet), + Allow(RwLock>), } impl List { @@ -26,31 +39,40 @@ impl List { if line.starts_with("/") || line.starts_with("#") || line.is_empty() { continue; } - if !this.insert(if let Some(item) = line.strip_prefix(".") { - debug!("Add `{line}` to the whitelist"); - Item::Ending(item.to_string()) - } else { - debug!("Add `{line}` exact match to the whitelist"); - Item::Exact(line.to_string()) - }) { + if !this.insert(Item::from_line(line)) { warn!("Duplicated whitelist record: `{line}`") } } } info!("Total whitelist entries parsed: {}", this.len()); - Ok(Self::Allow(this)) + Ok(Self::Allow(RwLock::new(this))) } - pub fn has(&self, value: &str) -> bool { + pub async fn any(&self, values: &[&str]) -> bool { match self { - Self::Allow(list) => list.iter().any(|item| match item { - Item::Exact(s) => s == value, - Item::Ending(s) => value.ends_with(s), - }), + Self::Allow(list) => { + let guard = list.read().await; + values.iter().any(|&value| { + guard.iter().any(|item| match item { + Item::Exact(v) => v == value, + Item::Ending(v) => value.ends_with(v), + }) + }) + } } } - pub fn entries(&self) -> usize { + pub async fn entries(&self) -> u64 { match self { - Self::Allow(list) => list.len(), + Self::Allow(list) => list.read().await.len() as u64, + } + } + pub async fn allow(&self, rule: &str) -> bool { + match self { + List::Allow(list) => list.write().await.insert(Item::from_line(rule)), + } + } + pub async fn block(&self, rule: &str) -> bool { + match self { + List::Allow(list) => list.write().await.remove(&Item::from_line(rule)), } } } diff --git a/src/main.rs b/src/main.rs index b92f320..4837a6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,28 +10,56 @@ use fast_socks5::{ use list::List; use log::*; use opt::{AuthMode, Opt}; -use rocket::serde::json::Json; +use rocket::{State, serde::json::Json}; use stats::{Snapshot, Total}; use std::{future::Future, sync::Arc}; use structopt::StructOpt; use tokio::{net::TcpListener, task}; #[rocket::get("/")] -async fn index(totals: &rocket::State>) -> Json { +async fn index(totals: &State>) -> Json { Json(totals.inner().snapshot()) } +#[rocket::get("/allow/")] +async fn allow(rule: &str, list: &State>, totals: &State>) -> Json { + let result = list.allow(rule).await; + totals.set_entries(list.entries().await); + info!("Delete `{rule}` from the in-memory rules (operation status: {result:?})"); + Json(result) +} + +#[rocket::get("/block/")] +async fn block(rule: &str, list: &State>, totals: &State>) -> Json { + let result = list.block(rule).await; + totals.set_entries(list.entries().await); + info!("Add `{rule}` to the in-memory rules (operation status: {result:?})"); + Json(result) +} + #[rocket::launch] async fn rocket() -> _ { env_logger::init(); let opt: &'static Opt = Box::leak(Box::new(Opt::from_args())); - let totals = Arc::new(Total::default()); + + let list = Arc::new( + List::from_opt(&opt.allow_list) + .await + .map_err(|err| { + error!("Can't parse whitelist: `{err}`"); + SocksError::ArgumentInputError("Can't parse whitelist") + }) + .unwrap(), + ); + + let totals = Arc::new(Total::with_entries(list.entries().await)); tokio::spawn({ + let socks_list = list.clone(); let socks_totals = totals.clone(); async move { - if let Err(err) = spawn_socks_server(opt, socks_totals).await { + if let Err(err) = spawn_socks_server(opt, socks_list, socks_totals).await { error!("SOCKS server failed: `{err}`"); } } @@ -43,11 +71,12 @@ async fn rocket() -> _ { address: opt.api_addr.ip(), ..rocket::Config::release_default() }) + .manage(list) .manage(totals) - .mount("/", rocket::routes![index]) + .mount("/", rocket::routes![index, allow, block]) } -async fn spawn_socks_server(opt: &'static Opt, totals: Arc) -> Result<()> { +async fn spawn_socks_server(opt: &'static Opt, list: Arc, totals: Arc) -> Result<()> { if opt.allow_udp && opt.public_addr.is_none() { return Err(SocksError::ArgumentInputError( "Can't allow UDP if public-addr is not set", @@ -59,13 +88,6 @@ async fn spawn_socks_server(opt: &'static Opt, totals: Arc) -> Result<()> )); } - let list = Arc::new(List::from_opt(&opt.allow_list).await.map_err(|err| { - error!("Can't parse whitelist: `{err}`"); - SocksError::ArgumentInputError("Can't parse whitelist") - })?); - - totals.set_entries(list.entries() as u64); - let listener = TcpListener::bind(&opt.listen_addr).await?; info!("Listen for socks connections @ {}", &opt.listen_addr); @@ -107,7 +129,7 @@ async fn serve_socks5( let (host, _) = request.2.clone().into_string_and_port(); // @TODO ref let (proto, cmd, addr) = request.resolve_dns().await?; - if !list.has(&host) && !list.has(&addr.to_string()) { + if !list.any(&[&host, &addr.to_string()]).await { totals.increase_blocked(); info!("Blocked connection attempt to: {host}"); proto.reply_error(&ReplyError::ConnectionNotAllowed).await?; diff --git a/src/stats.rs b/src/stats.rs index 2aaf6f8..52e480b 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -16,6 +16,12 @@ pub struct Total { } impl Total { + pub fn with_entries(entries: u64) -> Self { + Self { + entries: entries.into(), + ..Self::default() + } + } pub fn snapshot(&self) -> Snapshot { Snapshot { request: self.request.load(Ordering::Relaxed),