use AtomicU64 for stats, rename some members

This commit is contained in:
postscriptum 2026-03-22 04:34:19 +02:00
parent fcd59fd306
commit b03bdd0e3a
2 changed files with 48 additions and 28 deletions

View file

@ -11,17 +11,14 @@ use list::List;
use log::*; use log::*;
use opt::{AuthMode, Opt}; use opt::{AuthMode, Opt};
use rocket::serde::json::Json; use rocket::serde::json::Json;
use stats::Stats; use stats::{Snapshot, Total};
use std::{future::Future, sync::Arc}; use std::{future::Future, sync::Arc};
use structopt::StructOpt; use structopt::StructOpt;
use tokio::{net::TcpListener, sync::RwLock, task}; use tokio::{net::TcpListener, task};
type SharedStats = Arc<RwLock<Stats>>;
type SharedList = Arc<List>;
#[rocket::get("/")] #[rocket::get("/")]
async fn index(stats: &rocket::State<SharedStats>) -> Json<Stats> { async fn index(totals: &rocket::State<Arc<Total>>) -> Json<Snapshot> {
Json(*stats.read().await) Json(totals.inner().snapshot())
} }
#[rocket::launch] #[rocket::launch]
@ -29,12 +26,12 @@ async fn rocket() -> _ {
env_logger::init(); env_logger::init();
let opt: &'static Opt = Box::leak(Box::new(Opt::from_args())); let opt: &'static Opt = Box::leak(Box::new(Opt::from_args()));
let stats = Arc::new(RwLock::new(Stats::default())); let totals = Arc::new(Total::default());
tokio::spawn({ tokio::spawn({
let socks_stats = stats.clone(); let socks_totals = totals.clone();
async move { async move {
if let Err(err) = spawn_socks_server(opt, socks_stats).await { if let Err(err) = spawn_socks_server(opt, socks_totals).await {
error!("SOCKS server failed: `{err}`"); error!("SOCKS server failed: `{err}`");
} }
} }
@ -46,11 +43,11 @@ async fn rocket() -> _ {
address: opt.api_addr.ip(), address: opt.api_addr.ip(),
..rocket::Config::release_default() ..rocket::Config::release_default()
}) })
.manage(stats) .manage(totals)
.mount("/", rocket::routes![index]) .mount("/", rocket::routes![index])
} }
async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()> { async fn spawn_socks_server(opt: &'static Opt, totals: Arc<Total>) -> Result<()> {
if opt.allow_udp && opt.public_addr.is_none() { if opt.allow_udp && opt.public_addr.is_none() {
return Err(SocksError::ArgumentInputError( return Err(SocksError::ArgumentInputError(
"Can't allow UDP if public-addr is not set", "Can't allow UDP if public-addr is not set",
@ -66,10 +63,8 @@ async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()>
error!("Can't parse whitelist: `{err}`"); error!("Can't parse whitelist: `{err}`");
SocksError::ArgumentInputError("Can't parse whitelist") SocksError::ArgumentInputError("Can't parse whitelist")
})?); })?);
{
let mut s = stats.write().await; totals.set_entries(list.entries() as u64);
s.entries = list.entries();
}
let listener = TcpListener::bind(&opt.listen_addr).await?; let listener = TcpListener::bind(&opt.listen_addr).await?;
@ -78,7 +73,7 @@ async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()>
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok((socket, _client_addr)) => { Ok((socket, _client_addr)) => {
spawn_and_log_error(serve_socks5(opt, socket, list.clone(), stats.clone())); spawn_and_log_error(serve_socks5(opt, socket, list.clone(), totals.clone()));
} }
Err(err) => error!("accept error = {:?}", err), Err(err) => error!("accept error = {:?}", err),
} }
@ -88,11 +83,10 @@ async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()>
async fn serve_socks5( async fn serve_socks5(
opt: &Opt, opt: &Opt,
socket: tokio::net::TcpStream, socket: tokio::net::TcpStream,
list: SharedList, list: Arc<List>,
stats: SharedStats, totals: Arc<Total>,
) -> Result<(), SocksError> { ) -> Result<(), SocksError> {
let mut s = stats.write().await; totals.increase_request();
s.request += 1;
let request = match &opt.auth { let request = match &opt.auth {
AuthMode::NoAuth if opt.skip_auth => { AuthMode::NoAuth if opt.skip_auth => {
@ -114,7 +108,7 @@ async fn serve_socks5(
let (proto, cmd, addr) = request.resolve_dns().await?; let (proto, cmd, addr) = request.resolve_dns().await?;
if !list.has(&host) && !list.has(&addr.to_string()) { if !list.has(&host) && !list.has(&addr.to_string()) {
s.blocked += 1; totals.increase_blocked();
info!("Blocked connection attempt to: {host}"); info!("Blocked connection attempt to: {host}");
proto.reply_error(&ReplyError::ConnectionNotAllowed).await?; proto.reply_error(&ReplyError::ConnectionNotAllowed).await?;
return Err(ReplyError::ConnectionNotAllowed.into()); return Err(ReplyError::ConnectionNotAllowed.into());

View file

@ -1,9 +1,35 @@
use rocket::serde::Serialize; use rocket::serde::Serialize;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Default, Clone, Copy, Serialize)] #[derive(Serialize)]
#[serde(crate = "rocket::serde")] pub struct Snapshot {
pub struct Stats { pub request: u64,
pub entries: usize, pub blocked: u64,
pub blocked: usize, pub entries: u64,
pub request: usize, }
#[derive(Default)]
pub struct Total {
entries: AtomicU64,
blocked: AtomicU64,
request: AtomicU64,
}
impl Total {
pub fn snapshot(&self) -> Snapshot {
Snapshot {
request: self.request.load(Ordering::Relaxed),
blocked: self.blocked.load(Ordering::Relaxed),
entries: self.entries.load(Ordering::Relaxed),
}
}
pub fn set_entries(&self, value: u64) {
self.entries.store(value, Ordering::Relaxed);
}
pub fn increase_blocked(&self) {
self.blocked.fetch_add(1, Ordering::Relaxed);
}
pub fn increase_request(&self) {
self.request.fetch_add(1, Ordering::Relaxed);
}
} }