diff --git a/Cargo.lock b/Cargo.lock index 2eebe57..51ebcce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,31 @@ dependencies = [ "rand", ] +[[package]] +name = "aquatic_http" +version = "0.1.0" +dependencies = [ + "anyhow", + "aquatic_cli_helpers", + "aquatic_common", + "either", + "flume", + "hashbrown", + "indexmap", + "log", + "mimalloc", + "mio", + "native-tls", + "parking_lot", + "privdrop", + "quickcheck", + "quickcheck_macros", + "rand", + "serde", + "simplelog", + "socket2", +] + [[package]] name = "aquatic_udp" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e890ce2..bf89258 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "aquatic_cli_helpers", "aquatic_common", + "aquatic_http", "aquatic_udp", "aquatic_udp_bench", "aquatic_udp_load_test", diff --git a/TODO.md b/TODO.md index e41c749..7c3b3e9 100644 --- a/TODO.md +++ b/TODO.md @@ -5,6 +5,13 @@ * avx-512 should be avoided, maybe this should be mentioned in README and maybe run scripts should be adjusted +## aquatic_http +* setup tls connection: support TLS and plain at the same time?? +* parse http requests incrementally when data comes in. crate for streaming + parse? +* serde for request/responses, also url encoded info hashes and peer id's +* move stuff to common crate with ws: what about Request/InMessage etc? + ## aquatic_ws * tests * ipv4 and ipv6 state split: think about this more.. diff --git a/aquatic_http/Cargo.toml b/aquatic_http/Cargo.toml new file mode 100644 index 0000000..323dbda --- /dev/null +++ b/aquatic_http/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "aquatic_http" +version = "0.1.0" +authors = ["Joakim FrostegĂ„rd "] +edition = "2018" +license = "Apache-2.0" + +[lib] +name = "aquatic_http" +path = "src/lib/lib.rs" + +[[bin]] +name = "aquatic_http" +path = "src/bin/main.rs" + +[dependencies] +anyhow = "1" +aquatic_cli_helpers = { path = "../aquatic_cli_helpers" } +aquatic_common = { path = "../aquatic_common" } +either = "1" +flume = "0.7" +hashbrown = { version = "0.7", features = ["serde"] } +indexmap = "1" +log = "0.4" +mimalloc = { version = "0.1", default-features = false } +mio = { version = "0.7", features = ["tcp", "os-poll", "os-util"] } +native-tls = "0.2" +parking_lot = "0.10" +privdrop = "0.3" +rand = { version = "0.7", features = ["small_rng"] } +serde = { version = "1", features = ["derive"] } +socket2 = { version = "0.3", features = ["reuseport"] } +simplelog = "0.8" + +[dev-dependencies] +quickcheck = "0.9" +quickcheck_macros = "0.9" diff --git a/aquatic_http/src/bin/main.rs b/aquatic_http/src/bin/main.rs new file mode 100644 index 0000000..4d093fd --- /dev/null +++ b/aquatic_http/src/bin/main.rs @@ -0,0 +1,44 @@ +use anyhow::Context; +use aquatic_cli_helpers::run_app_with_cli_and_config; +use simplelog::{ConfigBuilder, LevelFilter, TermLogger, TerminalMode}; + +use aquatic_http::config::{Config, LogLevel}; + + +#[global_allocator] +static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; + + +fn main(){ + run_app_with_cli_and_config::( + "aquatic: BitTorrent (HTTP/TLS) tracker", + run + ) +} + + +// almost identical to ws version +fn run(config: Config) -> anyhow::Result<()> { + let level_filter = match config.log_level { + LogLevel::Off => LevelFilter::Off, + LogLevel::Error => LevelFilter::Error, + LogLevel::Warn => LevelFilter::Warn, + LogLevel::Info => LevelFilter::Info, + LogLevel::Debug => LevelFilter::Debug, + LogLevel::Trace => LevelFilter::Trace, + }; + + // Note: logger doesn't seem to pick up thread names. Not a huge loss. + let simplelog_config = ConfigBuilder::new() + .set_time_to_local(true) + .set_location_level(LevelFilter::Off) + .build(); + + TermLogger::init( + level_filter, + simplelog_config, + TerminalMode::Stderr + ).context("Couldn't initialize logger")?; + + aquatic_http::run(config) +} \ No newline at end of file diff --git a/aquatic_http/src/lib/common.rs b/aquatic_http/src/lib/common.rs new file mode 100644 index 0000000..3ac09a4 --- /dev/null +++ b/aquatic_http/src/lib/common.rs @@ -0,0 +1,148 @@ +use std::net::{SocketAddr, IpAddr}; +use std::sync::Arc; + +use flume::{Sender, Receiver}; +use hashbrown::HashMap; +use indexmap::IndexMap; +use log::error; +use mio::Token; +use parking_lot::Mutex; + +pub use aquatic_common::ValidUntil; + +use crate::protocol::*; + + + +// identical to ws version +#[derive(Clone, Copy, Debug)] +pub struct ConnectionMeta { + /// Index of socket worker responsible for this connection. Required for + /// sending back response through correct channel to correct worker. + pub worker_index: usize, + pub peer_addr: SocketAddr, + pub poll_token: Token, +} + + +// identical to ws version +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum PeerStatus { + Seeding, + Leeching, + Stopped +} + + +// identical to ws version - FIXME only if bytes left is optional +impl PeerStatus { + /// Determine peer status from announce event and number of bytes left. + /// + /// Likely, the last branch will be taken most of the time. + #[inline] + pub fn from_event_and_bytes_left( + event: AnnounceEvent, + opt_bytes_left: Option + ) -> Self { + if let AnnounceEvent::Stopped = event { + Self::Stopped + } else if let Some(0) = opt_bytes_left { + Self::Seeding + } else { + Self::Leeching + } + } +} + + +#[derive(Clone, Copy)] +pub struct Peer { + pub connection_meta: ConnectionMeta, + pub port: u16, + pub status: PeerStatus, + pub valid_until: ValidUntil, +} + + +// identical to ws version +pub type PeerMap = IndexMap; + + +// identical to ws version +pub struct TorrentData { + pub peers: PeerMap, + pub num_seeders: usize, + pub num_leechers: usize, +} + + +// identical to ws version +impl Default for TorrentData { + #[inline] + fn default() -> Self { + Self { + peers: IndexMap::new(), + num_seeders: 0, + num_leechers: 0, + } + } +} + + +// identical to ws version +pub type TorrentMap = HashMap; + + +// identical to ws version +#[derive(Default)] +pub struct TorrentMaps { + pub ipv4: TorrentMap, + pub ipv6: TorrentMap, +} + + +// identical to ws version +#[derive(Clone)] +pub struct State { + pub torrent_maps: Arc>, +} + + +// identical to ws version +impl Default for State { + fn default() -> Self { + Self { + torrent_maps: Arc::new(Mutex::new(TorrentMaps::default())), + } + } +} + + +pub type RequestChannelSender = Sender<(ConnectionMeta, Request)>; +pub type RequestChannelReceiver = Receiver<(ConnectionMeta, Request)>; +pub type ResponseChannelReceiver = Receiver<(ConnectionMeta, Response)>; + + +pub struct ResponseChannelSender(Vec>); + + +impl ResponseChannelSender { + pub fn new(senders: Vec>) -> Self { + Self(senders) + } + + #[inline] + pub fn send( + &self, + meta: ConnectionMeta, + message: Response + ){ + if let Err(err) = self.0[meta.worker_index].send((meta, message)){ + error!("ResponseChannelSender: couldn't send message: {:?}", err); + } + } +} + + +pub type SocketWorkerStatus = Option>; +pub type SocketWorkerStatuses = Arc>>; diff --git a/aquatic_http/src/lib/config.rs b/aquatic_http/src/lib/config.rs new file mode 100644 index 0000000..982ef58 --- /dev/null +++ b/aquatic_http/src/lib/config.rs @@ -0,0 +1,179 @@ +use std::net::SocketAddr; + +use serde::{Serialize, Deserialize}; + + +// identical to ws version +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + Off, + Error, + Warn, + Info, + Debug, + Trace +} + + +// identical to ws version +impl Default for LogLevel { + fn default() -> Self { + Self::Error + } +} + + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct Config { + /// Socket workers receive requests from the socket, parse them and send + /// them on to the request handler. They then recieve responses from the + /// request handler, encode them and send them back over the socket. + pub socket_workers: usize, + pub log_level: LogLevel, + pub network: NetworkConfig, + pub protocol: ProtocolConfig, + pub handlers: HandlerConfig, + pub cleaning: CleaningConfig, + pub privileges: PrivilegeConfig, +} + + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct NetworkConfig { + /// Bind to this address + pub address: SocketAddr, + pub ipv6_only: bool, + pub use_tls: bool, + pub tls_pkcs12_path: String, + pub tls_pkcs12_password: String, + pub poll_event_capacity: usize, + pub poll_timeout_milliseconds: u64, +} + + +// identical to ws version +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct HandlerConfig { + /// Maximum number of requests to receive from channel before locking + /// mutex and starting work + pub max_requests_per_iter: usize, + pub channel_recv_timeout_microseconds: u64, +} + + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct ProtocolConfig { + /// Maximum number of torrents to accept in scrape request + pub max_scrape_torrents: usize, + /// Maximum number of requested peers to accept in announce request + pub max_peers: usize, + /// Ask peers to announce this often (seconds) + pub peer_announce_interval: usize, +} + + +// identical to ws version +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct CleaningConfig { + /// Clean peers this often (seconds) + pub interval: u64, + /// Remove peers that haven't announced for this long (seconds) + pub max_peer_age: u64, + /// Remove connections that are older than this (seconds) + pub max_connection_age: u64, +} + + +// identical to ws version +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct PrivilegeConfig { + /// Chroot and switch user after binding to sockets + pub drop_privileges: bool, + /// Chroot to this path + pub chroot_path: String, + /// User to switch to after chrooting + pub user: String, +} + + +impl Default for Config { + fn default() -> Self { + Self { + socket_workers: 1, + log_level: LogLevel::default(), + network: NetworkConfig::default(), + protocol: ProtocolConfig::default(), + handlers: HandlerConfig::default(), + cleaning: CleaningConfig::default(), + privileges: PrivilegeConfig::default(), + } + } +} + + +impl Default for NetworkConfig { + fn default() -> Self { + Self { + address: SocketAddr::from(([0, 0, 0, 0], 3000)), + ipv6_only: false, + use_tls: false, + tls_pkcs12_path: "".into(), + tls_pkcs12_password: "".into(), + poll_event_capacity: 4096, + poll_timeout_milliseconds: 50, + } + } +} + + +impl Default for ProtocolConfig { + fn default() -> Self { + Self { + max_scrape_torrents: 255, // FIXME: what value is reasonable? + max_peers: 50, + peer_announce_interval: 120, + } + } +} + + +// identical to ws version +impl Default for HandlerConfig { + fn default() -> Self { + Self { + max_requests_per_iter: 10000, + channel_recv_timeout_microseconds: 200, + } + } +} + + +// identical to ws version +impl Default for CleaningConfig { + fn default() -> Self { + Self { + interval: 30, + max_peer_age: 180, + max_connection_age: 180, + } + } +} + + +// identical to ws version +impl Default for PrivilegeConfig { + fn default() -> Self { + Self { + drop_privileges: false, + chroot_path: ".".to_string(), + user: "nobody".to_string(), + } + } +} \ No newline at end of file diff --git a/aquatic_http/src/lib/handler.rs b/aquatic_http/src/lib/handler.rs new file mode 100644 index 0000000..95f614e --- /dev/null +++ b/aquatic_http/src/lib/handler.rs @@ -0,0 +1,208 @@ +use std::time::Duration; +use std::vec::Drain; + +use hashbrown::HashMap; +use parking_lot::MutexGuard; +use rand::{Rng, SeedableRng, rngs::SmallRng}; + +use aquatic_common::extract_response_peers; + +use crate::common::*; +use crate::config::Config; +use crate::protocol::*; + + +// almost identical to ws version +pub fn run_request_worker( + config: Config, + state: State, + request_channel_receiver: RequestChannelReceiver, + response_channel_sender: ResponseChannelSender, +){ + let mut responses = Vec::new(); + + let mut announce_requests = Vec::new(); + let mut scrape_requests = Vec::new(); + + let mut rng = SmallRng::from_entropy(); + + let timeout = Duration::from_micros( + config.handlers.channel_recv_timeout_microseconds + ); + + loop { + let mut opt_torrent_map_guard: Option> = None; + + for i in 0..config.handlers.max_requests_per_iter { + let opt_in_message = if i == 0 { + request_channel_receiver.recv().ok() + } else { + request_channel_receiver.recv_timeout(timeout).ok() + }; + + match opt_in_message { + Some((meta, Request::Announce(r))) => { + announce_requests.push((meta, r)); + }, + Some((meta, Request::Scrape(r))) => { + scrape_requests.push((meta, r)); + }, + None => { + if let Some(torrent_guard) = state.torrent_maps.try_lock(){ + opt_torrent_map_guard = Some(torrent_guard); + + break + } + } + } + } + + let mut torrent_map_guard = opt_torrent_map_guard + .unwrap_or_else(|| state.torrent_maps.lock()); + + handle_announce_requests( + &config, + &mut rng, + &mut torrent_map_guard, + &mut responses, + announce_requests.drain(..) + ); + + handle_scrape_requests( + &config, + &mut torrent_map_guard, + &mut responses, + scrape_requests.drain(..) + ); + + ::std::mem::drop(torrent_map_guard); + + for (meta, response) in responses.drain(..){ + response_channel_sender.send(meta, response); + } + } +} + + +pub fn handle_announce_requests( + config: &Config, + rng: &mut impl Rng, + torrent_maps: &mut TorrentMaps, + responses: &mut Vec<(ConnectionMeta, Response)>, + requests: Drain<(ConnectionMeta, AnnounceRequest)>, +){ + let valid_until = ValidUntil::new(config.cleaning.max_peer_age); + + responses.extend(requests.into_iter().map(|(request_sender_meta, request)| { + let torrent_data: &mut TorrentData = if request_sender_meta.peer_addr.is_ipv4(){ + torrent_maps.ipv4.entry(request.info_hash).or_default() + } else { + torrent_maps.ipv6.entry(request.info_hash).or_default() + }; + + // Insert/update/remove peer who sent this request + { + let peer_status = PeerStatus::from_event_and_bytes_left( + request.event, + Some(request.bytes_left) + ); + + let peer = Peer { + connection_meta: request_sender_meta, + port: request.port, + status: peer_status, + valid_until, + }; + + let opt_removed_peer = match peer_status { + PeerStatus::Leeching => { + torrent_data.num_leechers += 1; + + torrent_data.peers.insert(request.peer_id, peer) + }, + PeerStatus::Seeding => { + torrent_data.num_seeders += 1; + + torrent_data.peers.insert(request.peer_id, peer) + }, + PeerStatus::Stopped => { + torrent_data.peers.remove(&request.peer_id) + } + }; + + match opt_removed_peer.map(|peer| peer.status){ + Some(PeerStatus::Leeching) => { + torrent_data.num_leechers -= 1; + }, + Some(PeerStatus::Seeding) => { + torrent_data.num_seeders -= 1; + }, + _ => {} + } + } + + let max_num_peers_to_take = if request.numwant <= 0 { + config.protocol.max_peers + } else { + request.numwant.min(config.protocol.max_peers) + }; + + let response_peers: Vec = extract_response_peers( + rng, + &torrent_data.peers, + max_num_peers_to_take, + ResponsePeer::from_peer + ); + + let response = Response::AnnounceSuccess(AnnounceResponseSuccess { + complete: torrent_data.num_seeders, + incomplete: torrent_data.num_leechers, + announce_interval: config.protocol.peer_announce_interval, + peers: response_peers, + tracker_id: "".to_string() + }); + + (request_sender_meta, response) + })); +} + + +// almost identical to ws version +pub fn handle_scrape_requests( + config: &Config, + torrent_maps: &mut TorrentMaps, + messages_out: &mut Vec<(ConnectionMeta, Response)>, + requests: Drain<(ConnectionMeta, ScrapeRequest)>, +){ + messages_out.extend(requests.map(|(meta, request)| { + let num_to_take = request.info_hashes.len().min( + config.protocol.max_scrape_torrents + ); + + let mut response = ScrapeResponse { + files: HashMap::with_capacity(num_to_take), + }; + + let torrent_map: &mut TorrentMap = if meta.peer_addr.is_ipv4(){ + &mut torrent_maps.ipv4 + } else { + &mut torrent_maps.ipv6 + }; + + // If request.info_hashes is empty, don't return scrape for all + // torrents, even though reference server does it. It is too expensive. + for info_hash in request.info_hashes.into_iter().take(num_to_take){ + if let Some(torrent_data) = torrent_map.get(&info_hash){ + let stats = ScrapeStatistics { + complete: torrent_data.num_seeders, + downloaded: 0, // No implementation planned + incomplete: torrent_data.num_leechers, + }; + + response.files.insert(info_hash, stats); + } + } + + (meta, Response::Scrape(response)) + })); +} \ No newline at end of file diff --git a/aquatic_http/src/lib/lib.rs b/aquatic_http/src/lib/lib.rs new file mode 100644 index 0000000..2216e8a --- /dev/null +++ b/aquatic_http/src/lib/lib.rs @@ -0,0 +1,143 @@ +use std::time::Duration; +use std::fs::File; +use std::io::Read; +use std::sync::Arc; +use std::thread::Builder; + +use anyhow::Context; +use native_tls::{Identity, TlsAcceptor}; +use parking_lot::Mutex; +use privdrop::PrivDrop; + +pub mod common; +pub mod config; +pub mod handler; +pub mod network; +pub mod protocol; +pub mod tasks; + +use common::*; +use config::Config; + + +// almost identical to ws version +pub fn run(config: Config) -> anyhow::Result<()> { + let opt_tls_acceptor = create_tls_acceptor(&config)?; + + let state = State::default(); + + let (request_channel_sender, request_channel_receiver) = ::flume::unbounded(); + + let mut out_message_senders = Vec::new(); + + let socket_worker_statuses: SocketWorkerStatuses = { + let mut statuses = Vec::new(); + + for _ in 0..config.socket_workers { + statuses.push(None); + } + + Arc::new(Mutex::new(statuses)) + }; + + for i in 0..config.socket_workers { + let config = config.clone(); + let socket_worker_statuses = socket_worker_statuses.clone(); + let request_channel_sender = request_channel_sender.clone(); + let opt_tls_acceptor = opt_tls_acceptor.clone(); + + let (response_channel_sender, response_channel_receiver) = ::flume::unbounded(); + + out_message_senders.push(response_channel_sender); + + Builder::new().name(format!("socket-{:02}", i + 1)).spawn(move || { + network::run_socket_worker( + config, + i, + socket_worker_statuses, + request_channel_sender, + response_channel_receiver, + opt_tls_acceptor + ); + })?; + } + + // Wait for socket worker statuses. On error from any, quit program. + // On success from all, drop privileges if corresponding setting is set + // and continue program. + loop { + ::std::thread::sleep(::std::time::Duration::from_millis(10)); + + if let Some(statuses) = socket_worker_statuses.try_lock(){ + for opt_status in statuses.iter(){ + match opt_status { + Some(Err(err)) => { + return Err(::anyhow::anyhow!(err.to_owned())); + }, + _ => {}, + } + } + + if statuses.iter().all(Option::is_some){ + if config.privileges.drop_privileges { + PrivDrop::default() + .chroot(config.privileges.chroot_path.clone()) + .user(config.privileges.user.clone()) + .apply() + .context("Couldn't drop root privileges")?; + } + + break + } + } + } + + let response_channel_sender = ResponseChannelSender::new(out_message_senders); + + { + let config = config.clone(); + let state = state.clone(); + + Builder::new().name("request".to_string()).spawn(move || { + handler::run_request_worker( + config, + state, + request_channel_receiver, + response_channel_sender, + ); + })?; + } + + loop { + ::std::thread::sleep(Duration::from_secs(config.cleaning.interval)); + + tasks::clean_torrents(&state); + } +} + + +// identical to ws version +pub fn create_tls_acceptor( + config: &Config, +) -> anyhow::Result> { + if config.network.use_tls { + let mut identity_bytes = Vec::new(); + let mut file = File::open(&config.network.tls_pkcs12_path) + .context("Couldn't open pkcs12 identity file")?; + + file.read_to_end(&mut identity_bytes) + .context("Couldn't read pkcs12 identity file")?; + + let identity = Identity::from_pkcs12( + &mut identity_bytes, + &config.network.tls_pkcs12_password + ).context("Couldn't parse pkcs12 identity file")?; + + let acceptor = TlsAcceptor::new(identity) + .context("Couldn't create TlsAcceptor from pkcs12 identity")?; + + Ok(Some(acceptor)) + } else { + Ok(None) + } +} \ No newline at end of file diff --git a/aquatic_http/src/lib/network/connection.rs b/aquatic_http/src/lib/network/connection.rs new file mode 100644 index 0000000..a79c8ef --- /dev/null +++ b/aquatic_http/src/lib/network/connection.rs @@ -0,0 +1,278 @@ +use std::net::{SocketAddr}; +use std::io::{Read, Write}; + +use either::Either; +use hashbrown::HashMap; +use log::info; +use mio::Token; +use mio::net::TcpStream; +use native_tls::{TlsAcceptor, TlsStream, MidHandshakeTlsStream}; +use tungstenite::WebSocket; +use tungstenite::handshake::{MidHandshake, HandshakeError, server::NoCallback}; +use tungstenite::server::{ServerHandshake}; +use tungstenite::protocol::WebSocketConfig; + +use crate::common::*; + + +pub enum Stream { + TcpStream(TcpStream), + TlsStream(TlsStream), +} + + +impl Stream { + #[inline] + pub fn get_peer_addr(&self) -> SocketAddr { + match self { + Self::TcpStream(stream) => stream.peer_addr().unwrap(), + Self::TlsStream(stream) => stream.get_ref().peer_addr().unwrap(), + } + } +} + + +impl Read for Stream { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> Result { + match self { + Self::TcpStream(stream) => stream.read(buf), + Self::TlsStream(stream) => stream.read(buf), + } + } + + /// Not used but provided for completeness + #[inline] + fn read_vectored( + &mut self, + bufs: &mut [::std::io::IoSliceMut<'_>] + ) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.read_vectored(bufs), + Self::TlsStream(stream) => stream.read_vectored(bufs), + } + } +} + + +impl Write for Stream { + #[inline] + fn write(&mut self, buf: &[u8]) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.write(buf), + Self::TlsStream(stream) => stream.write(buf), + } + } + + /// Not used but provided for completeness + #[inline] + fn write_vectored( + &mut self, + bufs: &[::std::io::IoSlice<'_>] + ) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.write_vectored(bufs), + Self::TlsStream(stream) => stream.write_vectored(bufs), + } + } + + #[inline] + fn flush(&mut self) -> ::std::io::Result<()> { + match self { + Self::TcpStream(stream) => stream.flush(), + Self::TlsStream(stream) => stream.flush(), + } + } +} + + +enum HandshakeMachine { + TcpStream(TcpStream), + TlsStream(TlsStream), + TlsMidHandshake(MidHandshakeTlsStream), + WsMidHandshake(MidHandshake>), +} + + +impl HandshakeMachine { + #[inline] + fn new(tcp_stream: TcpStream) -> Self { + Self::TcpStream(tcp_stream) + } + + #[inline] + fn advance( + self, + ws_config: WebSocketConfig, + opt_tls_acceptor: &Option, // If set, run TLS + ) -> (Option>, bool) { // bool = stop looping + match self { + HandshakeMachine::TcpStream(stream) => { + if let Some(tls_acceptor) = opt_tls_acceptor { + Self::handle_tls_handshake_result( + tls_acceptor.accept(stream) + ) + } else { + let handshake_result = ::tungstenite::server::accept_with_config( + Stream::TcpStream(stream), + Some(ws_config) + ); + + Self::handle_ws_handshake_result(handshake_result) + } + }, + HandshakeMachine::TlsStream(stream) => { + let handshake_result = ::tungstenite::server::accept( + Stream::TlsStream(stream), + ); + + Self::handle_ws_handshake_result(handshake_result) + }, + HandshakeMachine::TlsMidHandshake(handshake) => { + Self::handle_tls_handshake_result(handshake.handshake()) + }, + HandshakeMachine::WsMidHandshake(handshake) => { + Self::handle_ws_handshake_result(handshake.handshake()) + }, + } + } + + #[inline] + fn handle_tls_handshake_result( + result: Result, ::native_tls::HandshakeError>, + ) -> (Option>, bool) { + match result { + Ok(stream) => { + (Some(Either::Right(Self::TlsStream(stream))), false) + }, + Err(native_tls::HandshakeError::WouldBlock(handshake)) => { + (Some(Either::Right(Self::TlsMidHandshake(handshake))), true) + }, + Err(native_tls::HandshakeError::Failure(err)) => { + info!("tls handshake error: {}", err); + + (None, false) + } + } + } + + #[inline] + fn handle_ws_handshake_result( + result: Result, HandshakeError>> , + ) -> (Option>, bool) { + match result { + Ok(mut ws) => { + let peer_addr = ws.get_mut().get_peer_addr(); + + let established_ws = EstablishedWs { + ws, + peer_addr, + }; + + (Some(Either::Left(established_ws)), false) + }, + Err(HandshakeError::Interrupted(handshake)) => { + (Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))), true) + }, + Err(HandshakeError::Failure(err)) => { + info!("ws handshake error: {}", err); + + (None, false) + } + } + } +} + + +pub struct EstablishedWs { + pub ws: WebSocket, + pub peer_addr: SocketAddr, +} + + +pub struct Connection { + ws_config: WebSocketConfig, + pub valid_until: ValidUntil, + inner: Either, +} + + +/// Create from TcpStream. Run `advance_handshakes` until `get_established_ws` +/// returns Some(EstablishedWs). +/// +/// advance_handshakes takes ownership of self because the TLS and WebSocket +/// handshake methods do. get_established_ws doesn't, since work can be done +/// on a mutable reference to a tungstenite websocket, and this way, the whole +/// Connection doesn't have to be removed from and reinserted into the +/// TorrentMap. This is also the reason for wrapping Container.inner in an +/// Either instead of combining all states into one structure just having a +/// single method for advancing handshakes and maybe returning a websocket. +impl Connection { + #[inline] + pub fn new( + ws_config: WebSocketConfig, + valid_until: ValidUntil, + tcp_stream: TcpStream, + ) -> Self { + Self { + ws_config, + valid_until, + inner: Either::Right(HandshakeMachine::new(tcp_stream)) + } + } + + #[inline] + pub fn get_established_ws<'a>(&mut self) -> Option<&mut EstablishedWs> { + match self.inner { + Either::Left(ref mut ews) => Some(ews), + Either::Right(_) => None, + } + } + + #[inline] + pub fn advance_handshakes( + self, + opt_tls_acceptor: &Option, + valid_until: ValidUntil, + ) -> (Option, bool) { + match self.inner { + Either::Left(_) => (Some(self), false), + Either::Right(machine) => { + let ws_config = self.ws_config; + + let (opt_inner, stop_loop) = machine.advance( + ws_config, + opt_tls_acceptor + ); + + let opt_new_self = opt_inner.map(|inner| Self { + ws_config, + valid_until, + inner + }); + + (opt_new_self, stop_loop) + } + } + } + + #[inline] + pub fn close(&mut self){ + if let Either::Left(ref mut ews) = self.inner { + if ews.ws.can_read(){ + ews.ws.close(None).unwrap(); + + // Required after ws.close() + if let Err(err) = ews.ws.write_pending(){ + info!( + "error writing pending messages after closing ws: {}", + err + ) + } + } + } + } +} + + +pub type ConnectionMap = HashMap; \ No newline at end of file diff --git a/aquatic_http/src/lib/network/mod.rs b/aquatic_http/src/lib/network/mod.rs new file mode 100644 index 0000000..7479092 --- /dev/null +++ b/aquatic_http/src/lib/network/mod.rs @@ -0,0 +1,279 @@ +use std::time::Duration; +use std::io::ErrorKind; + +use hashbrown::HashMap; +use log::{info, debug, error}; +use native_tls::TlsAcceptor; +use mio::{Events, Poll, Interest, Token}; +use mio::net::TcpListener; +use tungstenite::protocol::WebSocketConfig; + +use crate::common::*; +use crate::config::Config; +use crate::protocol::*; + +pub mod connection; +pub mod utils; + +use connection::*; +use utils::*; + + +// will be almost identical to ws version +pub fn run_socket_worker( + config: Config, + socket_worker_index: usize, + socket_worker_statuses: SocketWorkerStatuses, + request_channel_sender: InMessageSender, + response_channel_receiver: OutMessageReceiver, + opt_tls_acceptor: Option, +){ + match create_listener(&config){ + Ok(listener) => { + socket_worker_statuses.lock()[socket_worker_index] = Some(Ok(())); + + run_poll_loop( + config, + socket_worker_index, + request_channel_sender, + response_channel_receiver, + listener, + opt_tls_acceptor + ); + }, + Err(err) => { + socket_worker_statuses.lock()[socket_worker_index] = Some( + Err(format!("Couldn't open socket: {:#}", err)) + ); + } + } +} + + +// will be almost identical to ws version +pub fn run_poll_loop( + config: Config, + socket_worker_index: usize, + request_channel_sender: InMessageSender, + response_channel_receiver: OutMessageReceiver, + listener: ::std::net::TcpListener, + opt_tls_acceptor: Option, +){ + let poll_timeout = Duration::from_millis( + config.network.poll_timeout_milliseconds + ); + let ws_config = WebSocketConfig { + max_message_size: Some(config.network.websocket_max_message_size), + max_frame_size: Some(config.network.websocket_max_frame_size), + max_send_queue: None, + }; + + let mut listener = TcpListener::from_std(listener); + let mut poll = Poll::new().expect("create poll"); + let mut events = Events::with_capacity(config.network.poll_event_capacity); + + poll.registry() + .register(&mut listener, Token(0), Interest::READABLE) + .unwrap(); + + let mut connections: ConnectionMap = HashMap::new(); + + let mut poll_token_counter = Token(0usize); + let mut iter_counter = 0usize; + + loop { + poll.poll(&mut events, Some(poll_timeout)) + .expect("failed polling"); + + let valid_until = ValidUntil::new(config.cleaning.max_connection_age); + + for event in events.iter(){ + let token = event.token(); + + if token.0 == 0 { + accept_new_streams( + ws_config, + &mut listener, + &mut poll, + &mut connections, + valid_until, + &mut poll_token_counter, + ); + } else { + run_handshake_and_read_requests( + socket_worker_index, + &request_channel_sender, + &opt_tls_acceptor, + &mut connections, + token, + valid_until, + ); + } + } + + send_out_messages( + response_channel_receiver.drain(), + &mut connections + ); + + // Remove inactive connections, but not every iteration + if iter_counter % 128 == 0 { + remove_inactive_connections(&mut connections); + } + + iter_counter = iter_counter.wrapping_add(1); + } +} + + +// will be identical to ws version +fn accept_new_streams( + ws_config: WebSocketConfig, + listener: &mut TcpListener, + poll: &mut Poll, + connections: &mut ConnectionMap, + valid_until: ValidUntil, + poll_token_counter: &mut Token, +){ + loop { + match listener.accept(){ + Ok((mut stream, _)) => { + poll_token_counter.0 = poll_token_counter.0.wrapping_add(1); + + if poll_token_counter.0 == 0 { + poll_token_counter.0 = 1; + } + + let token = *poll_token_counter; + + remove_connection_if_exists(connections, token); + + poll.registry() + .register(&mut stream, token, Interest::READABLE) + .unwrap(); + + let connection = Connection::new(ws_config, valid_until, stream); + + connections.insert(token, connection); + }, + Err(err) => { + if err.kind() == ErrorKind::WouldBlock { + break + } + + info!("error while accepting streams: {}", err); + } + } + } +} + + +/// On the stream given by poll_token, get TLS (if requested) and tungstenite +/// up and running, then read messages and pass on through channel. +pub fn run_handshake_and_read_requests( + socket_worker_index: usize, + request_channel_sender: &InMessageSender, + opt_tls_acceptor: &Option, // If set, run TLS + connections: &mut ConnectionMap, + poll_token: Token, + valid_until: ValidUntil, +){ + loop { + if let Some(established_ws) = connections.get_mut(&poll_token) + .and_then(Connection::get_established_ws) + { + use ::tungstenite::Error::Io; + + match established_ws.ws.read_message(){ + Ok(ws_message) => { + if let Some(in_message) = InMessage::from_ws_message(ws_message){ + let meta = ConnectionMeta { + worker_index: socket_worker_index, + poll_token, + peer_addr: established_ws.peer_addr + }; + + debug!("read message"); + + if let Err(err) = request_channel_sender + .send((meta, in_message)) + { + error!( + "InMessageSender: couldn't send message: {:?}", + err + ); + } + } + }, + Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => { + break; + }, + Err(tungstenite::Error::ConnectionClosed) => { + remove_connection_if_exists(connections, poll_token); + + break + }, + Err(err) => { + info!("error reading messages: {}", err); + + remove_connection_if_exists(connections, poll_token); + + break; + } + } + } else if let Some(connection) = connections.remove(&poll_token){ + let (opt_new_connection, stop_loop) = connection.advance_handshakes( + opt_tls_acceptor, + valid_until + ); + + if let Some(connection) = opt_new_connection { + connections.insert(poll_token, connection); + } + + if stop_loop { + break; + } + } + } +} + + +/// Read messages from channel, send to peers +pub fn send_out_messages( + response_channel_receiver: ::flume::Drain<(ConnectionMeta, OutMessage)>, + connections: &mut ConnectionMap, +){ + for (meta, out_message) in response_channel_receiver { + let opt_established_ws = connections.get_mut(&meta.poll_token) + .and_then(Connection::get_established_ws); + + if let Some(established_ws) = opt_established_ws { + if established_ws.peer_addr != meta.peer_addr { + info!("socket worker error: peer socket addrs didn't match"); + + continue; + } + + use ::tungstenite::Error::Io; + + match established_ws.ws.write_message(out_message.to_ws_message()){ + Ok(()) => { + debug!("sent message"); + }, + Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => {}, + Err(tungstenite::Error::ConnectionClosed) => { + remove_connection_if_exists(connections, meta.poll_token); + }, + Err(err) => { + info!("error writing ws message: {}", err); + + remove_connection_if_exists( + connections, + meta.poll_token + ); + }, + } + } + } +} \ No newline at end of file diff --git a/aquatic_http/src/lib/network/utils.rs b/aquatic_http/src/lib/network/utils.rs new file mode 100644 index 0000000..b10112d --- /dev/null +++ b/aquatic_http/src/lib/network/utils.rs @@ -0,0 +1,74 @@ +use std::time::Instant; + +use anyhow::Context; +use mio::Token; +use socket2::{Socket, Domain, Type, Protocol}; + +use crate::config::Config; + +use super::connection::*; + + +// will be identical to ws version +pub fn create_listener( + config: &Config +) -> ::anyhow::Result<::std::net::TcpListener> { + let builder = if config.network.address.is_ipv4(){ + Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp())) + } else { + Socket::new(Domain::ipv6(), Type::stream(), Some(Protocol::tcp())) + }.context("Couldn't create socket2::Socket")?; + + if config.network.ipv6_only { + builder.set_only_v6(true) + .context("Couldn't put socket in ipv6 only mode")? + } + + builder.set_nonblocking(true) + .context("Couldn't put socket in non-blocking mode")?; + builder.set_reuse_port(true) + .context("Couldn't put socket in reuse_port mode")?; + builder.bind(&config.network.address.into()).with_context(|| + format!("Couldn't bind socket to address {}", config.network.address) + )?; + builder.listen(128) + .context("Couldn't listen for connections on socket")?; + + Ok(builder.into_tcp_listener()) +} + + +// will be identical to ws version +/// Don't bother with deregistering from Poll. In my understanding, this is +/// done automatically when the stream is dropped, as long as there are no +/// other references to the file descriptor, such as when it is accessed +/// in multiple threads. +pub fn remove_connection_if_exists( + connections: &mut ConnectionMap, + token: Token, +){ + if let Some(mut connection) = connections.remove(&token){ + connection.close(); + } +} + + +// will be identical to ws version +// Close and remove inactive connections +pub fn remove_inactive_connections( + connections: &mut ConnectionMap, +){ + let now = Instant::now(); + + connections.retain(|_, connection| { + if connection.valid_until.0 < now { + connection.close(); + + false + } else { + true + } + }); + + connections.shrink_to_fit(); +} \ No newline at end of file diff --git a/aquatic_http/src/lib/protocol/mod.rs b/aquatic_http/src/lib/protocol/mod.rs new file mode 100644 index 0000000..bf61d6e --- /dev/null +++ b/aquatic_http/src/lib/protocol/mod.rs @@ -0,0 +1,147 @@ +use std::net::IpAddr; +use hashbrown::HashMap; +use serde::{Serialize, Deserialize}; + +use crate::common::Peer; + +mod serde_helpers; + +use serde_helpers::*; + + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct PeerId( + #[serde( + deserialize_with = "deserialize_20_bytes", + serialize_with = "serialize_20_bytes" + )] + pub [u8; 20] +); + + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct InfoHash( + #[serde( + deserialize_with = "deserialize_20_bytes", + serialize_with = "serialize_20_bytes" + )] + pub [u8; 20] +); + + +#[derive(Clone, Copy, Debug, Serialize)] +pub struct ResponsePeer { + pub ip_address: IpAddr, + pub port: u16 +} + + +impl ResponsePeer { + pub fn from_peer(peer: &Peer) -> Self { + let ip_address = peer.connection_meta.peer_addr.ip(); + + Self { + ip_address, + port: peer.port + } + } +} + + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AnnounceEvent { + Started, + Stopped, + Completed, + Empty +} + + +impl Default for AnnounceEvent { + fn default() -> Self { + Self::Empty + } +} + + +#[derive(Debug, Clone, Deserialize)] +pub struct AnnounceRequest { + pub info_hash: InfoHash, + pub peer_id: PeerId, + pub port: u16, + #[serde(rename = "left")] + pub bytes_left: usize, + #[serde(default)] + pub event: AnnounceEvent, + /// FIXME: number: 0 or 1 + pub compact: bool, + /// Requested number of peers to return + pub numwant: usize, +} + + +#[derive(Debug, Clone, Serialize)] +pub struct AnnounceResponseSuccess { + #[serde(rename = "interval")] + pub announce_interval: usize, + pub tracker_id: String, // Optional?? + pub complete: usize, + pub incomplete: usize, + pub peers: Vec, +} + + +#[derive(Debug, Clone, Serialize)] +pub struct AnnounceResponseFailure { + pub failure_reason: String, +} + + +#[derive(Debug, Clone, Deserialize)] +pub struct ScrapeRequest { + #[serde( + rename = "info_hash", + deserialize_with = "deserialize_info_hashes", + default + )] + pub info_hashes: Vec, +} + + +#[derive(Debug, Clone, Serialize)] +pub struct ScrapeStatistics { + pub complete: usize, + pub incomplete: usize, + pub downloaded: usize, +} + + +#[derive(Debug, Clone, Serialize)] +pub struct ScrapeResponse { + pub files: HashMap, +} + + +#[derive(Debug, Clone, Deserialize)] +pub enum Request { + Announce(AnnounceRequest), + Scrape(ScrapeRequest), +} + + +impl Request { + pub fn from_http() -> Self { + unimplemented!() + } +} + + +#[derive(Debug, Clone, Serialize)] +pub enum Response { + AnnounceSuccess(AnnounceResponseSuccess), + AnnounceFailure(AnnounceResponseFailure), + Scrape(ScrapeResponse) +} \ No newline at end of file diff --git a/aquatic_http/src/lib/protocol/serde_helpers.rs b/aquatic_http/src/lib/protocol/serde_helpers.rs new file mode 100644 index 0000000..f7cea32 --- /dev/null +++ b/aquatic_http/src/lib/protocol/serde_helpers.rs @@ -0,0 +1,247 @@ +use serde::{Serializer, Deserializer, de::{Visitor, SeqAccess}}; + +use super::InfoHash; + + +pub fn serialize_20_bytes( + data: &[u8; 20], + serializer: S +) -> Result where S: Serializer { + let text: String = data.iter().map(|byte| *byte as char).collect(); + + serializer.serialize_str(&text) +} + + +struct TwentyByteVisitor; + +impl<'de> Visitor<'de> for TwentyByteVisitor { + type Value = [u8; 20]; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("string consisting of 20 bytes") + } + + #[inline] + fn visit_str(self, value: &str) -> Result + where E: ::serde::de::Error, + { + // Value is encoded in nodejs reference client something as follows: + // ``` + // var infoHash = 'abcd..'; // 40 hexadecimals + // Buffer.from(infoHash, 'hex').toString('binary'); + // ``` + // As I understand it: + // - the code above produces a UTF16 string of 20 chars, each having + // only the "low byte" set (e.g., numeric value ranges from 0-255) + // - serde_json decodes this to string of 20 chars (tested), each in + // the aforementioned range (tested), so the bytes can be extracted + // by casting each char to u8. + + let mut arr = [0u8; 20]; + let mut char_iter = value.chars(); + + for a in arr.iter_mut(){ + if let Some(c) = char_iter.next(){ + if c as u32 > 255 { + return Err(E::custom(format!( + "character not in single byte range: {:#?}", + c + ))); + } + + *a = c as u8; + } else { + return Err(E::custom(format!("not 20 bytes: {:#?}", value))); + } + } + + Ok(arr) + } +} + + +#[inline] +pub fn deserialize_20_bytes<'de, D>( + deserializer: D +) -> Result<[u8; 20], D::Error> + where D: Deserializer<'de> +{ + deserializer.deserialize_any(TwentyByteVisitor) +} + + +pub struct InfoHashVecVisitor; + + +impl<'de> Visitor<'de> for InfoHashVecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("string or array of strings consisting of 20 bytes") + } + + #[inline] + fn visit_str(self, value: &str) -> Result + where E: ::serde::de::Error, + { + match TwentyByteVisitor::visit_str::(TwentyByteVisitor, value){ + Ok(arr) => Ok(vec![InfoHash(arr)]), + Err(err) => Err(E::custom(format!("got string, but {}", err))) + } + } + + #[inline] + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> + { + let mut info_hashes: Self::Value = Vec::new(); + + while let Ok(Some(value)) = seq.next_element::<&str>(){ + let arr = TwentyByteVisitor::visit_str( + TwentyByteVisitor, value + )?; + + info_hashes.push(InfoHash(arr)); + } + + Ok(info_hashes) + } + + #[inline] + fn visit_none(self) -> Result + where E: ::serde::de::Error + { + Ok(vec![]) + } +} + + +/// Empty vector is returned if value is null or any invalid info hash +/// is present +#[inline] +pub fn deserialize_info_hashes<'de, D>( + deserializer: D +) -> Result, D::Error> + where D: Deserializer<'de>, +{ + Ok(deserializer.deserialize_any(InfoHashVecVisitor).unwrap_or_default()) +} + + +#[cfg(test)] +mod tests { + use serde::Deserialize; + + use super::*; + + fn info_hash_from_bytes(bytes: &[u8]) -> InfoHash { + let mut arr = [0u8; 20]; + + assert!(bytes.len() == 20); + + arr.copy_from_slice(&bytes[..]); + + InfoHash(arr) + } + + #[test] + fn test_deserialize_20_bytes(){ + let input = r#""aaaabbbbccccddddeeee""#; + + let expected = info_hash_from_bytes(b"aaaabbbbccccddddeeee"); + let observed: InfoHash = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + + let input = r#""aaaabbbbccccddddeee""#; + let res_info_hash: Result = serde_json::from_str(input); + + assert!(res_info_hash.is_err()); + + let input = r#""aaaabbbbccccddddeee𝕊""#; + let res_info_hash: Result = serde_json::from_str(input); + + assert!(res_info_hash.is_err()); + } + + #[test] + fn test_serde_20_bytes(){ + let info_hash = info_hash_from_bytes(b"aaaabbbbccccddddeeee"); + + let out = serde_json::to_string(&info_hash).unwrap(); + let info_hash_2 = serde_json::from_str(&out).unwrap(); + + assert_eq!(info_hash, info_hash_2); + } + + #[derive(Debug, PartialEq, Eq, Deserialize)] + struct Test { + #[serde(deserialize_with = "deserialize_info_hashes", default)] + info_hashes: Vec, + } + + + #[test] + fn test_deserialize_info_hashes_vec(){ + let input = r#"{ + "info_hashes": ["aaaabbbbccccddddeeee", "aaaabbbbccccddddeeee"] + }"#; + + let expected = Test { + info_hashes: vec![ + info_hash_from_bytes(b"aaaabbbbccccddddeeee"), + info_hash_from_bytes(b"aaaabbbbccccddddeeee"), + ] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } + + #[test] + fn test_deserialize_info_hashes_str(){ + let input = r#"{ + "info_hashes": "aaaabbbbccccddddeeee" + }"#; + + let expected = Test { + info_hashes: vec![ + info_hash_from_bytes(b"aaaabbbbccccddddeeee"), + ] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } + + #[test] + fn test_deserialize_info_hashes_null(){ + let input = r#"{ + "info_hashes": null + }"#; + + let expected = Test { + info_hashes: vec![] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } + + #[test] + fn test_deserialize_info_hashes_missing(){ + let input = r#"{}"#; + + let expected = Test { + info_hashes: vec![] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } +} \ No newline at end of file diff --git a/aquatic_http/src/lib/tasks.rs b/aquatic_http/src/lib/tasks.rs new file mode 100644 index 0000000..2415650 --- /dev/null +++ b/aquatic_http/src/lib/tasks.rs @@ -0,0 +1,28 @@ +use std::time::Instant; + +use crate::common::*; + + +// identical to ws version +pub fn clean_torrents(state: &State){ + fn clean_torrent_map( + torrent_map: &mut TorrentMap, + ){ + let now = Instant::now(); + + torrent_map.retain(|_, torrent_data| { + torrent_data.peers.retain(|_, peer| { + peer.valid_until.0 >= now + }); + + !torrent_data.peers.is_empty() + }); + + torrent_map.shrink_to_fit(); + } + + let mut torrent_maps = state.torrent_maps.lock(); + + clean_torrent_map(&mut torrent_maps.ipv4); + clean_torrent_map(&mut torrent_maps.ipv6); +} \ No newline at end of file