http_private: get tls working

This commit is contained in:
Joakim Frostegård 2022-04-03 19:07:55 +02:00
parent e790727bc0
commit c21ed97cb2
5 changed files with 211 additions and 9 deletions

19
Cargo.lock generated
View file

@ -83,6 +83,8 @@ dependencies = [
"log",
"privdrop",
"rand",
"rustls 0.20.4",
"rustls-pemfile",
"serde",
]
@ -150,14 +152,18 @@ dependencies = [
"aquatic_toml_config",
"axum",
"dotenv",
"futures-util",
"hex",
"hyper",
"log",
"mimalloc",
"rand",
"rustls 0.20.4",
"serde",
"socket2 0.4.4",
"sqlx",
"tokio",
"tokio-rustls 0.23.3",
]
[[package]]
@ -2619,7 +2625,7 @@ checksum = "b555e70fbbf84e269ec3858b7a6515bcfe7a166a7cc9c636dd6efd20431678b6"
dependencies = [
"once_cell",
"tokio",
"tokio-rustls",
"tokio-rustls 0.22.0",
]
[[package]]
@ -2797,6 +2803,17 @@ dependencies = [
"webpki 0.21.4",
]
[[package]]
name = "tokio-rustls"
version = "0.23.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4151fda0cf2798550ad0b34bcfc9b9dcc2a9d2471c895c68f3a8818e54f2389e"
dependencies = [
"rustls 0.20.4",
"tokio",
"webpki 0.22.0",
]
[[package]]
name = "tokio-stream"
version = "0.1.8"

View file

@ -11,19 +11,22 @@ name = "aquatic_http_private"
[dependencies]
aquatic_cli_helpers = "0.2.0"
aquatic_common = "0.2.0"
aquatic_common = { version = "0.2.0", features = ["rustls-config"] }
aquatic_http_protocol = "0.2.0"
aquatic_toml_config = "0.2.0"
anyhow = "1"
axum = { version = "0.5", default-features = false, features = ["headers", "http1", "matched-path", "original-uri"] }
dotenv = "0.15"
futures-util = { version = "0.3", default-features = false }
hex = "0.4"
hyper = "0.14"
log = "0.4"
mimalloc = { version = "0.1", default-features = false }
rand = { version = "0.8", features = ["small_rng"] }
rustls = "0.20"
serde = { version = "1", features = ["derive"] }
socket2 = { version = "0.4", features = ["all"] }
sqlx = { version = "0.5", features = [ "runtime-tokio-rustls" , "mysql" ] }
tokio = { version = "1", features = ["full"] }
tokio-rustls = "0.23"

View file

