mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-04-02 10:45:30 +00:00
aquatic_udp: split some code into mio and glommio versions
This commit is contained in:
parent
65ef9a8ab2
commit
f2b157a149
12 changed files with 527 additions and 53 deletions
|
|
@ -1,5 +1,5 @@
|
|||
use std::hash::Hash;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::{atomic::AtomicUsize, Arc};
|
||||
use std::time::Instant;
|
||||
|
||||
|
|
@ -195,6 +195,31 @@ impl Default for State {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ConnectionMap(HashMap<(ConnectionId, SocketAddr), ValidUntil>);
|
||||
|
||||
impl ConnectionMap {
|
||||
pub fn insert(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
socket_addr: SocketAddr,
|
||||
valid_until: ValidUntil,
|
||||
) {
|
||||
self.0.insert((connection_id, socket_addr), valid_until);
|
||||
}
|
||||
|
||||
pub fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool {
|
||||
self.0.contains_key(&(connection_id, socket_addr))
|
||||
}
|
||||
|
||||
pub fn clean(&mut self) {
|
||||
let now = Instant::now();
|
||||
|
||||
self.0.retain(|_, v| v.0 > now);
|
||||
self.0.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
|
|
|
|||
1
aquatic_udp/src/lib/glommio/mod.rs
Normal file
1
aquatic_udp/src/lib/glommio/mod.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub mod network;
|
||||
200
aquatic_udp/src/lib/glommio/network.rs
Normal file
200
aquatic_udp/src/lib/glommio/network.rs
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
/// TODO
|
||||
/// - move connection checks to socket workers
|
||||
/// - ignore scrape requests. forward announce requests to request workers
|
||||
/// sharded by info hash (with some nice algo to make it difficult for an
|
||||
/// attacker to know which one they get forwarded to). this way, shared
|
||||
/// state can be avoided.
|
||||
use std::io::Cursor;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::rc::Rc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use futures_lite::StreamExt;
|
||||
use glommio::channels::local_channel::{new_unbounded, LocalReceiver, LocalSender};
|
||||
use glommio::channels::shared_channel::{SharedReceiver, SharedSender};
|
||||
use glommio::net::UdpSocket;
|
||||
use glommio::prelude::*;
|
||||
use rand::prelude::{Rng, SeedableRng, StdRng};
|
||||
|
||||
use aquatic_udp_protocol::{IpVersion, Request, Response};
|
||||
|
||||
use crate::common::*;
|
||||
use crate::config::Config;
|
||||
|
||||
pub fn run_socket_worker(
|
||||
state: State,
|
||||
config: Config,
|
||||
request_sender: SharedSender<(ConnectedRequest, SocketAddr)>,
|
||||
response_receiver: SharedReceiver<(ConnectedResponse, SocketAddr)>,
|
||||
num_bound_sockets: Arc<AtomicUsize>,
|
||||
) {
|
||||
LocalExecutorBuilder::default()
|
||||
.spawn(|| async move {
|
||||
let (local_sender, local_receiver) = new_unbounded();
|
||||
|
||||
let mut socket = UdpSocket::bind(config.network.address).unwrap();
|
||||
|
||||
let recv_buffer_size = config.network.socket_recv_buffer_size;
|
||||
|
||||
if recv_buffer_size != 0 {
|
||||
socket.set_buffer_size(recv_buffer_size);
|
||||
}
|
||||
|
||||
let socket = Rc::new(socket);
|
||||
|
||||
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
spawn_local(read_requests(
|
||||
config.clone(),
|
||||
state.access_list.clone(),
|
||||
request_sender,
|
||||
local_sender,
|
||||
socket.clone(),
|
||||
))
|
||||
.await;
|
||||
spawn_local(send_responses(response_receiver, local_receiver, socket)).await;
|
||||
})
|
||||
.expect("failed to spawn local executor")
|
||||
.join()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn read_requests(
|
||||
config: Config,
|
||||
access_list: Arc<AccessList>,
|
||||
request_sender: SharedSender<(ConnectedRequest, SocketAddr)>,
|
||||
local_sender: LocalSender<(Response, SocketAddr)>,
|
||||
socket: Rc<UdpSocket>,
|
||||
) {
|
||||
let request_sender = request_sender.connect().await;
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
|
||||
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
|
||||
let access_list_mode = config.access_list.mode;
|
||||
|
||||
let mut connections = ConnectionMap::default();
|
||||
|
||||
let mut buf = [0u8; 2048];
|
||||
|
||||
loop {
|
||||
match socket.recv_from(&mut buf).await {
|
||||
Ok((amt, src)) => {
|
||||
let request = Request::from_bytes(&buf[..amt], config.protocol.max_scrape_torrents);
|
||||
|
||||
match request {
|
||||
Ok(Request::Connect(request)) => {
|
||||
let connection_id = ConnectionId(rng.gen());
|
||||
|
||||
connections.insert(connection_id, src, valid_until);
|
||||
|
||||
let response = Response::Connect(ConnectResponse {
|
||||
connection_id,
|
||||
transaction_id: request.transaction_id,
|
||||
});
|
||||
|
||||
local_sender.try_send((response, src));
|
||||
}
|
||||
Ok(Request::Announce(request)) => {
|
||||
if connections.contains(request.connection_id, src) {
|
||||
if access_list.allows(access_list_mode, &request.info_hash.0) {
|
||||
if let Err(err) = request_sender
|
||||
.try_send((ConnectedRequest::Announce(request), src))
|
||||
{
|
||||
::log::warn!("request_sender.try_send failed: {:?}", err)
|
||||
}
|
||||
} else {
|
||||
let response = Response::Error(ErrorResponse {
|
||||
transaction_id: request.transaction_id,
|
||||
message: "Info hash not allowed".into(),
|
||||
});
|
||||
|
||||
local_sender.try_send((response, src));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Request::Scrape(request)) => {
|
||||
if connections.contains(request.connection_id, src) {
|
||||
if let Err(err) =
|
||||
request_sender.try_send((ConnectedRequest::Scrape(request), src))
|
||||
{
|
||||
::log::warn!("request_sender.try_send failed: {:?}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
::log::debug!("Request::from_bytes error: {:?}", err);
|
||||
|
||||
if let RequestParseError::Sendable {
|
||||
connection_id,
|
||||
transaction_id,
|
||||
err,
|
||||
} = err
|
||||
{
|
||||
if connections.contains(connection_id, src) {
|
||||
let response = ErrorResponse {
|
||||
transaction_id,
|
||||
message: err.right_or("Parse error").into(),
|
||||
};
|
||||
|
||||
local_sender.try_send((response.into(), src));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
::log::error!("recv_from: {:?}", err);
|
||||
}
|
||||
}
|
||||
|
||||
yield_if_needed().await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_responses(
|
||||
response_receiver: SharedReceiver<(ConnectedResponse, SocketAddr)>,
|
||||
local_receiver: LocalReceiver<(Response, SocketAddr)>,
|
||||
socket: Rc<UdpSocket>,
|
||||
) {
|
||||
let response_receiver = response_receiver.connect().await;
|
||||
|
||||
let mut buf = [0u8; MAX_PACKET_SIZE];
|
||||
let mut buf = Cursor::new(&mut buf[..]);
|
||||
|
||||
let mut stream = local_receiver
|
||||
.stream()
|
||||
.race(response_receiver.map(|(response, addr)| (response.into(), addr)));
|
||||
|
||||
while let Some((response, src)) = stream.next().await {
|
||||
buf.set_position(0);
|
||||
|
||||
response
|
||||
.write(&mut buf, ip_version_from_ip(src.ip()))
|
||||
.expect("write response");
|
||||
|
||||
let position = buf.position() as usize;
|
||||
|
||||
if let Err(err) = socket.send_to(&buf.get_ref()[..position], src).await {
|
||||
::log::info!("send_to failed: {:?}", err);
|
||||
}
|
||||
|
||||
yield_if_needed().await;
|
||||
}
|
||||
}
|
||||
|
||||
fn ip_version_from_ip(ip: IpAddr) -> IpVersion {
|
||||
match ip {
|
||||
IpAddr::V4(_) => IpVersion::IPv4,
|
||||
IpAddr::V6(ip) => {
|
||||
if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() {
|
||||
IpVersion::IPv4
|
||||
} else {
|
||||
IpVersion::IPv6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -11,8 +11,8 @@ use privdrop::PrivDrop;
|
|||
|
||||
pub mod common;
|
||||
pub mod config;
|
||||
pub mod handlers;
|
||||
pub mod network;
|
||||
pub mod glommio;
|
||||
pub mod mio;
|
||||
pub mod tasks;
|
||||
|
||||
use common::State;
|
||||
|
|
@ -74,7 +74,7 @@ pub fn start_workers(config: Config, state: State) -> ::anyhow::Result<Arc<Atomi
|
|||
Builder::new()
|
||||
.name(format!("request-{:02}", i + 1))
|
||||
.spawn(move || {
|
||||
handlers::run_request_worker(state, config, request_receiver, response_sender)
|
||||
mio::handlers::run_request_worker(state, config, request_receiver, response_sender)
|
||||
})
|
||||
.with_context(|| "spawn request worker")?;
|
||||
}
|
||||
|
|
@ -91,7 +91,7 @@ pub fn start_workers(config: Config, state: State) -> ::anyhow::Result<Arc<Atomi
|
|||
Builder::new()
|
||||
.name(format!("socket-{:02}", i + 1))
|
||||
.spawn(move || {
|
||||
network::run_socket_worker(
|
||||
mio::network::run_socket_worker(
|
||||
state,
|
||||
config,
|
||||
i,
|
||||
|
|
|
|||
2
aquatic_udp/src/lib/mio/mod.rs
Normal file
2
aquatic_udp/src/lib/mio/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod handlers;
|
||||
pub mod network;
|
||||
|
|
@ -8,7 +8,6 @@ use std::time::{Duration, Instant};
|
|||
use std::vec::Drain;
|
||||
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
use hashbrown::HashMap;
|
||||
use mio::net::UdpSocket;
|
||||
use mio::{Events, Interest, Poll, Token};
|
||||
use rand::prelude::{Rng, SeedableRng, StdRng};
|
||||
|
|
@ -19,31 +18,6 @@ use aquatic_udp_protocol::{IpVersion, Request, Response};
|
|||
use crate::common::*;
|
||||
use crate::config::Config;
|
||||
|
||||
#[derive(Default)]
|
||||
struct ConnectionMap(HashMap<(ConnectionId, SocketAddr), ValidUntil>);
|
||||
|
||||
impl ConnectionMap {
|
||||
fn insert(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
socket_addr: SocketAddr,
|
||||
valid_until: ValidUntil,
|
||||
) {
|
||||
self.0.insert((connection_id, socket_addr), valid_until);
|
||||
}
|
||||
|
||||
fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool {
|
||||
self.0.contains_key(&(connection_id, socket_addr))
|
||||
}
|
||||
|
||||
fn clean(&mut self) {
|
||||
let now = Instant::now();
|
||||
|
||||
self.0.retain(|_, v| v.0 > now);
|
||||
self.0.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_socket_worker(
|
||||
state: State,
|
||||
config: Config,
|
||||
Loading…
Add table
Add a link
Reference in a new issue