diff --git a/aquatic_udp/src/lib/common/network.rs b/aquatic_udp/src/lib/common/network.rs index 833c99f..469d658 100644 --- a/aquatic_udp/src/lib/common/network.rs +++ b/aquatic_udp/src/lib/common/network.rs @@ -17,7 +17,7 @@ impl ConnectionMap { self.0.insert((connection_id, socket_addr), valid_until); } - pub fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool { + pub fn contains(&self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool { self.0.contains_key(&(connection_id, socket_addr)) } diff --git a/aquatic_udp/src/lib/glommio/network.rs b/aquatic_udp/src/lib/glommio/network.rs index 93d5f04..0a87d4b 100644 --- a/aquatic_udp/src/lib/glommio/network.rs +++ b/aquatic_udp/src/lib/glommio/network.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::io::Cursor; use std::net::{IpAddr, SocketAddr}; use std::rc::Rc; @@ -5,16 +6,21 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; +use std::time::Duration; use futures_lite::{Stream, StreamExt}; +use glommio::enclose; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::local_channel::{new_unbounded, LocalSender}; use glommio::net::UdpSocket; use glommio::prelude::*; +use glommio::timer::TimerActionRepeat; use rand::prelude::{Rng, SeedableRng, StdRng}; use aquatic_udp_protocol::{IpVersion, Request, Response}; +use super::common::update_access_list; + use crate::common::network::ConnectionMap; use crate::common::*; use crate::config::Config; @@ -76,12 +82,37 @@ async fn read_requests( let access_list_mode = config.access_list.mode; - // Needs to be updated periodically: use timer? - let valid_until = ValidUntil::new(config.cleaning.max_connection_age); - // Needs to be updated periodically: use timer? - let access_list = AccessList::default(); - // Needs to be cleaned periodically: use timer? - let mut connections = ConnectionMap::default(); + let max_connection_age = config.cleaning.max_connection_age; + let connection_valid_until = Rc::new(RefCell::new(ValidUntil::new(max_connection_age))); + let access_list = Rc::new(RefCell::new(AccessList::default())); + let connections = Rc::new(RefCell::new(ConnectionMap::default())); + + // Periodically update connection_valid_until + TimerActionRepeat::repeat(enclose!((connection_valid_until) move || { + enclose!((connection_valid_until) move || async move { + *connection_valid_until.borrow_mut() = ValidUntil::new(max_connection_age); + + Some(Duration::from_secs(1)) + })() + })); + + // Periodically update access list + TimerActionRepeat::repeat(enclose!((config, access_list) move || { + enclose!((config, access_list) move || async move { + update_access_list(config.clone(), access_list.clone()).await; + + Some(Duration::from_secs(config.cleaning.interval)) + })() + })); + + // Periodically clean connections + TimerActionRepeat::repeat(enclose!((config, connections) move || { + enclose!((config, connections) move || async move { + connections.borrow_mut().clean(); + + Some(Duration::from_secs(config.cleaning.interval)) + })() + })); let mut buf = [0u8; 2048]; @@ -96,7 +127,7 @@ async fn read_requests( Ok(Request::Connect(request)) => { let connection_id = ConnectionId(rng.gen()); - connections.insert(connection_id, src, valid_until); + connections.borrow_mut().insert(connection_id, src, connection_valid_until.borrow().to_owned()); let response = Response::Connect(ConnectResponse { connection_id, @@ -106,8 +137,8 @@ async fn read_requests( local_sender.try_send((response, src)).unwrap(); } Ok(Request::Announce(request)) => { - if connections.contains(request.connection_id, src) { - if access_list.allows(access_list_mode, &request.info_hash.0) { + if connections.borrow().contains(request.connection_id, src) { + if access_list.borrow().allows(access_list_mode, &request.info_hash.0) { let request_consumer_index = (request.info_hash.0[0] as usize) % config.request_workers; @@ -128,7 +159,7 @@ async fn read_requests( } } Ok(Request::Scrape(request)) => { - if connections.contains(request.connection_id, src) { + if connections.borrow().contains(request.connection_id, src) { let response = Response::Error(ErrorResponse { transaction_id: request.transaction_id, message: "Scrape requests not supported".into(), @@ -146,7 +177,7 @@ async fn read_requests( err, } = err { - if connections.contains(connection_id, src) { + if connections.borrow().contains(connection_id, src) { let response = ErrorResponse { transaction_id, message: err.right_or("Parse error").into(),