grand refactory to multiple list-based control api

This commit is contained in:
yggverse 2026-03-28 04:42:41 +02:00
parent 827cb182f2
commit b93c1e8481
15 changed files with 271 additions and 282 deletions

13
src/config.rs Normal file
View file

@ -0,0 +1,13 @@
use serde::Deserialize;
use std::collections::HashMap;
#[derive(Deserialize)]
pub struct List {
pub is_enabled: bool,
pub source: String,
}
#[derive(Deserialize)]
pub struct Config {
pub list: HashMap<String, List>,
}

11
src/example/config.toml Normal file
View file

@ -0,0 +1,11 @@
[list.google]
is_enabled = true
source = "https://codeberg.org/postscriptum/psocks-list/src/branch/main/allow/google.txt"
[list.github]
is_enabled = false
source = "https://codeberg.org/postscriptum/psocks-list/src/branch/main/allow/github.txt"
#[[list.common]]
#is_enabled = false
#source = "/path/to/common.txt"

View file

@ -0,0 +1,6 @@
// Github list example
.github.com
.github.io
.githubassets.com
.githubusercontent.com

View file

@ -0,0 +1,13 @@
// Google list example
.gmail.com
.google
.google.ch
.google.com
.google.com.ua
.google.dev
.googlevideo.com
.gstatic.com
.withgoogle.com
.youtube.com
.ytimg.com

View file

