From d18117595ee9d2542f58d664f58bcd08886acfb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Sun, 14 Nov 2021 03:19:33 +0100 Subject: [PATCH] udp: move code shared by mio/uring impls to common.rs --- aquatic_udp/src/lib/other/common.rs | 140 +++++++++++++++++++- aquatic_udp/src/lib/other/network_mio.rs | 135 +++----------------- aquatic_udp/src/lib/other/network_uring.rs | 142 ++------------------- 3 files changed, 162 insertions(+), 255 deletions(-) diff --git a/aquatic_udp/src/lib/other/common.rs b/aquatic_udp/src/lib/other/common.rs index bcaff2f..a1f62f8 100644 --- a/aquatic_udp/src/lib/other/common.rs +++ b/aquatic_udp/src/lib/other/common.rs @@ -1,8 +1,17 @@ -use aquatic_common::access_list::AccessListArcSwap; +use aquatic_common::access_list::{AccessListArcSwap, AccessListCache}; +use aquatic_udp_protocol::*; +use crossbeam_channel::Sender; use parking_lot::Mutex; -use std::sync::{atomic::AtomicUsize, Arc}; +use rand::{prelude::StdRng, Rng}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::{atomic::AtomicUsize, Arc}, +}; use crate::common::*; +use crate::common::{handlers::ConnectedRequest, network::ConnectionMap}; +use crate::config::Config; #[derive(Default)] pub struct Statistics { @@ -28,3 +37,130 @@ impl Default for State { } } } + +pub fn handle_request( + config: &Config, + connections: &mut ConnectionMap, + access_list_cache: &mut AccessListCache, + rng: &mut StdRng, + request_sender: &Sender<(ConnectedRequest, SocketAddr)>, + local_responses: &mut Vec<(Response, SocketAddr)>, + valid_until: ValidUntil, + res_request: Result, + src: SocketAddr, +) { + let access_list_mode = config.access_list.mode; + + match res_request { + Ok(Request::Connect(request)) => { + let connection_id = ConnectionId(rng.gen()); + + connections.insert(connection_id, src, valid_until); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, + }); + + local_responses.push((response, src)) + } + Ok(Request::Announce(request)) => { + if connections.contains(request.connection_id, src) { + if access_list_cache + .load() + .allows(access_list_mode, &request.info_hash.0) + { + if let Err(err) = + request_sender.try_send((ConnectedRequest::Announce(request), src)) + { + ::log::warn!("request_sender.try_send failed: {:?}", err) + } + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + local_responses.push((response, src)) + } + } + } + Ok(Request::Scrape(request)) => { + if connections.contains(request.connection_id, src) { + let request = ConnectedRequest::Scrape { + request, + original_indices: Vec::new(), + }; + + if let Err(err) = request_sender.try_send((request, src)) { + ::log::warn!("request_sender.try_send failed: {:?}", err) + } + } + } + Err(err) => { + ::log::debug!("Request::from_bytes error: {:?}", err); + + if let RequestParseError::Sendable { + connection_id, + transaction_id, + err, + } = err + { + if connections.contains(connection_id, src) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + local_responses.push((response.into(), src)); + } + } + } + } +} + +pub fn ip_version_from_ip(ip: IpAddr) -> IpVersion { + match ip { + IpAddr::V4(_) => IpVersion::IPv4, + IpAddr::V6(ip) => { + if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() { + IpVersion::IPv4 + } else { + IpVersion::IPv6 + } + } + } +} + +pub fn create_socket(config: &Config) -> ::std::net::UdpSocket { + let socket = if config.network.address.is_ipv4() { + Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + } else { + Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) + } + .expect("create socket"); + + socket.set_reuse_port(true).expect("socket: set reuse port"); + + socket + .set_nonblocking(true) + .expect("socket: set nonblocking"); + + socket + .bind(&config.network.address.into()) + .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); + + let recv_buffer_size = config.network.socket_recv_buffer_size; + + if recv_buffer_size != 0 { + if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { + ::log::error!( + "socket: failed setting recv buffer to {}: {:?}", + recv_buffer_size, + err + ); + } + } + + socket.into() +} diff --git a/aquatic_udp/src/lib/other/network_mio.rs b/aquatic_udp/src/lib/other/network_mio.rs index dfe00d6..d04fb2a 100644 --- a/aquatic_udp/src/lib/other/network_mio.rs +++ b/aquatic_udp/src/lib/other/network_mio.rs @@ -1,5 +1,5 @@ use std::io::{Cursor, ErrorKind}; -use std::net::{IpAddr, SocketAddr}; +use std::net::SocketAddr; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -11,10 +11,9 @@ use aquatic_common::access_list::create_access_list_cache; use crossbeam_channel::{Receiver, Sender}; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; -use rand::prelude::{Rng, SeedableRng, StdRng}; -use socket2::{Domain, Protocol, Socket, Type}; +use rand::prelude::{SeedableRng, StdRng}; -use aquatic_udp_protocol::{IpVersion, Request, Response}; +use aquatic_udp_protocol::{Request, Response}; use crate::common::handlers::*; use crate::common::network::ConnectionMap; @@ -101,39 +100,6 @@ pub fn run_socket_worker( } } -fn create_socket(config: &Config) -> ::std::net::UdpSocket { - let socket = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) - } else { - Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) - } - .expect("create socket"); - - socket.set_reuse_port(true).expect("socket: set reuse port"); - - socket - .set_nonblocking(true) - .expect("socket: set nonblocking"); - - socket - .bind(&config.network.address.into()) - .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); - - let recv_buffer_size = config.network.socket_recv_buffer_size; - - if recv_buffer_size != 0 { - if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { - ::log::error!( - "socket: failed setting recv buffer to {}: {:?}", - recv_buffer_size, - err - ); - } - } - - socket.into() -} - #[inline] fn read_requests( config: &Config, @@ -149,88 +115,32 @@ fn read_requests( let mut bytes_received: usize = 0; let valid_until = ValidUntil::new(config.cleaning.max_connection_age); - let access_list_mode = config.access_list.mode; let mut access_list_cache = create_access_list_cache(&state.access_list); loop { match socket.recv_from(&mut buffer[..]) { Ok((amt, src)) => { - let request = + let res_request = Request::from_bytes(&buffer[..amt], config.protocol.max_scrape_torrents); bytes_received += amt; - if request.is_ok() { + if res_request.is_ok() { requests_received += 1; } - match request { - Ok(Request::Connect(request)) => { - let connection_id = ConnectionId(rng.gen()); - - connections.insert(connection_id, src, valid_until); - - let response = Response::Connect(ConnectResponse { - connection_id, - transaction_id: request.transaction_id, - }); - - local_responses.push((response, src)) - } - Ok(Request::Announce(request)) => { - if connections.contains(request.connection_id, src) { - if access_list_cache - .load() - .allows(access_list_mode, &request.info_hash.0) - { - if let Err(err) = request_sender - .try_send((ConnectedRequest::Announce(request), src)) - { - ::log::warn!("request_sender.try_send failed: {:?}", err) - } - } else { - let response = Response::Error(ErrorResponse { - transaction_id: request.transaction_id, - message: "Info hash not allowed".into(), - }); - - local_responses.push((response, src)) - } - } - } - Ok(Request::Scrape(request)) => { - if connections.contains(request.connection_id, src) { - let request = ConnectedRequest::Scrape { - request, - original_indices: Vec::new(), - }; - - if let Err(err) = request_sender.try_send((request, src)) { - ::log::warn!("request_sender.try_send failed: {:?}", err) - } - } - } - Err(err) => { - ::log::debug!("Request::from_bytes error: {:?}", err); - - if let RequestParseError::Sendable { - connection_id, - transaction_id, - err, - } = err - { - if connections.contains(connection_id, src) { - let response = ErrorResponse { - transaction_id, - message: err.right_or("Parse error").into(), - }; - - local_responses.push((response.into(), src)); - } - } - } - } + handle_request( + config, + connections, + &mut access_list_cache, + rng, + request_sender, + local_responses, + valid_until, + res_request, + src, + ); } Err(err) => { if err.kind() == ErrorKind::WouldBlock { @@ -314,16 +224,3 @@ fn send_responses( .fetch_add(bytes_sent, Ordering::SeqCst); } } - -fn ip_version_from_ip(ip: IpAddr) -> IpVersion { - match ip { - IpAddr::V4(_) => IpVersion::IPv4, - IpAddr::V6(ip) => { - if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() { - IpVersion::IPv4 - } else { - IpVersion::IPv6 - } - } - } -} diff --git a/aquatic_udp/src/lib/other/network_uring.rs b/aquatic_udp/src/lib/other/network_uring.rs index 520fff9..184e092 100644 --- a/aquatic_udp/src/lib/other/network_uring.rs +++ b/aquatic_udp/src/lib/other/network_uring.rs @@ -1,6 +1,6 @@ use std::io::Cursor; use std::mem::size_of; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; use std::os::unix::prelude::AsRawFd; use std::ptr::null_mut; use std::sync::{ @@ -9,18 +9,17 @@ use std::sync::{ }; use std::time::{Duration, Instant}; -use aquatic_common::access_list::{create_access_list_cache, AccessListCache}; +use aquatic_common::access_list::create_access_list_cache; use crossbeam_channel::{Receiver, Sender}; use io_uring::types::{Fixed, Timespec}; use io_uring::SubmissionQueue; use libc::{ c_void, in6_addr, in_addr, iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6, }; -use rand::prelude::{Rng, SeedableRng, StdRng}; +use rand::prelude::{SeedableRng, StdRng}; use slab::Slab; -use socket2::{Domain, Protocol, Socket, Type}; -use aquatic_udp_protocol::{IpVersion, Request, Response}; +use aquatic_udp_protocol::{Request, Response}; use crate::common::handlers::*; use crate::common::network::ConnectionMap; @@ -248,14 +247,17 @@ pub fn run_socket_worker( config.protocol.max_scrape_torrents, ); + // FIXME: don't run every iteration + let valid_until = ValidUntil::new(config.cleaning.max_connection_age); + handle_request( &config, - &state, &mut connections, &mut access_list_cache, &mut rng, &request_sender, &mut local_responses, + valid_until, res_request, addr, ); @@ -376,88 +378,6 @@ pub fn run_socket_worker( } } -fn handle_request( - config: &Config, - state: &State, - connections: &mut ConnectionMap, - access_list_cache: &mut AccessListCache, - rng: &mut StdRng, - request_sender: &Sender<(ConnectedRequest, SocketAddr)>, - local_responses: &mut Vec<(Response, SocketAddr)>, - res_request: Result, - src: SocketAddr, -) { - let valid_until = ValidUntil::new(config.cleaning.max_connection_age); - let access_list_mode = config.access_list.mode; - - match res_request { - Ok(Request::Connect(request)) => { - let connection_id = ConnectionId(rng.gen()); - - connections.insert(connection_id, src, valid_until); - - let response = Response::Connect(ConnectResponse { - connection_id, - transaction_id: request.transaction_id, - }); - - local_responses.push((response, src)) - } - Ok(Request::Announce(request)) => { - if connections.contains(request.connection_id, src) { - if access_list_cache - .load() - .allows(access_list_mode, &request.info_hash.0) - { - if let Err(err) = - request_sender.try_send((ConnectedRequest::Announce(request), src)) - { - ::log::warn!("request_sender.try_send failed: {:?}", err) - } - } else { - let response = Response::Error(ErrorResponse { - transaction_id: request.transaction_id, - message: "Info hash not allowed".into(), - }); - - local_responses.push((response, src)) - } - } - } - Ok(Request::Scrape(request)) => { - if connections.contains(request.connection_id, src) { - let request = ConnectedRequest::Scrape { - request, - original_indices: Vec::new(), - }; - - if let Err(err) = request_sender.try_send((request, src)) { - ::log::warn!("request_sender.try_send failed: {:?}", err) - } - } - } - Err(err) => { - ::log::debug!("Request::from_bytes error: {:?}", err); - - if let RequestParseError::Sendable { - connection_id, - transaction_id, - err, - } = err - { - if connections.contains(connection_id, src) { - let response = ErrorResponse { - transaction_id, - message: err.right_or("Parse error").into(), - }; - - local_responses.push((response.into(), src)); - } - } - } - } -} - fn queue_response( config: &Config, sq: &mut SubmissionQueue, @@ -525,52 +445,6 @@ fn queue_response( } } -fn create_socket(config: &Config) -> ::std::net::UdpSocket { - let socket = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) - } else { - Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) - } - .expect("create socket"); - - socket.set_reuse_port(true).expect("socket: set reuse port"); - - socket - .set_nonblocking(true) - .expect("socket: set nonblocking"); - - socket - .bind(&config.network.address.into()) - .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); - - let recv_buffer_size = config.network.socket_recv_buffer_size; - - if recv_buffer_size != 0 { - if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { - ::log::error!( - "socket: failed setting recv buffer to {}: {:?}", - recv_buffer_size, - err - ); - } - } - - socket.into() -} - -fn ip_version_from_ip(ip: IpAddr) -> IpVersion { - match ip { - IpAddr::V4(_) => IpVersion::IPv4, - IpAddr::V6(ip) => { - if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() { - IpVersion::IPv4 - } else { - IpVersion::IPv6 - } - } - } -} - #[cfg(test)] mod tests { use quickcheck::Arbitrary;