diff --git a/Cargo.lock b/Cargo.lock index f483c42..d765d5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -455,7 +455,7 @@ dependencies = [ "atomic 0.6.1", "pear", "serde", - "toml", + "toml 0.8.23", "uncased", "version_check", ] @@ -1438,6 +1438,8 @@ dependencies = [ "serde", "structopt", "tokio", + "toml 1.1.0+spec-1.1.0", + "url", ] [[package]] @@ -1979,6 +1981,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876ac351060d4f882bb1032b6369eb0aef79ad9df1ea8bc404874d8cc3d0cd98" +dependencies = [ + "serde_core", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2361,11 +2372,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" dependencies = [ "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", "toml_edit", ] +[[package]] +name = "toml" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8195ca05e4eb728f4ba94f3e3291661320af739c4e43779cbdfae82ab239fcc" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned 1.1.0", + "toml_datetime 1.1.0+spec-1.1.0", + "toml_parser", + "toml_writer", + "winnow 1.0.0", +] + [[package]] name = "toml_datetime" version = "0.6.11" @@ -2375,6 +2401,15 @@ dependencies = [ "serde", ] +[[package]] +name = "toml_datetime" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97251a7c317e03ad83774a8752a7e81fb6067740609f75ea2b585b569a59198f" +dependencies = [ + "serde_core", +] + [[package]] name = "toml_edit" version = "0.22.27" @@ -2383,10 +2418,19 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ "indexmap", "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", "toml_write", - "winnow", + "winnow 0.7.15", +] + +[[package]] +name = "toml_parser" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2334f11ee363607eb04df9b8fc8a13ca1715a72ba8662a26ac285c98aabb4011" +dependencies = [ + "winnow 1.0.0", ] [[package]] @@ -2395,6 +2439,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "toml_writer" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed" + [[package]] name = "tower" version = "0.5.3" @@ -3126,6 +3176,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index 0521a3e..31b1079 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,5 @@ rocket = { version = "0.5.1", features = ["json"] } serde = { version = "1.0.228", features = ["derive"] } structopt = "0.3.26" tokio = { version = "1", features = ["full"] } +toml = "1.1.0" +url = "2.5.8" diff --git a/README.md b/README.md index e1b66ef..82bbf81 100644 --- a/README.md +++ b/README.md @@ -16,19 +16,13 @@ Filtering asynchronous SOCKS5 (TCP/UDP) proxy server based on [fast-socks5](http ## Usage ``` bash -RUST_LOG=trace cargo run -- --allow=http://localhost/allow.txt \ - --allow=/path/to/allow.txt \ - no-auth +RUST_LOG=trace cargo run -- -c=/path/to/config.toml no-auth ``` * set `socks5://127.0.0.1:1080` proxy in your application * use http://127.0.0.1:8010 for API: - * `/api/allow/` - add rule to the current session - * `/api/block/` - delete rule from the current session - * `/api/rules` - return active rules (from server memory) - * `/api/lists` - get parsed lists with its ID - * `/api/list/` - get all parsed rules for list ID (see `/api/lists`) - * `/api/list/enable/` - enable all parsed rules of given list ID (see `/api/lists`) - * `/api/list/disable/` - disable all parsed rules of given list ID (see `/api/lists`) + * `/api/totals` - blocking summary + * `/api/list/enable/` - enable all parsed rules of given list ID (`[list.ID]` in your config) + * `/api/list/disable/` - disable all parsed rules of given list ID (`[list.ID]` in your config) ### Allow list example @@ -53,6 +47,7 @@ git clone https://codeberg.org/postscriptum/psocks.git cd psocks cargo build --release --locked sudo install target/release/psocks /usr/local/bin +sudo cp example/config.toml /etc/psocks.toml sudo useradd -s /usr/sbin/nologin -Mr psocks sudo mkdir /var/log/psocks && sudo chown psocks:psocks /var/log/psocks ``` @@ -68,9 +63,7 @@ Wants=network-online.target User=psocks Group=psocks -ExecStart=/usr/local/bin/psocks \ - -a=http://localhost/allow.txt \ - no-auth +ExecStart=/usr/local/bin/psocks -c=/etc/psocks.toml no-auth Restart=always diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..85667ec --- /dev/null +++ b/src/config.rs @@ -0,0 +1,13 @@ +use serde::Deserialize; +use std::collections::HashMap; + +#[derive(Deserialize)] +pub struct List { + pub is_enabled: bool, + pub source: String, +} + +#[derive(Deserialize)] +pub struct Config { + pub list: HashMap, +} diff --git a/src/example/config.toml b/src/example/config.toml new file mode 100644 index 0000000..d2cf16b --- /dev/null +++ b/src/example/config.toml @@ -0,0 +1,11 @@ +[list.google] +is_enabled = true +source = "https://codeberg.org/postscriptum/psocks-list/src/branch/main/allow/google.txt" + +[list.github] +is_enabled = false +source = "https://codeberg.org/postscriptum/psocks-list/src/branch/main/allow/github.txt" + +#[[list.common]] +#is_enabled = false +#source = "/path/to/common.txt" \ No newline at end of file diff --git a/src/example/list/github.txt b/src/example/list/github.txt new file mode 100644 index 0000000..0988718 --- /dev/null +++ b/src/example/list/github.txt @@ -0,0 +1,6 @@ +// Github list example + +.github.com +.github.io +.githubassets.com +.githubusercontent.com \ No newline at end of file diff --git a/src/example/list/google.txt b/src/example/list/google.txt new file mode 100644 index 0000000..01f17dd --- /dev/null +++ b/src/example/list/google.txt @@ -0,0 +1,13 @@ +// Google list example + +.gmail.com +.google +.google.ch +.google.com +.google.com.ua +.google.dev +.googlevideo.com +.gstatic.com +.withgoogle.com +.youtube.com +.ytimg.com \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index b945436..52c32c4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,10 @@ +mod config; mod opt; mod rules; mod stats; use anyhow::Context; +use config::Config; use fast_socks5::{ ReplyError, Result, Socks5Command, SocksError, server::{DnsResolveHelper as _, Socks5ServerProtocol, run_tcp_proxy, run_udp_proxy}, @@ -14,7 +16,7 @@ use rules::Rules; use stats::{Snap, Total}; use std::{future::Future, sync::Arc, time::Instant}; use structopt::StructOpt; -use tokio::{net::TcpListener, task}; +use tokio::{net::TcpListener, sync::RwLock, task}; #[rocket::get("/")] async fn index(totals: &State>, startup_time: &State) -> Json { @@ -26,79 +28,20 @@ async fn api_totals(totals: &State>, startup_time: &State) - Json(totals.inner().snap(startup_time.elapsed().as_secs())) } -#[rocket::get("/api/allow/")] -async fn api_allow( - rule: &str, - rules: &State>, - totals: &State>, -) -> Result, Status> { - let result = rules.allow(rule).await; - totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals - info!("Delete `{rule}` from the in-memory rules (operation status: {result:?})"); - Ok(Json(result)) -} - -#[rocket::get("/api/block/")] -async fn api_block( - rule: &str, - rules: &State>, - totals: &State>, -) -> Result, Status> { - let result = rules.block(rule).await; - totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals - info!("Add `{rule}` to the in-memory rules (operation status: {result:?})"); - Ok(Json(result)) -} - -#[rocket::get("/api/rules")] -async fn api_rules(rules: &State>) -> Result>, Status> { - let active = rules.active().await; - debug!("Get rules (total: {})", active.len()); - Ok(Json(active)) -} - -#[rocket::get("/api/lists")] -async fn api_lists(rules: &State>) -> Result>, Status> { - let lists = rules.lists().await; - debug!("Get lists index (total: {})", lists.len()); - Ok(Json(lists)) -} - -#[rocket::get("/api/list/")] -async fn api_list( - id: usize, - rules: &State>, -) -> Result>, Status> { - let list = rules.list(&id).await; - debug!( - "Get list #{id} rules (total: {:?})", - list.as_ref().map(|l| l.items.len()) - ); - Ok(Json(list)) -} - -#[rocket::get("/api/list/enable/")] +#[rocket::get("/api/list/enable/")] async fn api_list_enable( - id: usize, - rules: &State>, - totals: &State>, -) -> Result>, Status> { - let affected = rules.enable(&id, true).await; - totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals - info!("Enabled {affected:?} rules from the active rule set"); - Ok(Json(affected)) // @TODO handle empty result + alias: &str, + rules: &State>>, +) -> Result, Status> { + Ok(Json(rules.write().await.set_status(alias, true))) } -#[rocket::get("/api/list/disable/")] +#[rocket::get("/api/list/disable/")] async fn api_list_disable( - id: usize, - rules: &State>, - totals: &State>, -) -> Result>, Status> { - let affected = rules.enable(&id, false).await; - totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals - info!("Disabled {affected:?} rules from the active rule set"); - Ok(Json(affected)) // @TODO handle empty result + alias: &str, + rules: &State>>, +) -> Result, Status> { + Ok(Json(rules.write().await.set_status(alias, false))) } #[rocket::launch] @@ -106,10 +49,20 @@ async fn rocket() -> _ { env_logger::init(); let opt: &'static Opt = Box::leak(Box::new(Opt::from_args())); - - let rules = Arc::new(Rules::from_opt(&opt.allow_list).await.unwrap()); - - let totals = Arc::new(Total::with_rules(rules.total(true).await)); // @TODO separate active/inactive totals + let config: Config = toml::from_str(&std::fs::read_to_string(&opt.config).unwrap()).unwrap(); + let totals = Arc::new(Total::default()); + let rules = Arc::new(RwLock::new({ + let mut rules = Rules::new(); + for (alias, rule) in config.list { + assert!( + rules + .push(&alias, &rule.source, rule.is_enabled) + .await + .unwrap() + ) + } + rules + })); tokio::spawn({ let socks_rules = rules.clone(); @@ -132,23 +85,13 @@ async fn rocket() -> _ { .manage(Instant::now()) .mount( "/", - rocket::routes![ - index, - api_totals, - api_allow, - api_block, - api_rules, - api_lists, - api_list, - api_list_enable, - api_list_disable - ], + rocket::routes![index, api_totals, api_list_enable, api_list_disable], ) } async fn spawn_socks_server( opt: &'static Opt, - rules: Arc, + rules: Arc>, totals: Arc, ) -> Result<()> { if opt.allow_udp && opt.public_addr.is_none() { @@ -179,7 +122,7 @@ async fn spawn_socks_server( async fn serve_socks5( opt: &Opt, socket: tokio::net::TcpStream, - rules: Arc, + rules: Arc>, totals: Arc, ) -> Result<(), SocksError> { totals.increase_request(); @@ -201,7 +144,7 @@ async fn serve_socks5( let (host, _) = request.2.clone().into_string_and_port(); - if !rules.any(&host).await { + if !rules.read().await.any(&host) { totals.increase_blocked(); info!("Blocked connection attempt to: {host}"); request diff --git a/src/opt.rs b/src/opt.rs index 5a26c13..db8819f 100644 --- a/src/opt.rs +++ b/src/opt.rs @@ -1,4 +1,4 @@ -use std::{net::SocketAddr, num::ParseFloatError, time::Duration}; +use std::{net::SocketAddr, num::ParseFloatError, path::PathBuf, time::Duration}; use structopt::StructOpt; /// # How to use it: @@ -45,11 +45,9 @@ pub struct Opt { #[structopt(short = "U", long)] pub allow_udp: bool, - /// Allow list: - /// * local filename - /// * remote URL - #[structopt(short = "a", long)] - pub allow_list: Vec, + /// Path to config file + #[structopt(short = "c", long)] + pub config: PathBuf, } /// Choose the authentication type diff --git a/src/rules.rs b/src/rules.rs index 619d22a..3d823ea 100644 --- a/src/rules.rs +++ b/src/rules.rs @@ -1,162 +1,39 @@ -mod item; +mod list; -use anyhow::{Result, bail}; -use item::Item; -use log::*; -use std::collections::{HashMap, HashSet}; -use tokio::{fs, sync::RwLock}; +use anyhow::Result; +use list::List; +use std::collections::HashMap; -pub struct Rules(RwLock>); +pub struct Rules(HashMap); impl Rules { - pub async fn from_opt(list: &[String]) -> Result { - let mut index = HashMap::with_capacity(list.len()); - assert!( - index - .insert( - 0, - List { - alias: "session".into(), - items: HashSet::new(), - status: true - } - ) - .is_none() - ); - for (i, l) in list.iter().enumerate() { - let mut items = HashSet::new(); - for line in if l.contains("://") { - let response = reqwest::get(l).await?; - let status = response.status(); - if status.is_success() { - response.text().await? - } else { - bail!("Could not receive remote list `{l}`: `{status}`") - } - } else { - fs::read_to_string(l).await? - } - .lines() - { - if line.starts_with("/") || line.starts_with("#") || line.is_empty() { - continue; // skip comments - } - if !items.insert(Item::from_line(line)) { - warn!("List `{l}` contains duplicated entry: `{line}`") - } - } - assert!( - index - .insert( - i + 1, - List { - alias: l.clone(), - items, - status: true // @TODO implement config file - } - ) - .is_none() - ) - } - Ok(Self(RwLock::new(index))) + pub fn new() -> Self { + Self(HashMap::new()) } - /// Check if rule is exist in the (allow) index - pub async fn any(&self, value: &str) -> bool { - self.0.read().await.values().any(|list| list.any(value)) - } - /// Get total rules from the current session - pub async fn total(&self, status: bool) -> u64 { - self.0 - .read() - .await - .values() - .filter(|list| list.status == status) - .map(|list| list.total()) - .sum() - } - /// Allow given `rule`(in the session index) - /// * return `false` if the `rule` is exist - pub async fn allow(&self, rule: &str) -> bool { - self.0 - .write() - .await - .get_mut(&0) - .unwrap() - .items - .insert(Item::from_line(rule)) - } - /// Block given `rule` (in the session index) - pub async fn block(&self, rule: &str) { - self.0 - .write() - .await - .get_mut(&0) - .unwrap() - .items - .retain(|item| rule == item.as_str()) - } - /// Return active rules - pub async fn active(&self) -> Vec { - let mut rules: Vec = self + pub async fn push(&mut self, list_alias: &str, src: &str, status: bool) -> Result { + Ok(self .0 - .read() - .await - .values() - .filter(|list| list.status) - .flat_map(|list| list.items.iter().map(|item| item.to_string())) - .collect(); - rules.sort(); // HashSet does not keep the order - rules - } - /// Return list references - pub async fn lists(&self) -> Vec { - let this = self.0.read().await; - let mut list = Vec::with_capacity(this.len()); - for l in this.iter() { - list.push(ListEntry { - id: *l.0, - alias: l.1.alias.clone(), - }) - } - list - } - /// Return original list references - pub async fn list(&self, id: &usize) -> Option { - self.0.read().await.get(id).cloned() + .insert(list_alias.into(), List::init(src, status).await?) + .is_none()) } /// Change rule set status by list ID - pub async fn enable(&self, list_id: &usize, status: bool) -> Option<()> { + pub fn set_status(&mut self, list_alias: &str, status: bool) -> bool { self.0 - .write() - .await - .get_mut(list_id) - .map(|this| this.status = status) + .get_mut(list_alias) + .map(|list| list.set_status(status)) + .is_some() } -} - -#[derive(serde::Serialize)] -pub struct ListEntry { - pub id: usize, - pub alias: String, -} - -#[derive(serde::Serialize, Clone)] -pub struct List { - pub alias: String, - pub items: HashSet, - pub status: bool, -} - -impl List { - /// Check if rule is exist in the items index + /// Check if rule is exist in the index pub fn any(&self, value: &str) -> bool { - self.items.iter().any(|item| match item { - Item::Exact(v) => v == value, - Item::Ending(v) => value.ends_with(v), - }) - } - /// Get total rules in list - pub fn total(&self) -> u64 { - self.items.len() as u64 + self.0.values().any(|list| list.contains(value)) } + /* + /// Get total rules in session by `status` + pub fn total(&self, status: bool) -> u64 { + self.0 + .values() + .filter(|list| status == list.is_enabled()) + .map(|list| list.total()) + .sum() + }*/ } diff --git a/src/rules/list.rs b/src/rules/list.rs new file mode 100644 index 0000000..b9a9c1e --- /dev/null +++ b/src/rules/list.rs @@ -0,0 +1,44 @@ +mod rule; +mod source; + +use anyhow::Result; +use log::warn; +use rule::Rule; +use source::Source; +use std::collections::HashSet; + +pub struct List { + pub is_enabled: bool, + //pub source: Source, + pub rules: HashSet, +} + +impl List { + pub async fn init(src: &str, is_enabled: bool) -> Result { + let mut rules = HashSet::new(); + let source = Source::from_str(src)?; + if is_enabled { + for line in source.get().await?.lines() { + if line.starts_with("/") || line.starts_with("#") || line.is_empty() { + continue; // skip comments + } + if !rules.insert(Rule::from_line(line)) { + warn!("List `{src}` contains duplicated entry: `{line}`") + } + } + } + Ok(Self { + rules, + //source, + is_enabled, + }) + } + /// Change rule set status by list ID + pub fn set_status(&mut self, is_enabled: bool) { + self.is_enabled = is_enabled; + } + /// Check if rule is exist in the items index + pub fn contains(&self, value: &str) -> bool { + self.is_enabled && self.rules.iter().any(|rule| rule.contains(value)) + } +} diff --git a/src/rules/item.rs b/src/rules/list/rule.rs similarity index 63% rename from src/rules/item.rs rename to src/rules/list/rule.rs index 23678fc..c88ff56 100644 --- a/src/rules/item.rs +++ b/src/rules/list/rule.rs @@ -1,12 +1,12 @@ use log::debug; -#[derive(PartialEq, Eq, Hash, Clone, serde::Serialize)] -pub enum Item { +#[derive(PartialEq, Eq, Hash)] +pub enum Rule { Ending(String), Exact(String), } -impl Item { +impl Rule { pub fn from_line(rule: &str) -> Self { if let Some(item) = rule.strip_prefix(".") { debug!("Init `{rule}` rule"); @@ -16,14 +16,20 @@ impl Item { Self::Exact(rule.to_string()) } } - pub fn as_str(&self) -> &str { + /*pub fn as_str(&self) -> &str { match self { - Item::Ending(s) | Item::Exact(s) => s, + Self::Ending(s) | Self::Exact(s) => s, + } + }*/ + pub fn contains(&self, value: &str) -> bool { + match self { + Rule::Exact(v) => v == value, + Rule::Ending(v) => value.ends_with(v), } } } -impl std::fmt::Display for Item { +impl std::fmt::Display for Rule { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Ending(s) => write!(f, ".{}", s), diff --git a/src/rules/list/source.rs b/src/rules/list/source.rs new file mode 100644 index 0000000..f2e63c6 --- /dev/null +++ b/src/rules/list/source.rs @@ -0,0 +1,40 @@ +use anyhow::{Result, bail}; +use std::{path::PathBuf, str::FromStr}; +use url::Url; + +pub enum Source { + Path(PathBuf), + Url(Url), +} + +impl Source { + pub fn from_str(source: &str) -> Result { + Ok(if source.contains("://") { + Self::Url(Url::from_str(source)?) + } else { + Self::Path(PathBuf::from_str(source)?.canonicalize()?) + }) + } + pub async fn get(&self) -> Result { + Ok(match self { + Source::Path(path) => tokio::fs::read_to_string(path).await?, + Source::Url(url) => { + let request = url.as_str(); + let response = reqwest::get(request).await?; + let status = response.status(); + if status.is_success() { + response.text().await? + } else { + bail!("Could not receive remote list `{request}`: `{status}`") + } + } + }) + } +} + +impl FromStr for Source { + type Err = anyhow::Error; + fn from_str(s: &str) -> Result { + Source::from_str(s) + } +} diff --git a/src/stats.rs b/src/stats.rs index e79e1e6..12a8a54 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -6,29 +6,18 @@ pub use snap::Snap; #[derive(Default)] pub struct Total { - entries: AtomicU64, blocked: AtomicU64, request: AtomicU64, } impl Total { - pub fn with_rules(entries: u64) -> Self { - Self { - entries: entries.into(), - ..Self::default() - } - } pub fn snap(&self, seconds_from_startup: u64) -> Snap { Snap::shot( - self.entries.load(Ordering::Relaxed), self.request.load(Ordering::Relaxed), self.blocked.load(Ordering::Relaxed), seconds_from_startup, ) } - pub fn set_entries(&self, value: u64) { - self.entries.store(value, Ordering::Relaxed); - } pub fn increase_blocked(&self) { self.blocked.fetch_add(1, Ordering::Relaxed); } diff --git a/src/stats/snap.rs b/src/stats/snap.rs index 3d26fe5..30d25a6 100644 --- a/src/stats/snap.rs +++ b/src/stats/snap.rs @@ -40,20 +40,18 @@ pub struct Request { #[derive(Serialize)] pub struct Snap { - rules: u64, request: Request, up: Up, } impl Snap { - pub fn shot(rules: u64, request: u64, blocked: u64, seconds_from_startup: u64) -> Self { + pub fn shot(request: u64, blocked: u64, seconds_from_startup: u64) -> Self { let blocked_percent = if request > 0 { blocked as f32 * 100.0 / request as f32 } else { 0.0 }; Self { - rules, request: Request { total: request, allowed: Sum {