@ -2,18 +2,26 @@ mod common;
pub mod config;
mod workers;
use std::collections::VecDeque;
use std::{collections::VecDeque, sync::Arc};
use aquatic_common::rustls_config::create_rustls_config;
use common::ChannelRequestSender;
use dotenv::dotenv;
use tokio::sync::mpsc::channel;
use config::Config;
pub const APP_NAME: &str = "aquatic_http_private: private HTTP/TLS BitTorrent tracker";
pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
pub fn run(config: config::Config) -> anyhow::Result<()> {
pub fn run(config: Config) -> anyhow::Result<()> {
dotenv().ok();
let tls_config = Arc::new(create_rustls_config(
&config.network.tls_certificate_path,
&config.network.tls_private_key_path,
)?);
let mut request_senders = Vec::new();
let mut request_receivers = VecDeque::new();
@ -28,11 +36,14 @@ pub fn run(config: config::Config) -> anyhow::Result<()> {
for _ in 0..config.socket_workers {
let config = config.clone();
let tls_config = tls_config.clone();
let request_sender = ChannelRequestSender::new(request_senders.clone());
let handle = ::std::thread::Builder::new()
.name("socket".into())
.spawn(move || workers::socket::run_socket_worker(config, request_sender))?;
.spawn(move || {
workers::socket::run_socket_worker(config, tls_config, request_sender)
})?;
handles.push(handle);
}

View file

@ -1,5 +1,6 @@
pub mod db;
mod routes;
mod tls;
use std::{
net::{SocketAddr, TcpListener},
@ -7,13 +8,23 @@ use std::{
};
use anyhow::Context;
use axum::{routing::get, Extension, Router};
use aquatic_common::rustls_config::RustlsConfig;
use axum::{extract::connect_info::Connected, routing::get, Extension, Router};
use hyper::server::conn::AddrIncoming;
use sqlx::mysql::MySqlPoolOptions;
use self::tls::{TlsAcceptor, TlsStream};
use crate::{common::ChannelRequestSender, config::Config};
impl<'a> Connected<&'a tls::TlsStream> for SocketAddr {
fn connect_info(target: &'a TlsStream) -> Self {
target.get_remote_addr()
}
}
pub fn run_socket_worker(
config: Config,
tls_config: Arc<RustlsConfig>,
request_sender: ChannelRequestSender,
) -> anyhow::Result<()> {
let tcp_listener = create_tcp_listener(config.network.address)?;
@ -22,19 +33,25 @@ pub fn run_socket_worker(
.enable_all()
.build()?;
runtime.block_on(run_app(config, tcp_listener, request_sender))?;
runtime.block_on(run_app(config, tls_config, tcp_listener, request_sender))?;
Ok(())
}
async fn run_app(
config: Config,
tls_config: Arc<RustlsConfig>,
tcp_listener: TcpListener,
request_sender: ChannelRequestSender,
) -> anyhow::Result<()> {
let db_url =
::std::env::var("DATABASE_URL").with_context(|| "Retrieve env var DATABASE_URL")?;
let tls_acceptor = TlsAcceptor::new(
tls_config,
AddrIncoming::from_listener(tokio::net::TcpListener::from_std(tcp_listener)?)?,
);
let pool = MySqlPoolOptions::new()
.max_connections(5)
.connect(&db_url)
@ -46,7 +63,7 @@ async fn run_app(
.layer(Extension(pool))
.layer(Extension(Arc::new(request_sender)));
axum::Server::from_tcp(tcp_listener)?
axum::Server::builder(tls_acceptor)
.http1_keepalive(false)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
@ -66,6 +83,9 @@ fn create_tcp_listener(addr: SocketAddr) -> anyhow::Result<TcpListener> {
socket
.set_reuse_port(true)
.with_context(|| "set_reuse_port")?;
socket
.set_nonblocking(true)
.with_context(|| "set_nonblocking")?;
socket
.bind(&addr.into())
.with_context(|| format!("bind to {}", addr))?;

View file

@ -0,0 +1,151 @@
//! hyper/rustls integration
//!
//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
//! otherwise HTTP/1.1 will be used.
//!
//! Based on https://github.com/rustls/hyper-rustls/blob/9b7b1220f74de9b249ce2b8f8b922fd00074c53b/examples/server.rs
// ISC License (ISC)
// Copyright (c) 2016, Joseph Birr-Pixton <jpixton@gmail.com>
//
// Permission to use, copy, modify, and/or distribute this software for
// any purpose with or without fee is hereby granted, provided that the
// above copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL
// DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR
// PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
// ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF
// THIS SOFTWARE.
use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
enum State {
Handshaking(tokio_rustls::Accept<AddrStream>, SocketAddr),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}
impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let remote_addr = stream.remote_addr();
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept, remote_addr),
}
}
pub fn get_remote_addr(&self) -> SocketAddr {
match &self.state {
State::Handshaking(_, remote_addr) => *remote_addr,
State::Streaming(stream) => stream.get_ref().0.remote_addr(),
}
}
}
impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept, ref mut socket_addr) => {
match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
*socket_addr = stream.get_ref().0.remote_addr();
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept, _) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_, _) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_, _) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}
impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}
impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}