udp: move code shared by mio/uring impls to common.rs

This commit is contained in:
Joakim Frostegård 2021-11-14 03:19:33 +01:00
parent ce1c0b24c3
commit d18117595e
3 changed files with 162 additions and 255 deletions

View file

@ -1,8 +1,17 @@
use aquatic_common::access_list::AccessListArcSwap;
use aquatic_common::access_list::{AccessListArcSwap, AccessListCache};
use aquatic_udp_protocol::*;
use crossbeam_channel::Sender;
use parking_lot::Mutex;
use std::sync::{atomic::AtomicUsize, Arc};
use rand::{prelude::StdRng, Rng};
use socket2::{Domain, Protocol, Socket, Type};
use std::{
net::{IpAddr, SocketAddr},
sync::{atomic::AtomicUsize, Arc},
};
use crate::common::*;
use crate::common::{handlers::ConnectedRequest, network::ConnectionMap};
use crate::config::Config;
#[derive(Default)]
pub struct Statistics {
@ -28,3 +37,130 @@ impl Default for State {
}
}
}
pub fn handle_request(
config: &Config,
connections: &mut ConnectionMap,
access_list_cache: &mut AccessListCache,
rng: &mut StdRng,
request_sender: &Sender<(ConnectedRequest, SocketAddr)>,
local_responses: &mut Vec<(Response, SocketAddr)>,
valid_until: ValidUntil,
res_request: Result<Request, RequestParseError>,
src: SocketAddr,
) {
let access_list_mode = config.access_list.mode;
match res_request {
Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen());
connections.insert(connection_id, src, valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) {
if access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
if let Err(err) =
request_sender.try_send((ConnectedRequest::Announce(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) {
let request = ConnectedRequest::Scrape {
request,
original_indices: Vec::new(),
};
if let Err(err) = request_sender.try_send((request, src)) {
::log::warn!("request_sender.try_send failed: {:?}", err)
}
}
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if connections.contains(connection_id, src) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_responses.push((response.into(), src));
}
}
}
}
}
pub fn ip_version_from_ip(ip: IpAddr) -> IpVersion {
match ip {
IpAddr::V4(_) => IpVersion::IPv4,
IpAddr::V6(ip) => {
if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() {
IpVersion::IPv4
} else {
IpVersion::IPv6
}
}
}
}
pub fn create_socket(config: &Config) -> ::std::net::UdpSocket {
let socket = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
} else {
Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))
}
.expect("create socket");
socket.set_reuse_port(true).expect("socket: set reuse port");
socket
.set_nonblocking(true)
.expect("socket: set nonblocking");
socket
.bind(&config.network.address.into())
.unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err));
let recv_buffer_size = config.network.socket_recv_buffer_size;
if recv_buffer_size != 0 {
if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) {
::log::error!(
"socket: failed setting recv buffer to {}: {:?}",
recv_buffer_size,
err
);
}
}
socket.into()
}

View file

@ -1,5 +1,5 @@
use std::io::{Cursor, ErrorKind};
use std::net::{IpAddr, SocketAddr};
use std::net::SocketAddr;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
@ -11,10 +11,9 @@ use aquatic_common::access_list::create_access_list_cache;
use crossbeam_channel::{Receiver, Sender};
use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token};
use rand::prelude::{Rng, SeedableRng, StdRng};
use socket2::{Domain, Protocol, Socket, Type};
use rand::prelude::{SeedableRng, StdRng};
use aquatic_udp_protocol::{IpVersion, Request, Response};
use aquatic_udp_protocol::{Request, Response};
use crate::common::handlers::*;
use crate::common::network::ConnectionMap;
@ -101,39 +100,6 @@ pub fn run_socket_worker(
}
}
fn create_socket(config: &Config) -> ::std::net::UdpSocket {
let socket = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
} else {
Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))
}
.expect("create socket");
socket.set_reuse_port(true).expect("socket: set reuse port");
socket
.set_nonblocking(true)
.expect("socket: set nonblocking");
socket
.bind(&config.network.address.into())
.unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err));
let recv_buffer_size = config.network.socket_recv_buffer_size;
if recv_buffer_size != 0 {
if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) {
::log::error!(
"socket: failed setting recv buffer to {}: {:?}",
recv_buffer_size,
err
);
}
}
socket.into()
}
#[inline]
fn read_requests(
config: &Config,
@ -149,88 +115,32 @@ fn read_requests(
let mut bytes_received: usize = 0;
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode;
let mut access_list_cache = create_access_list_cache(&state.access_list);
loop {
match socket.recv_from(&mut buffer[..]) {
Ok((amt, src)) => {
let request =
let res_request =
Request::from_bytes(&buffer[..amt], config.protocol.max_scrape_torrents);
bytes_received += amt;
if request.is_ok() {
if res_request.is_ok() {
requests_received += 1;
}
match request {
Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen());
connections.insert(connection_id, src, valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) {
if access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
if let Err(err) = request_sender
.try_send((ConnectedRequest::Announce(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) {
let request = ConnectedRequest::Scrape {
request,
original_indices: Vec::new(),
};
if let Err(err) = request_sender.try_send((request, src)) {
::log::warn!("request_sender.try_send failed: {:?}", err)
}
}
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if connections.contains(connection_id, src) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_responses.push((response.into(), src));
}
}
}
}
handle_request(
config,
connections,
&mut access_list_cache,
rng,
request_sender,
local_responses,
valid_until,
res_request,
src,
);
}
Err(err) => {
if err.kind() == ErrorKind::WouldBlock {
@ -314,16 +224,3 @@ fn send_responses(
.fetch_add(bytes_sent, Ordering::SeqCst);
}
}
fn ip_version_from_ip(ip: IpAddr) -> IpVersion {
match ip {
IpAddr::V4(_) => IpVersion::IPv4,
IpAddr::V6(ip) => {
if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() {
IpVersion::IPv4
} else {
IpVersion::IPv6
}
}
}
}

View file

@ -1,6 +1,6 @@
use std::io::Cursor;
use std::mem::size_of;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
use std::os::unix::prelude::AsRawFd;
use std::ptr::null_mut;
use std::sync::{
@ -9,18 +9,17 @@ use std::sync::{
};
use std::time::{Duration, Instant};
use aquatic_common::access_list::{create_access_list_cache, AccessListCache};
use aquatic_common::access_list::create_access_list_cache;
use crossbeam_channel::{Receiver, Sender};
use io_uring::types::{Fixed, Timespec};
use io_uring::SubmissionQueue;
use libc::{
c_void, in6_addr, in_addr, iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6,
};
use rand::prelude::{Rng, SeedableRng, StdRng};
use rand::prelude::{SeedableRng, StdRng};
use slab::Slab;
use socket2::{Domain, Protocol, Socket, Type};
use aquatic_udp_protocol::{IpVersion, Request, Response};
use aquatic_udp_protocol::{Request, Response};
use crate::common::handlers::*;
use crate::common::network::ConnectionMap;
@ -248,14 +247,17 @@ pub fn run_socket_worker(
config.protocol.max_scrape_torrents,
);
// FIXME: don't run every iteration
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
handle_request(
&config,
&state,
&mut connections,
&mut access_list_cache,
&mut rng,
&request_sender,
&mut local_responses,
valid_until,
res_request,
addr,
);
@ -376,88 +378,6 @@ pub fn run_socket_worker(
}
}
fn handle_request(
config: &Config,
state: &State,
connections: &mut ConnectionMap,
access_list_cache: &mut AccessListCache,
rng: &mut StdRng,
request_sender: &Sender<(ConnectedRequest, SocketAddr)>,
local_responses: &mut Vec<(Response, SocketAddr)>,
res_request: Result<Request, RequestParseError>,
src: SocketAddr,
) {
let valid_until = ValidUntil::new(config.cleaning.max_connection_age);
let access_list_mode = config.access_list.mode;
match res_request {
Ok(Request::Connect(request)) => {
let connection_id = ConnectionId(rng.gen());
connections.insert(connection_id, src, valid_until);
let response = Response::Connect(ConnectResponse {
connection_id,
transaction_id: request.transaction_id,
});
local_responses.push((response, src))
}
Ok(Request::Announce(request)) => {
if connections.contains(request.connection_id, src) {
if access_list_cache
.load()
.allows(access_list_mode, &request.info_hash.0)
{
if let Err(err) =
request_sender.try_send((ConnectedRequest::Announce(request), src))
{
::log::warn!("request_sender.try_send failed: {:?}", err)
}
} else {
let response = Response::Error(ErrorResponse {
transaction_id: request.transaction_id,
message: "Info hash not allowed".into(),
});
local_responses.push((response, src))
}
}
}
Ok(Request::Scrape(request)) => {
if connections.contains(request.connection_id, src) {
let request = ConnectedRequest::Scrape {
request,
original_indices: Vec::new(),
};
if let Err(err) = request_sender.try_send((request, src)) {
::log::warn!("request_sender.try_send failed: {:?}", err)
}
}
}
Err(err) => {
::log::debug!("Request::from_bytes error: {:?}", err);
if let RequestParseError::Sendable {
connection_id,
transaction_id,
err,
} = err
{
if connections.contains(connection_id, src) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
};
local_responses.push((response.into(), src));
}
}
}
}
}
fn queue_response(
config: &Config,
sq: &mut SubmissionQueue,
@ -525,52 +445,6 @@ fn queue_response(
}
}
fn create_socket(config: &Config) -> ::std::net::UdpSocket {
let socket = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
} else {
Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))
}
.expect("create socket");
socket.set_reuse_port(true).expect("socket: set reuse port");
socket
.set_nonblocking(true)
.expect("socket: set nonblocking");
socket
.bind(&config.network.address.into())
.unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err));
let recv_buffer_size = config.network.socket_recv_buffer_size;
if recv_buffer_size != 0 {
if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) {
::log::error!(
"socket: failed setting recv buffer to {}: {:?}",
recv_buffer_size,
err
);
}
}
socket.into()
}
fn ip_version_from_ip(ip: IpAddr) -> IpVersion {
match ip {
IpAddr::V4(_) => IpVersion::IPv4,
IpAddr::V6(ip) => {
if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() {
IpVersion::IPv4
} else {
IpVersion::IPv6
}
}
}
}
#[cfg(test)]
mod tests {
use quickcheck::Arbitrary;