aquatic_udp: split some code into mio and glommio versions

This commit is contained in:
Joakim Frostegård 2021-10-18 22:31:56 +02:00
parent 65ef9a8ab2
commit f2b157a149
12 changed files with 527 additions and 53 deletions

View file

@ -1,5 +1,5 @@
use std::hash::Hash;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::{atomic::AtomicUsize, Arc};
use std::time::Instant;
@ -195,6 +195,31 @@ impl Default for State {
}
}
#[derive(Default)]
pub struct ConnectionMap(HashMap<(ConnectionId, SocketAddr), ValidUntil>);
impl ConnectionMap {
pub fn insert(
&mut self,
connection_id: ConnectionId,
socket_addr: SocketAddr,
valid_until: ValidUntil,
) {
self.0.insert((connection_id, socket_addr), valid_until);
}
pub fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool {
self.0.contains_key(&(connection_id, socket_addr))
}
pub fn clean(&mut self) {
let now = Instant::now();
self.0.retain(|_, v| v.0 > now);
self.0.shrink_to_fit();
}
}
#[cfg(test)]
mod tests {
#[test]

View file

@ -0,0 +1 @@
pub mod network;

View file

@ -0,0 +1,200 @@
/// TODO
/// - move connection checks to socket workers
/// - ignore scrape requests. forward announce requests to request workers
/// sharded by info hash (with some nice algo to make it difficult for an
/// attacker to know which one they get forwarded to). this way, shared
/// state can be avoided.
use std::io::Cursor;
use std::net::{IpAddr, SocketAddr};
use std::rc::Rc;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use futures_lite::StreamExt;
use glommio::channels::local_channel::{new_unbounded, LocalReceiver, LocalSender};
use glommio::channels::shared_channel::{SharedReceiver, SharedSender};
use glommio::net::UdpSocket;
use glommio::prelude::*;
use rand::prelude::{Rng, SeedableRng, StdRng};
use aquatic_udp_protocol::{IpVersion, Request, Response};
use crate::common::*;
use crate::config::Config;
pub fn run_socket_worker(
state: State,
config: Config,
request_sender: SharedSender<(ConnectedRequest, SocketAddr)>,
response_receiver: SharedReceiver<(ConnectedResponse, SocketAddr)>,
num_bound_sockets: Arc<AtomicUsize>,
) {
LocalExecutorBuilder::default()
.spawn(|| async move {
let (local_sender, local_receiver) = new_unbounded();
let mut socket = UdpSocket::bind(config.network.address).unwrap();
let recv_buffer_size = config.network.socket_recv_buffer_size;
if recv_buffer_size != 0 {
socket.set_buffer_size(recv_buffer_size);
}
let socket = Rc::new(socket);
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
spawn_local(read_requests(
config.clone(),
state.access_list.clone(),
request_sender,
local_sender,
socket.clone(),
))
.await;
spawn_local(send_responses(response_receiver, local_receiver, socket)).await;
})
.expect("failed to spawn local executor")
.join()
.unwrap();
}
async fn read_requests(
config: Config,
access_list: Arc<AccessList>,
request_sender: SharedSender<(ConnectedRequest, SocketAddr)>,
local_sender: LocalSender<(Response, SocketAddr)>,
socket: Rc<UdpSocket>,
) {
let request_sender = request_sender.connect().await;
let mut rng = StdRng::from_entropy();
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode;
let mut connections = ConnectionMap::default();
let mut buf = [0u8; 2048];
loop {
match socket.recv_from(&mut buf).await {
Ok((amt, src)) => {
let request = Request::from_bytes(&buf[..amt], config.protocol.max_scrape_torrents);
match request {
Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen());
connections.insert(connection_id, src, valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_sender.try_send((response, src));
}
Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) {
if access_list.allows(access_list_mode, &request.info_hash.0) {
if let Err(err) = request_sender
.try_send((ConnectedRequest::Announce(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_sender.try_send((response, src));
}
}
}
Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) {
if let Err(err) =
request_sender.try_send((ConnectedRequest::Scrape(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
}
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if connections.contains(connection_id, src) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_sender.try_send((response.into(), src));
}
}
}
}
}
Err(err) => {
::log::error!("recv_from: {:?}", err);
}
}
yield_if_needed().await;
}
}
async fn send_responses(
response_receiver: SharedReceiver<(ConnectedResponse, SocketAddr)>,
local_receiver: LocalReceiver<(Response, SocketAddr)>,
socket: Rc<UdpSocket>,
) {
let response_receiver = response_receiver.connect().await;
let mut buf = [0u8; MAX_PACKET_SIZE];
let mut buf = Cursor::new(&mut buf[..]);
let mut stream = local_receiver
.stream()
.race(response_receiver.map(|(response, addr)| (response.into(), addr)));
while let Some((response, src)) = stream.next().await {
buf.set_position(0);
response
.write(&mut buf, ip_version_from_ip(src.ip()))
.expect("write response");
let position = buf.position() as usize;
if let Err(err) = socket.send_to(&buf.get_ref()[..position], src).await {
::log::info!("send_to failed: {:?}", err);
}
yield_if_needed().await;
}
}
fn ip_version_from_ip(ip: IpAddr) -> IpVersion {
match ip {
IpAddr::V4(_) => IpVersion::IPv4,
IpAddr::V6(ip) => {
if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() {
IpVersion::IPv4
} else {
IpVersion::IPv6
}
}
}
}

View file

@ -11,8 +11,8 @@ use privdrop::PrivDrop;
pub mod common;
pub mod config;
pub mod handlers;
pub mod network;
pub mod glommio;
pub mod mio;
pub mod tasks;
use common::State;
@ -74,7 +74,7 @@ pub fn start_workers(config: Config, state: State) -> ::anyhow::Result<Arc<Atomi
Builder::new()
.name(format!("request-{:02}", i + 1))
.spawn(move || {
handlers::run_request_worker(state, config, request_receiver, response_sender)
mio::handlers::run_request_worker(state, config, request_receiver, response_sender)
})
.with_context(|| "spawn request worker")?;
}
@ -91,7 +91,7 @@ pub fn start_workers(config: Config, state: State) -> ::anyhow::Result<Arc<Atomi
Builder::new()
.name(format!("socket-{:02}", i + 1))
.spawn(move || {
network::run_socket_worker(
mio::network::run_socket_worker(
state,
config,
i,

View file

@ -0,0 +1,2 @@
pub mod handlers;
pub mod network;

View file

@ -8,7 +8,6 @@ use std::time::{Duration, Instant};
use std::vec::Drain;
use crossbeam_channel::{Receiver, Sender};
use hashbrown::HashMap;
use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token};
use rand::prelude::{Rng, SeedableRng, StdRng};
@ -19,31 +18,6 @@ use aquatic_udp_protocol::{IpVersion, Request, Response};
use crate::common::*;
use crate::config::Config;
#[derive(Default)]
struct ConnectionMap(HashMap<(ConnectionId, SocketAddr), ValidUntil>);
impl ConnectionMap {
fn insert(
&mut self,
connection_id: ConnectionId,
socket_addr: SocketAddr,
valid_until: ValidUntil,
) {
self.0.insert((connection_id, socket_addr), valid_until);
}
fn contains(&mut self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool {
self.0.contains_key(&(connection_id, socket_addr))
}
fn clean(&mut self) {
let now = Instant::now();
self.0.retain(|_, v| v.0 > now);
self.0.shrink_to_fit();
}
}
pub fn run_socket_worker(
state: State,
config: Config,