diff --git a/Cargo.lock b/Cargo.lock index 968236c..399cb3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,8 @@ dependencies = [ "log", "privdrop", "rand", + "rustls 0.20.4", + "rustls-pemfile", "serde", ] @@ -150,14 +152,18 @@ dependencies = [ "aquatic_toml_config", "axum", "dotenv", + "futures-util", "hex", + "hyper", "log", "mimalloc", "rand", + "rustls 0.20.4", "serde", "socket2 0.4.4", "sqlx", "tokio", + "tokio-rustls 0.23.3", ] [[package]] @@ -2619,7 +2625,7 @@ checksum = "b555e70fbbf84e269ec3858b7a6515bcfe7a166a7cc9c636dd6efd20431678b6" dependencies = [ "once_cell", "tokio", - "tokio-rustls", + "tokio-rustls 0.22.0", ] [[package]] @@ -2797,6 +2803,17 @@ dependencies = [ "webpki 0.21.4", ] +[[package]] +name = "tokio-rustls" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4151fda0cf2798550ad0b34bcfc9b9dcc2a9d2471c895c68f3a8818e54f2389e" +dependencies = [ + "rustls 0.20.4", + "tokio", + "webpki 0.22.0", +] + [[package]] name = "tokio-stream" version = "0.1.8" diff --git a/aquatic_http_private/Cargo.toml b/aquatic_http_private/Cargo.toml index 56189a9..bad6901 100644 --- a/aquatic_http_private/Cargo.toml +++ b/aquatic_http_private/Cargo.toml @@ -11,19 +11,22 @@ name = "aquatic_http_private" [dependencies] aquatic_cli_helpers = "0.2.0" -aquatic_common = "0.2.0" +aquatic_common = { version = "0.2.0", features = ["rustls-config"] } aquatic_http_protocol = "0.2.0" aquatic_toml_config = "0.2.0" anyhow = "1" axum = { version = "0.5", default-features = false, features = ["headers", "http1", "matched-path", "original-uri"] } dotenv = "0.15" +futures-util = { version = "0.3", default-features = false } hex = "0.4" +hyper = "0.14" log = "0.4" mimalloc = { version = "0.1", default-features = false } rand = { version = "0.8", features = ["small_rng"] } +rustls = "0.20" serde = { version = "1", features = ["derive"] } socket2 = { version = "0.4", features = ["all"] } sqlx = { version = "0.5", features = [ "runtime-tokio-rustls" , "mysql" ] } tokio = { version = "1", features = ["full"] } - +tokio-rustls = "0.23" diff --git a/aquatic_http_private/src/lib.rs b/aquatic_http_private/src/lib.rs index 23053d1..8e9d7a7 100644 --- a/aquatic_http_private/src/lib.rs +++ b/aquatic_http_private/src/lib.rs @@ -2,18 +2,26 @@ mod common; pub mod config; mod workers; -use std::collections::VecDeque; +use std::{collections::VecDeque, sync::Arc}; +use aquatic_common::rustls_config::create_rustls_config; use common::ChannelRequestSender; use dotenv::dotenv; use tokio::sync::mpsc::channel; +use config::Config; + pub const APP_NAME: &str = "aquatic_http_private: private HTTP/TLS BitTorrent tracker"; pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); -pub fn run(config: config::Config) -> anyhow::Result<()> { +pub fn run(config: Config) -> anyhow::Result<()> { dotenv().ok(); + let tls_config = Arc::new(create_rustls_config( + &config.network.tls_certificate_path, + &config.network.tls_private_key_path, + )?); + let mut request_senders = Vec::new(); let mut request_receivers = VecDeque::new(); @@ -28,11 +36,14 @@ pub fn run(config: config::Config) -> anyhow::Result<()> { for _ in 0..config.socket_workers { let config = config.clone(); + let tls_config = tls_config.clone(); let request_sender = ChannelRequestSender::new(request_senders.clone()); let handle = ::std::thread::Builder::new() .name("socket".into()) - .spawn(move || workers::socket::run_socket_worker(config, request_sender))?; + .spawn(move || { + workers::socket::run_socket_worker(config, tls_config, request_sender) + })?; handles.push(handle); } diff --git a/aquatic_http_private/src/workers/socket/mod.rs b/aquatic_http_private/src/workers/socket/mod.rs index 5f38a85..d94e3f8 100644 --- a/aquatic_http_private/src/workers/socket/mod.rs +++ b/aquatic_http_private/src/workers/socket/mod.rs @@ -1,5 +1,6 @@ pub mod db; mod routes; +mod tls; use std::{ net::{SocketAddr, TcpListener}, @@ -7,13 +8,23 @@ use std::{ }; use anyhow::Context; -use axum::{routing::get, Extension, Router}; +use aquatic_common::rustls_config::RustlsConfig; +use axum::{extract::connect_info::Connected, routing::get, Extension, Router}; +use hyper::server::conn::AddrIncoming; use sqlx::mysql::MySqlPoolOptions; +use self::tls::{TlsAcceptor, TlsStream}; use crate::{common::ChannelRequestSender, config::Config}; +impl<'a> Connected<&'a tls::TlsStream> for SocketAddr { + fn connect_info(target: &'a TlsStream) -> Self { + target.get_remote_addr() + } +} + pub fn run_socket_worker( config: Config, + tls_config: Arc, request_sender: ChannelRequestSender, ) -> anyhow::Result<()> { let tcp_listener = create_tcp_listener(config.network.address)?; @@ -22,19 +33,25 @@ pub fn run_socket_worker( .enable_all() .build()?; - runtime.block_on(run_app(config, tcp_listener, request_sender))?; + runtime.block_on(run_app(config, tls_config, tcp_listener, request_sender))?; Ok(()) } async fn run_app( config: Config, + tls_config: Arc, tcp_listener: TcpListener, request_sender: ChannelRequestSender, ) -> anyhow::Result<()> { let db_url = ::std::env::var("DATABASE_URL").with_context(|| "Retrieve env var DATABASE_URL")?; + let tls_acceptor = TlsAcceptor::new( + tls_config, + AddrIncoming::from_listener(tokio::net::TcpListener::from_std(tcp_listener)?)?, + ); + let pool = MySqlPoolOptions::new() .max_connections(5) .connect(&db_url) @@ -46,7 +63,7 @@ async fn run_app( .layer(Extension(pool)) .layer(Extension(Arc::new(request_sender))); - axum::Server::from_tcp(tcp_listener)? + axum::Server::builder(tls_acceptor) .http1_keepalive(false) .serve(app.into_make_service_with_connect_info::()) .await?; @@ -66,6 +83,9 @@ fn create_tcp_listener(addr: SocketAddr) -> anyhow::Result { socket .set_reuse_port(true) .with_context(|| "set_reuse_port")?; + socket + .set_nonblocking(true) + .with_context(|| "set_nonblocking")?; socket .bind(&addr.into()) .with_context(|| format!("bind to {}", addr))?; diff --git a/aquatic_http_private/src/workers/socket/tls.rs b/aquatic_http_private/src/workers/socket/tls.rs new file mode 100644 index 0000000..3828b29 --- /dev/null +++ b/aquatic_http_private/src/workers/socket/tls.rs @@ -0,0 +1,151 @@ +//! hyper/rustls integration +//! +//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2, +//! otherwise HTTP/1.1 will be used. +//! +//! Based on https://github.com/rustls/hyper-rustls/blob/9b7b1220f74de9b249ce2b8f8b922fd00074c53b/examples/server.rs + +// ISC License (ISC) +// Copyright (c) 2016, Joseph Birr-Pixton +// +// Permission to use, copy, modify, and/or distribute this software for +// any purpose with or without fee is hereby granted, provided that the +// above copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL +// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE +// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +// DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR +// PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +// ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF +// THIS SOFTWARE. + +use core::task::{Context, Poll}; +use futures_util::ready; +use hyper::server::accept::Accept; +use hyper::server::conn::{AddrIncoming, AddrStream}; +use std::future::Future; +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::rustls::ServerConfig; + +enum State { + Handshaking(tokio_rustls::Accept, SocketAddr), + Streaming(tokio_rustls::server::TlsStream), +} + +// tokio_rustls::server::TlsStream doesn't expose constructor methods, +// so we have to TlsAcceptor::accept and handshake to have access to it +// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first +pub struct TlsStream { + state: State, +} + +impl TlsStream { + fn new(stream: AddrStream, config: Arc) -> TlsStream { + let remote_addr = stream.remote_addr(); + let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); + + TlsStream { + state: State::Handshaking(accept, remote_addr), + } + } + + pub fn get_remote_addr(&self) -> SocketAddr { + match &self.state { + State::Handshaking(_, remote_addr) => *remote_addr, + State::Streaming(stream) => stream.get_ref().0.remote_addr(), + } + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut ReadBuf, + ) -> Poll> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept, ref mut socket_addr) => { + match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + *socket_addr = stream.get_ref().0.remote_addr(); + let result = Pin::new(&mut stream).poll_read(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + } + } + State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept, _) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_write(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.state { + State::Handshaking(_, _) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.state { + State::Handshaking(_, _) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +pub struct TlsAcceptor { + config: Arc, + incoming: AddrIncoming, +} + +impl TlsAcceptor { + pub fn new(config: Arc, incoming: AddrIncoming) -> TlsAcceptor { + TlsAcceptor { config, incoming } + } +} + +impl Accept for TlsAcceptor { + type Conn = TlsStream; + type Error = io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let pin = self.get_mut(); + match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { + Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +}