diff --git a/src/main.rs b/src/main.rs index 3bc8184..b92f320 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,17 +11,14 @@ use list::List; use log::*; use opt::{AuthMode, Opt}; use rocket::serde::json::Json; -use stats::Stats; +use stats::{Snapshot, Total}; use std::{future::Future, sync::Arc}; use structopt::StructOpt; -use tokio::{net::TcpListener, sync::RwLock, task}; - -type SharedStats = Arc>; -type SharedList = Arc; +use tokio::{net::TcpListener, task}; #[rocket::get("/")] -async fn index(stats: &rocket::State) -> Json { - Json(*stats.read().await) +async fn index(totals: &rocket::State>) -> Json { + Json(totals.inner().snapshot()) } #[rocket::launch] @@ -29,12 +26,12 @@ async fn rocket() -> _ { env_logger::init(); let opt: &'static Opt = Box::leak(Box::new(Opt::from_args())); - let stats = Arc::new(RwLock::new(Stats::default())); + let totals = Arc::new(Total::default()); tokio::spawn({ - let socks_stats = stats.clone(); + let socks_totals = totals.clone(); async move { - if let Err(err) = spawn_socks_server(opt, socks_stats).await { + if let Err(err) = spawn_socks_server(opt, socks_totals).await { error!("SOCKS server failed: `{err}`"); } } @@ -46,11 +43,11 @@ async fn rocket() -> _ { address: opt.api_addr.ip(), ..rocket::Config::release_default() }) - .manage(stats) + .manage(totals) .mount("/", rocket::routes![index]) } -async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()> { +async fn spawn_socks_server(opt: &'static Opt, totals: Arc) -> Result<()> { if opt.allow_udp && opt.public_addr.is_none() { return Err(SocksError::ArgumentInputError( "Can't allow UDP if public-addr is not set", @@ -66,10 +63,8 @@ async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()> error!("Can't parse whitelist: `{err}`"); SocksError::ArgumentInputError("Can't parse whitelist") })?); - { - let mut s = stats.write().await; - s.entries = list.entries(); - } + + totals.set_entries(list.entries() as u64); let listener = TcpListener::bind(&opt.listen_addr).await?; @@ -78,7 +73,7 @@ async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()> loop { match listener.accept().await { Ok((socket, _client_addr)) => { - spawn_and_log_error(serve_socks5(opt, socket, list.clone(), stats.clone())); + spawn_and_log_error(serve_socks5(opt, socket, list.clone(), totals.clone())); } Err(err) => error!("accept error = {:?}", err), } @@ -88,11 +83,10 @@ async fn spawn_socks_server(opt: &'static Opt, stats: SharedStats) -> Result<()> async fn serve_socks5( opt: &Opt, socket: tokio::net::TcpStream, - list: SharedList, - stats: SharedStats, + list: Arc, + totals: Arc, ) -> Result<(), SocksError> { - let mut s = stats.write().await; - s.request += 1; + totals.increase_request(); let request = match &opt.auth { AuthMode::NoAuth if opt.skip_auth => { @@ -114,7 +108,7 @@ async fn serve_socks5( let (proto, cmd, addr) = request.resolve_dns().await?; if !list.has(&host) && !list.has(&addr.to_string()) { - s.blocked += 1; + totals.increase_blocked(); info!("Blocked connection attempt to: {host}"); proto.reply_error(&ReplyError::ConnectionNotAllowed).await?; return Err(ReplyError::ConnectionNotAllowed.into()); diff --git a/src/stats.rs b/src/stats.rs index 1d9d093..2aaf6f8 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -1,9 +1,35 @@ use rocket::serde::Serialize; +use std::sync::atomic::{AtomicU64, Ordering}; -#[derive(Debug, Default, Clone, Copy, Serialize)] -#[serde(crate = "rocket::serde")] -pub struct Stats { - pub entries: usize, - pub blocked: usize, - pub request: usize, +#[derive(Serialize)] +pub struct Snapshot { + pub request: u64, + pub blocked: u64, + pub entries: u64, +} + +#[derive(Default)] +pub struct Total { + entries: AtomicU64, + blocked: AtomicU64, + request: AtomicU64, +} + +impl Total { + pub fn snapshot(&self) -> Snapshot { + Snapshot { + request: self.request.load(Ordering::Relaxed), + blocked: self.blocked.load(Ordering::Relaxed), + entries: self.entries.load(Ordering::Relaxed), + } + } + pub fn set_entries(&self, value: u64) { + self.entries.store(value, Ordering::Relaxed); + } + pub fn increase_blocked(&self) { + self.blocked.fetch_add(1, Ordering::Relaxed); + } + pub fn increase_request(&self) { + self.request.fetch_add(1, Ordering::Relaxed); + } }