diff --git a/aquatic_udp/src/lib/common/mod.rs b/aquatic_udp/src/lib/common.rs similarity index 80% rename from aquatic_udp/src/lib/common/mod.rs rename to aquatic_udp/src/lib/common.rs index be382b9..e9fc24e 100644 --- a/aquatic_udp/src/lib/common/mod.rs +++ b/aquatic_udp/src/lib/common.rs @@ -6,7 +6,6 @@ use std::sync::Arc; use std::time::Instant; use crossbeam_channel::Sender; -use socket2::{Domain, Protocol, Socket, Type}; use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap}; use aquatic_common::AHashIndexMap; @@ -15,8 +14,6 @@ use aquatic_udp_protocol::*; use crate::config::Config; -pub mod network; - pub const MAX_PACKET_SIZE: usize = 8192; pub trait Ip: Hash + PartialEq + Eq + Clone + Copy { @@ -60,58 +57,6 @@ pub enum ConnectedResponse { Scrape(PendingScrapeResponse), } -#[derive(Clone, PartialEq, Debug)] -pub struct ProtocolResponsePeer { - pub ip_address: I, - pub port: Port, -} - -pub struct ProtocolAnnounceResponse { - pub transaction_id: TransactionId, - pub announce_interval: AnnounceInterval, - pub leechers: NumberOfPeers, - pub seeders: NumberOfPeers, - pub peers: Vec>, -} - -impl Into for ProtocolAnnounceResponse { - fn into(self) -> ConnectedResponse { - ConnectedResponse::AnnounceIpv4(AnnounceResponseIpv4 { - transaction_id: self.transaction_id, - announce_interval: self.announce_interval, - leechers: self.leechers, - seeders: self.seeders, - peers: self - .peers - .into_iter() - .map(|peer| ResponsePeerIpv4 { - ip_address: peer.ip_address, - port: peer.port, - }) - .collect(), - }) - } -} - -impl Into for ProtocolAnnounceResponse { - fn into(self) -> ConnectedResponse { - ConnectedResponse::AnnounceIpv6(AnnounceResponseIpv6 { - transaction_id: self.transaction_id, - announce_interval: self.announce_interval, - leechers: self.leechers, - seeders: self.seeders, - peers: self - .peers - .into_iter() - .map(|peer| ResponsePeerIpv6 { - ip_address: peer.ip_address, - port: peer.port, - }) - .collect(), - }) - } -} - #[derive(Clone, Copy, Debug)] pub struct SocketWorkerIndex(pub usize); @@ -119,7 +64,7 @@ pub struct SocketWorkerIndex(pub usize); pub struct RequestWorkerIndex(pub usize); impl RequestWorkerIndex { - fn from_info_hash(config: &Config, info_hash: InfoHash) -> Self { + pub fn from_info_hash(config: &Config, info_hash: InfoHash) -> Self { Self(info_hash.0[0] as usize % config.request_workers) } } @@ -201,16 +146,6 @@ pub struct Peer { pub valid_until: ValidUntil, } -impl Peer { - #[inline(always)] - pub fn to_response_peer(&self) -> ProtocolResponsePeer { - ProtocolResponsePeer { - ip_address: self.ip_address, - port: self.port, - } - } -} - pub type PeerMap = AHashIndexMap>; pub struct TorrentData { diff --git a/aquatic_udp/src/lib/common/network.rs b/aquatic_udp/src/lib/common/network.rs deleted file mode 100644 index ee3ca4f..0000000 --- a/aquatic_udp/src/lib/common/network.rs +++ /dev/null @@ -1,235 +0,0 @@ -use std::{net::SocketAddr, time::Instant}; - -use aquatic_common::access_list::AccessListCache; -use aquatic_common::AHashIndexMap; -use aquatic_common::ValidUntil; -use aquatic_udp_protocol::*; -use rand::{prelude::StdRng, Rng}; - -use crate::common::*; - -#[derive(Default)] -pub struct ConnectionMap(AHashIndexMap<(ConnectionId, SocketAddr), ValidUntil>); - -impl ConnectionMap { - pub fn insert( - &mut self, - connection_id: ConnectionId, - socket_addr: SocketAddr, - valid_until: ValidUntil, - ) { - self.0.insert((connection_id, socket_addr), valid_until); - } - - pub fn contains(&self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool { - self.0.contains_key(&(connection_id, socket_addr)) - } - - pub fn clean(&mut self) { - let now = Instant::now(); - - self.0.retain(|_, v| v.0 > now); - self.0.shrink_to_fit(); - } -} - -pub struct PendingScrapeResponseMeta { - num_pending: usize, - valid_until: ValidUntil, -} - -#[derive(Default)] -pub struct PendingScrapeResponseMap( - AHashIndexMap, -); - -impl PendingScrapeResponseMap { - pub fn prepare( - &mut self, - transaction_id: TransactionId, - num_pending: usize, - valid_until: ValidUntil, - ) { - let meta = PendingScrapeResponseMeta { - num_pending, - valid_until, - }; - let response = PendingScrapeResponse { - transaction_id, - torrent_stats: BTreeMap::new(), - }; - - self.0.insert(transaction_id, (meta, response)); - } - - pub fn add_and_get_finished(&mut self, response: PendingScrapeResponse) -> Option { - let finished = if let Some(r) = self.0.get_mut(&response.transaction_id) { - r.0.num_pending -= 1; - - r.1.torrent_stats.extend(response.torrent_stats.into_iter()); - - r.0.num_pending == 0 - } else { - ::log::warn!("PendingScrapeResponses.add didn't find PendingScrapeResponse in map"); - - false - }; - - if finished { - let response = self.0.remove(&response.transaction_id).unwrap().1; - - Some(Response::Scrape(ScrapeResponse { - transaction_id: response.transaction_id, - torrent_stats: response.torrent_stats.into_values().collect(), - })) - } else { - None - } - } - - pub fn clean(&mut self) { - let now = Instant::now(); - - self.0.retain(|_, v| v.0.valid_until.0 > now); - self.0.shrink_to_fit(); - } -} - -pub fn handle_request( - config: &Config, - connections: &mut ConnectionMap, - pending_scrape_responses: &mut PendingScrapeResponseMap, - access_list_cache: &mut AccessListCache, - rng: &mut StdRng, - request_sender: &ConnectedRequestSender, - local_responses: &mut Vec<(Response, SocketAddr)>, - valid_until: ValidUntil, - res_request: Result, - src: SocketAddr, -) { - let access_list_mode = config.access_list.mode; - - match res_request { - Ok(Request::Connect(request)) => { - let connection_id = ConnectionId(rng.gen()); - - connections.insert(connection_id, src, valid_until); - - let response = Response::Connect(ConnectResponse { - connection_id, - transaction_id: request.transaction_id, - }); - - local_responses.push((response, src)) - } - Ok(Request::Announce(request)) => { - if connections.contains(request.connection_id, src) { - if access_list_cache - .load() - .allows(access_list_mode, &request.info_hash.0) - { - let worker_index = - RequestWorkerIndex::from_info_hash(config, request.info_hash); - - request_sender.try_send_to( - worker_index, - ConnectedRequest::Announce(request), - src, - ); - } else { - let response = Response::Error(ErrorResponse { - transaction_id: request.transaction_id, - message: "Info hash not allowed".into(), - }); - - local_responses.push((response, src)) - } - } - } - Ok(Request::Scrape(request)) => { - if connections.contains(request.connection_id, src) { - let mut requests: AHashIndexMap = - Default::default(); - - let transaction_id = request.transaction_id; - - for (i, info_hash) in request.info_hashes.into_iter().enumerate() { - let pending = requests - .entry(RequestWorkerIndex::from_info_hash(&config, info_hash)) - .or_insert_with(|| PendingScrapeRequest { - transaction_id, - info_hashes: BTreeMap::new(), - }); - - pending.info_hashes.insert(i, info_hash); - } - - pending_scrape_responses.prepare(transaction_id, requests.len(), valid_until); - - for (request_worker_index, request) in requests { - request_sender.try_send_to( - request_worker_index, - ConnectedRequest::Scrape(request), - src, - ); - } - } - } - Err(err) => { - ::log::debug!("Request::from_bytes error: {:?}", err); - - if let RequestParseError::Sendable { - connection_id, - transaction_id, - err, - } = err - { - if connections.contains(connection_id, src) { - let response = ErrorResponse { - transaction_id, - message: err.right_or("Parse error").into(), - }; - - local_responses.push((response.into(), src)); - } - } - } - } -} - -pub fn create_socket(config: &Config) -> ::std::net::UdpSocket { - let socket = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) - } else { - Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) - } - .expect("create socket"); - - if config.network.only_ipv6 { - socket.set_only_v6(true).expect("socket: set only ipv6"); - } - - socket.set_reuse_port(true).expect("socket: set reuse port"); - - socket - .set_nonblocking(true) - .expect("socket: set nonblocking"); - - socket - .bind(&config.network.address.into()) - .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); - - let recv_buffer_size = config.network.socket_recv_buffer_size; - - if recv_buffer_size != 0 { - if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { - ::log::error!( - "socket: failed setting recv buffer to {}: {:?}", - recv_buffer_size, - err - ); - } - } - - socket.into() -} diff --git a/aquatic_udp/src/lib/handlers.rs b/aquatic_udp/src/lib/handlers.rs index dcf2d28..dc1d719 100644 --- a/aquatic_udp/src/lib/handlers.rs +++ b/aquatic_udp/src/lib/handlers.rs @@ -1,5 +1,7 @@ use std::collections::BTreeMap; use std::net::IpAddr; +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; use std::net::SocketAddr; use std::time::Duration; use std::time::Instant; @@ -15,6 +17,68 @@ use aquatic_udp_protocol::*; use crate::common::*; use crate::config::Config; +#[derive(Clone, PartialEq, Debug)] +pub struct ProtocolResponsePeer { + pub ip_address: I, + pub port: Port, +} + +impl ProtocolResponsePeer { + #[inline(always)] + fn from_peer(peer: &Peer) -> Self { + Self { + ip_address: peer.ip_address, + port: peer.port, + } + } +} + +pub struct ProtocolAnnounceResponse { + pub transaction_id: TransactionId, + pub announce_interval: AnnounceInterval, + pub leechers: NumberOfPeers, + pub seeders: NumberOfPeers, + pub peers: Vec>, +} + +impl Into for ProtocolAnnounceResponse { + fn into(self) -> ConnectedResponse { + ConnectedResponse::AnnounceIpv4(AnnounceResponseIpv4 { + transaction_id: self.transaction_id, + announce_interval: self.announce_interval, + leechers: self.leechers, + seeders: self.seeders, + peers: self + .peers + .into_iter() + .map(|peer| ResponsePeerIpv4 { + ip_address: peer.ip_address, + port: peer.port, + }) + .collect(), + }) + } +} + +impl Into for ProtocolAnnounceResponse { + fn into(self) -> ConnectedResponse { + ConnectedResponse::AnnounceIpv6(AnnounceResponseIpv6 { + transaction_id: self.transaction_id, + announce_interval: self.announce_interval, + leechers: self.leechers, + seeders: self.seeders, + peers: self + .peers + .into_iter() + .map(|peer| ResponsePeerIpv6 { + ip_address: peer.ip_address, + port: peer.port, + }) + .collect(), + }) + } +} + pub fn run_request_worker( config: Config, state: State, @@ -147,7 +211,7 @@ fn handle_announce_request_inner( &torrent_data.peers, max_num_peers_to_take, request.peer_id, - Peer::to_response_peer, + ProtocolResponsePeer::from_peer, ); ProtocolAnnounceResponse { @@ -262,14 +326,14 @@ mod tests { for i in 0..gen_num_peers { let key = gen_peer_id(i); - let value = gen_peer((i << 16) + i); + let peer = gen_peer((i << 16) + i); if i == 0 { opt_sender_key = Some(key); - opt_sender_peer = Some(value.to_response_peer()); + opt_sender_peer = Some(ProtocolResponsePeer::from_peer(&peer)); } - peer_map.insert(key, value); + peer_map.insert(key, peer); } let mut rng = thread_rng(); @@ -279,7 +343,7 @@ mod tests { &peer_map, req_num_peers, opt_sender_key.unwrap_or_else(|| gen_peer_id(1)), - Peer::to_response_peer, + ProtocolResponsePeer::from_peer, ); // Check that number of returned peers is correct diff --git a/aquatic_udp/src/lib/network.rs b/aquatic_udp/src/lib/network.rs index 5e4cac1..281ef3f 100644 --- a/aquatic_udp/src/lib/network.rs +++ b/aquatic_udp/src/lib/network.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::io::{Cursor, ErrorKind}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::{ @@ -7,19 +8,247 @@ use std::sync::{ use std::time::{Duration, Instant}; use std::vec::Drain; -use aquatic_common::access_list::create_access_list_cache; -use aquatic_common::ValidUntil; use crossbeam_channel::Receiver; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; -use rand::prelude::{SeedableRng, StdRng}; +use rand::prelude::{SeedableRng, StdRng, Rng}; -use aquatic_udp_protocol::{Request, Response}; +use aquatic_common::access_list::create_access_list_cache; +use aquatic_common::ValidUntil; +use aquatic_common::access_list::AccessListCache; +use aquatic_common::AHashIndexMap; +use aquatic_udp_protocol::*; +use socket2::{Domain, Protocol, Socket, Type}; -use crate::common::network::*; use crate::common::*; use crate::config::Config; +#[derive(Default)] +pub struct ConnectionMap(AHashIndexMap<(ConnectionId, SocketAddr), ValidUntil>); + +impl ConnectionMap { + pub fn insert( + &mut self, + connection_id: ConnectionId, + socket_addr: SocketAddr, + valid_until: ValidUntil, + ) { + self.0.insert((connection_id, socket_addr), valid_until); + } + + pub fn contains(&self, connection_id: ConnectionId, socket_addr: SocketAddr) -> bool { + self.0.contains_key(&(connection_id, socket_addr)) + } + + pub fn clean(&mut self) { + let now = Instant::now(); + + self.0.retain(|_, v| v.0 > now); + self.0.shrink_to_fit(); + } +} + +pub struct PendingScrapeResponseMeta { + num_pending: usize, + valid_until: ValidUntil, +} + +#[derive(Default)] +pub struct PendingScrapeResponseMap( + AHashIndexMap, +); + +impl PendingScrapeResponseMap { + pub fn prepare( + &mut self, + transaction_id: TransactionId, + num_pending: usize, + valid_until: ValidUntil, + ) { + let meta = PendingScrapeResponseMeta { + num_pending, + valid_until, + }; + let response = PendingScrapeResponse { + transaction_id, + torrent_stats: BTreeMap::new(), + }; + + self.0.insert(transaction_id, (meta, response)); + } + + pub fn add_and_get_finished(&mut self, response: PendingScrapeResponse) -> Option { + let finished = if let Some(r) = self.0.get_mut(&response.transaction_id) { + r.0.num_pending -= 1; + + r.1.torrent_stats.extend(response.torrent_stats.into_iter()); + + r.0.num_pending == 0 + } else { + ::log::warn!("PendingScrapeResponses.add didn't find PendingScrapeResponse in map"); + + false + }; + + if finished { + let response = self.0.remove(&response.transaction_id).unwrap().1; + + Some(Response::Scrape(ScrapeResponse { + transaction_id: response.transaction_id, + torrent_stats: response.torrent_stats.into_values().collect(), + })) + } else { + None + } + } + + pub fn clean(&mut self) { + let now = Instant::now(); + + self.0.retain(|_, v| v.0.valid_until.0 > now); + self.0.shrink_to_fit(); + } +} + +pub fn handle_request( + config: &Config, + connections: &mut ConnectionMap, + pending_scrape_responses: &mut PendingScrapeResponseMap, + access_list_cache: &mut AccessListCache, + rng: &mut StdRng, + request_sender: &ConnectedRequestSender, + local_responses: &mut Vec<(Response, SocketAddr)>, + valid_until: ValidUntil, + res_request: Result, + src: SocketAddr, +) { + let access_list_mode = config.access_list.mode; + + match res_request { + Ok(Request::Connect(request)) => { + let connection_id = ConnectionId(rng.gen()); + + connections.insert(connection_id, src, valid_until); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, + }); + + local_responses.push((response, src)) + } + Ok(Request::Announce(request)) => { + if connections.contains(request.connection_id, src) { + if access_list_cache + .load() + .allows(access_list_mode, &request.info_hash.0) + { + let worker_index = + RequestWorkerIndex::from_info_hash(config, request.info_hash); + + request_sender.try_send_to( + worker_index, + ConnectedRequest::Announce(request), + src, + ); + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + local_responses.push((response, src)) + } + } + } + Ok(Request::Scrape(request)) => { + if connections.contains(request.connection_id, src) { + let mut requests: AHashIndexMap = + Default::default(); + + let transaction_id = request.transaction_id; + + for (i, info_hash) in request.info_hashes.into_iter().enumerate() { + let pending = requests + .entry(RequestWorkerIndex::from_info_hash(&config, info_hash)) + .or_insert_with(|| PendingScrapeRequest { + transaction_id, + info_hashes: BTreeMap::new(), + }); + + pending.info_hashes.insert(i, info_hash); + } + + pending_scrape_responses.prepare(transaction_id, requests.len(), valid_until); + + for (request_worker_index, request) in requests { + request_sender.try_send_to( + request_worker_index, + ConnectedRequest::Scrape(request), + src, + ); + } + } + } + Err(err) => { + ::log::debug!("Request::from_bytes error: {:?}", err); + + if let RequestParseError::Sendable { + connection_id, + transaction_id, + err, + } = err + { + if connections.contains(connection_id, src) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + local_responses.push((response.into(), src)); + } + } + } + } +} + +pub fn create_socket(config: &Config) -> ::std::net::UdpSocket { + let socket = if config.network.address.is_ipv4() { + Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + } else { + Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) + } + .expect("create socket"); + + if config.network.only_ipv6 { + socket.set_only_v6(true).expect("socket: set only ipv6"); + } + + socket.set_reuse_port(true).expect("socket: set reuse port"); + + socket + .set_nonblocking(true) + .expect("socket: set nonblocking"); + + socket + .bind(&config.network.address.into()) + .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); + + let recv_buffer_size = config.network.socket_recv_buffer_size; + + if recv_buffer_size != 0 { + if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { + ::log::error!( + "socket: failed setting recv buffer to {}: {:?}", + recv_buffer_size, + err + ); + } + } + + socket.into() +} + pub fn run_socket_worker( state: State, config: Config,