@ -1,8 +1,10 @@
mod config;
mod opt;
mod rules;
mod stats;
use anyhow::Context;
use config::Config;
use fast_socks5::{
ReplyError, Result, Socks5Command, SocksError,
server::{DnsResolveHelper as _, Socks5ServerProtocol, run_tcp_proxy, run_udp_proxy},
@ -14,7 +16,7 @@ use rules::Rules;
use stats::{Snap, Total};
use std::{future::Future, sync::Arc, time::Instant};
use structopt::StructOpt;
use tokio::{net::TcpListener, task};
use tokio::{net::TcpListener, sync::RwLock, task};
#[rocket::get("/")]
async fn index(totals: &State<Arc<Total>>, startup_time: &State<Instant>) -> Json<Snap> {
@ -26,79 +28,20 @@ async fn api_totals(totals: &State<Arc<Total>>, startup_time: &State<Instant>) -
Json(totals.inner().snap(startup_time.elapsed().as_secs()))
}
#[rocket::get("/api/allow/<rule>")]
async fn api_allow(
rule: &str,
rules: &State<Arc<Rules>>,
totals: &State<Arc<Total>>,
) -> Result<Json<bool>, Status> {
let result = rules.allow(rule).await;
totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals
info!("Delete `{rule}` from the in-memory rules (operation status: {result:?})");
Ok(Json(result))
}
#[rocket::get("/api/block/<rule>")]
async fn api_block(
rule: &str,
rules: &State<Arc<Rules>>,
totals: &State<Arc<Total>>,
) -> Result<Json<()>, Status> {
let result = rules.block(rule).await;
totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals
info!("Add `{rule}` to the in-memory rules (operation status: {result:?})");
Ok(Json(result))
}
#[rocket::get("/api/rules")]
async fn api_rules(rules: &State<Arc<Rules>>) -> Result<Json<Vec<String>>, Status> {
let active = rules.active().await;
debug!("Get rules (total: {})", active.len());
Ok(Json(active))
}
#[rocket::get("/api/lists")]
async fn api_lists(rules: &State<Arc<Rules>>) -> Result<Json<Vec<rules::ListEntry>>, Status> {
let lists = rules.lists().await;
debug!("Get lists index (total: {})", lists.len());
Ok(Json(lists))
}
#[rocket::get("/api/list/<id>")]
async fn api_list(
id: usize,
rules: &State<Arc<Rules>>,
) -> Result<Json<Option<rules::List>>, Status> {
let list = rules.list(&id).await;
debug!(
"Get list #{id} rules (total: {:?})",
list.as_ref().map(|l| l.items.len())
);
Ok(Json(list))
}
#[rocket::get("/api/list/enable/<id>")]
#[rocket::get("/api/list/enable/<alias>")]
async fn api_list_enable(
id: usize,
rules: &State<Arc<Rules>>,
totals: &State<Arc<Total>>,
) -> Result<Json<Option<()>>, Status> {
let affected = rules.enable(&id, true).await;
totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals
info!("Enabled {affected:?} rules from the active rule set");
Ok(Json(affected)) // @TODO handle empty result
alias: &str,
rules: &State<Arc<RwLock<Rules>>>,
) -> Result<Json<bool>, Status> {
Ok(Json(rules.write().await.set_status(alias, true)))
}
#[rocket::get("/api/list/disable/<id>")]
#[rocket::get("/api/list/disable/<alias>")]
async fn api_list_disable(
id: usize,
rules: &State<Arc<Rules>>,
totals: &State<Arc<Total>>,
) -> Result<Json<Option<()>>, Status> {
let affected = rules.enable(&id, false).await;
totals.set_entries(rules.total(true).await); // @TODO separate active/inactive totals
info!("Disabled {affected:?} rules from the active rule set");
Ok(Json(affected)) // @TODO handle empty result
alias: &str,
rules: &State<Arc<RwLock<Rules>>>,
) -> Result<Json<bool>, Status> {
Ok(Json(rules.write().await.set_status(alias, false)))
}
#[rocket::launch]
@ -106,10 +49,20 @@ async fn rocket() -> _ {
env_logger::init();
let opt: &'static Opt = Box::leak(Box::new(Opt::from_args()));
let rules = Arc::new(Rules::from_opt(&opt.allow_list).await.unwrap());
let totals = Arc::new(Total::with_rules(rules.total(true).await)); // @TODO separate active/inactive totals
let config: Config = toml::from_str(&std::fs::read_to_string(&opt.config).unwrap()).unwrap();
let totals = Arc::new(Total::default());
let rules = Arc::new(RwLock::new({
let mut rules = Rules::new();
for (alias, rule) in config.list {
assert!(
rules
.push(&alias, &rule.source, rule.is_enabled)
.await
.unwrap()
)
}
rules
}));
tokio::spawn({
let socks_rules = rules.clone();
@ -132,23 +85,13 @@ async fn rocket() -> _ {
.manage(Instant::now())
.mount(
"/",
rocket::routes![
index,
api_totals,
api_allow,
api_block,
api_rules,
api_lists,
api_list,
api_list_enable,
api_list_disable
],
rocket::routes![index, api_totals, api_list_enable, api_list_disable],
)
}
async fn spawn_socks_server(
opt: &'static Opt,
rules: Arc<Rules>,
rules: Arc<RwLock<Rules>>,
totals: Arc<Total>,
) -> Result<()> {
if opt.allow_udp && opt.public_addr.is_none() {
@ -179,7 +122,7 @@ async fn spawn_socks_server(
async fn serve_socks5(
opt: &Opt,
socket: tokio::net::TcpStream,
rules: Arc<Rules>,
rules: Arc<RwLock<Rules>>,
totals: Arc<Total>,
) -> Result<(), SocksError> {
totals.increase_request();
@ -201,7 +144,7 @@ async fn serve_socks5(
let (host, _) = request.2.clone().into_string_and_port();
if !rules.any(&host).await {
if !rules.read().await.any(&host) {
totals.increase_blocked();
info!("Blocked connection attempt to: {host}");
request

View file

@ -1,4 +1,4 @@
use std::{net::SocketAddr, num::ParseFloatError, time::Duration};
use std::{net::SocketAddr, num::ParseFloatError, path::PathBuf, time::Duration};
use structopt::StructOpt;
/// # How to use it:
@ -45,11 +45,9 @@ pub struct Opt {
#[structopt(short = "U", long)]
pub allow_udp: bool,
/// Allow list:
/// * local filename
/// * remote URL
#[structopt(short = "a", long)]
pub allow_list: Vec<String>,
/// Path to config file
#[structopt(short = "c", long)]
pub config: PathBuf,
}
/// Choose the authentication type

View file

@ -1,162 +1,39 @@
mod item;
mod list;
use anyhow::{Result, bail};
use item::Item;
use log::*;
use std::collections::{HashMap, HashSet};
use tokio::{fs, sync::RwLock};
use anyhow::Result;
use list::List;
use std::collections::HashMap;
pub struct Rules(RwLock<HashMap<usize, List>>);
pub struct Rules(HashMap<String, List>);
impl Rules {
pub async fn from_opt(list: &[String]) -> Result<Self> {
let mut index = HashMap::with_capacity(list.len());
assert!(
index
.insert(
0,
List {
alias: "session".into(),
items: HashSet::new(),
status: true
}
)
.is_none()
);
for (i, l) in list.iter().enumerate() {
let mut items = HashSet::new();
for line in if l.contains("://") {
let response = reqwest::get(l).await?;
let status = response.status();
if status.is_success() {
response.text().await?
} else {
bail!("Could not receive remote list `{l}`: `{status}`")
}
} else {
fs::read_to_string(l).await?
}
.lines()
{
if line.starts_with("/") || line.starts_with("#") || line.is_empty() {
continue; // skip comments
}
if !items.insert(Item::from_line(line)) {
warn!("List `{l}` contains duplicated entry: `{line}`")
}
}
assert!(
index
.insert(
i + 1,
List {
alias: l.clone(),
items,
status: true // @TODO implement config file
}
)
.is_none()
)
}
Ok(Self(RwLock::new(index)))
pub fn new() -> Self {
Self(HashMap::new())
}
/// Check if rule is exist in the (allow) index
pub async fn any(&self, value: &str) -> bool {
self.0.read().await.values().any(|list| list.any(value))
}
/// Get total rules from the current session
pub async fn total(&self, status: bool) -> u64 {
self.0
.read()
.await
.values()
.filter(|list| list.status == status)
.map(|list| list.total())
.sum()
}
/// Allow given `rule`(in the session index)
/// * return `false` if the `rule` is exist
pub async fn allow(&self, rule: &str) -> bool {
self.0
.write()
.await
.get_mut(&0)
.unwrap()
.items
.insert(Item::from_line(rule))
}
/// Block given `rule` (in the session index)
pub async fn block(&self, rule: &str) {
self.0
.write()
.await
.get_mut(&0)
.unwrap()
.items
.retain(|item| rule == item.as_str())
}
/// Return active rules
pub async fn active(&self) -> Vec<String> {
let mut rules: Vec<String> = self
pub async fn push(&mut self, list_alias: &str, src: &str, status: bool) -> Result<bool> {
Ok(self
.0
.read()
.await
.values()
.filter(|list| list.status)
.flat_map(|list| list.items.iter().map(|item| item.to_string()))
.collect();
rules.sort(); // HashSet does not keep the order
rules
}
/// Return list references
pub async fn lists(&self) -> Vec<ListEntry> {
let this = self.0.read().await;
let mut list = Vec::with_capacity(this.len());
for l in this.iter() {
list.push(ListEntry {
id: *l.0,
alias: l.1.alias.clone(),
})
}
list
}
/// Return original list references
pub async fn list(&self, id: &usize) -> Option<List> {
self.0.read().await.get(id).cloned()
.insert(list_alias.into(), List::init(src, status).await?)
.is_none())
}
/// Change rule set status by list ID
pub async fn enable(&self, list_id: &usize, status: bool) -> Option<()> {
pub fn set_status(&mut self, list_alias: &str, status: bool) -> bool {
self.0
.write()
.await
.get_mut(list_id)
.map(|this| this.status = status)
.get_mut(list_alias)
.map(|list| list.set_status(status))
.is_some()
}
}
#[derive(serde::Serialize)]
pub struct ListEntry {
pub id: usize,
pub alias: String,
}
#[derive(serde::Serialize, Clone)]
pub struct List {
pub alias: String,
pub items: HashSet<Item>,
pub status: bool,
}
impl List {
/// Check if rule is exist in the items index
/// Check if rule is exist in the index
pub fn any(&self, value: &str) -> bool {
self.items.iter().any(|item| match item {
Item::Exact(v) => v == value,
Item::Ending(v) => value.ends_with(v),
})
}
/// Get total rules in list
pub fn total(&self) -> u64 {
self.items.len() as u64
self.0.values().any(|list| list.contains(value))
}
/*
/// Get total rules in session by `status`
pub fn total(&self, status: bool) -> u64 {
self.0
.values()
.filter(|list| status == list.is_enabled())
.map(|list| list.total())
.sum()
}*/
}

44
src/rules/list.rs Normal file
View file

@ -0,0 +1,44 @@
mod rule;
mod source;
use anyhow::Result;
use log::warn;
use rule::Rule;
use source::Source;
use std::collections::HashSet;
pub struct List {
pub is_enabled: bool,
//pub source: Source,
pub rules: HashSet<Rule>,
}
impl List {
pub async fn init(src: &str, is_enabled: bool) -> Result<Self> {
let mut rules = HashSet::new();
let source = Source::from_str(src)?;
if is_enabled {
for line in source.get().await?.lines() {
if line.starts_with("/") || line.starts_with("#") || line.is_empty() {
continue; // skip comments
}
if !rules.insert(Rule::from_line(line)) {
warn!("List `{src}` contains duplicated entry: `{line}`")
}
}
}
Ok(Self {
rules,
//source,
is_enabled,
})
}
/// Change rule set status by list ID
pub fn set_status(&mut self, is_enabled: bool) {
self.is_enabled = is_enabled;
}
/// Check if rule is exist in the items index
pub fn contains(&self, value: &str) -> bool {
self.is_enabled && self.rules.iter().any(|rule| rule.contains(value))
}
}

View file

@ -1,12 +1,12 @@
use log::debug;
#[derive(PartialEq, Eq, Hash, Clone, serde::Serialize)]
pub enum Item {
#[derive(PartialEq, Eq, Hash)]
pub enum Rule {
Ending(String),
Exact(String),
}
impl Item {
impl Rule {
pub fn from_line(rule: &str) -> Self {
if let Some(item) = rule.strip_prefix(".") {
debug!("Init `{rule}` rule");
@ -16,14 +16,20 @@ impl Item {
Self::Exact(rule.to_string())
}
}
pub fn as_str(&self) -> &str {
/*pub fn as_str(&self) -> &str {
match self {
Item::Ending(s) | Item::Exact(s) => s,
Self::Ending(s) | Self::Exact(s) => s,
}
}*/
pub fn contains(&self, value: &str) -> bool {
match self {
Rule::Exact(v) => v == value,
Rule::Ending(v) => value.ends_with(v),
}
}
}
impl std::fmt::Display for Item {
impl std::fmt::Display for Rule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Ending(s) => write!(f, ".{}", s),

40
src/rules/list/source.rs Normal file
View file

@ -0,0 +1,40 @@
use anyhow::{Result, bail};
use std::{path::PathBuf, str::FromStr};
use url::Url;
pub enum Source {
Path(PathBuf),
Url(Url),
}
impl Source {
pub fn from_str(source: &str) -> Result<Self> {
Ok(if source.contains("://") {
Self::Url(Url::from_str(source)?)
} else {
Self::Path(PathBuf::from_str(source)?.canonicalize()?)
})
}
pub async fn get(&self) -> Result<String> {
Ok(match self {
Source::Path(path) => tokio::fs::read_to_string(path).await?,
Source::Url(url) => {
let request = url.as_str();
let response = reqwest::get(request).await?;
let status = response.status();
if status.is_success() {
response.text().await?
} else {
bail!("Could not receive remote list `{request}`: `{status}`")
}
}
})
}
}
impl FromStr for Source {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Source::from_str(s)
}
}

View file

@ -6,29 +6,18 @@ pub use snap::Snap;
#[derive(Default)]
pub struct Total {
entries: AtomicU64,
blocked: AtomicU64,
request: AtomicU64,
}
impl Total {
pub fn with_rules(entries: u64) -> Self {
Self {
entries: entries.into(),
..Self::default()
}
}
pub fn snap(&self, seconds_from_startup: u64) -> Snap {
Snap::shot(
self.entries.load(Ordering::Relaxed),
self.request.load(Ordering::Relaxed),
self.blocked.load(Ordering::Relaxed),
seconds_from_startup,
)
}
pub fn set_entries(&self, value: u64) {
self.entries.store(value, Ordering::Relaxed);
}
pub fn increase_blocked(&self) {
self.blocked.fetch_add(1, Ordering::Relaxed);
}

View file

@ -40,20 +40,18 @@ pub struct Request {
#[derive(Serialize)]
pub struct Snap {
rules: u64,
request: Request,
up: Up,
}
impl Snap {
pub fn shot(rules: u64, request: u64, blocked: u64, seconds_from_startup: u64) -> Self {
pub fn shot(request: u64, blocked: u64, seconds_from_startup: u64) -> Self {
let blocked_percent = if request > 0 {
blocked as f32 * 100.0 / request as f32
} else {
0.0
};
Self {
rules,
request: Request {
total: request,
allowed: Sum {