http_private: get tls working

This commit is contained in:
Joakim Frostegård 2022-04-03 19:07:55 +02:00
parent e790727bc0
commit c21ed97cb2
5 changed files with 211 additions and 9 deletions

View file

@ -1,5 +1,6 @@
pub mod db;
mod routes;
mod tls;
use std::{
net::{SocketAddr, TcpListener},
@ -7,13 +8,23 @@ use std::{
};
use anyhow::Context;
use axum::{routing::get, Extension, Router};
use aquatic_common::rustls_config::RustlsConfig;
use axum::{extract::connect_info::Connected, routing::get, Extension, Router};
use hyper::server::conn::AddrIncoming;
use sqlx::mysql::MySqlPoolOptions;
use self::tls::{TlsAcceptor, TlsStream};
use crate::{common::ChannelRequestSender, config::Config};
impl<'a> Connected<&'a tls::TlsStream> for SocketAddr {
fn connect_info(target: &'a TlsStream) -> Self {
target.get_remote_addr()
}
}
pub fn run_socket_worker(
config: Config,
tls_config: Arc<RustlsConfig>,
request_sender: ChannelRequestSender,
) -> anyhow::Result<()> {
let tcp_listener = create_tcp_listener(config.network.address)?;
@ -22,19 +33,25 @@ pub fn run_socket_worker(
.enable_all()
.build()?;
runtime.block_on(run_app(config, tcp_listener, request_sender))?;
runtime.block_on(run_app(config, tls_config, tcp_listener, request_sender))?;
Ok(())
}
async fn run_app(
config: Config,
tls_config: Arc<RustlsConfig>,
tcp_listener: TcpListener,
request_sender: ChannelRequestSender,
) -> anyhow::Result<()> {
let db_url =
::std::env::var("DATABASE_URL").with_context(|| "Retrieve env var DATABASE_URL")?;
let tls_acceptor = TlsAcceptor::new(
tls_config,
AddrIncoming::from_listener(tokio::net::TcpListener::from_std(tcp_listener)?)?,
);
let pool = MySqlPoolOptions::new()
.max_connections(5)
.connect(&db_url)
@ -46,7 +63,7 @@ async fn run_app(
.layer(Extension(pool))
.layer(Extension(Arc::new(request_sender)));
axum::Server::from_tcp(tcp_listener)?
axum::Server::builder(tls_acceptor)
.http1_keepalive(false)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
@ -66,6 +83,9 @@ fn create_tcp_listener(addr: SocketAddr) -> anyhow::Result<TcpListener> {
socket
.set_reuse_port(true)
.with_context(|| "set_reuse_port")?;
socket
.set_nonblocking(true)
.with_context(|| "set_nonblocking")?;
socket
.bind(&addr.into())
.with_context(|| format!("bind to {}", addr))?;

View file

@ -0,0 +1,151 @@
//! hyper/rustls integration
//!
//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
//! otherwise HTTP/1.1 will be used.
//!
//! Based on https://github.com/rustls/hyper-rustls/blob/9b7b1220f74de9b249ce2b8f8b922fd00074c53b/examples/server.rs
// ISC License (ISC)
// Copyright (c) 2016, Joseph Birr-Pixton <jpixton@gmail.com>
//
// Permission to use, copy, modify, and/or distribute this software for
// any purpose with or without fee is hereby granted, provided that the
// above copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL
// DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR
// PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
// ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF
// THIS SOFTWARE.
use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
enum State {
Handshaking(tokio_rustls::Accept<AddrStream>, SocketAddr),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}
impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let remote_addr = stream.remote_addr();
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept, remote_addr),
}
}
pub fn get_remote_addr(&self) -> SocketAddr {
match &self.state {
State::Handshaking(_, remote_addr) => *remote_addr,
State::Streaming(stream) => stream.get_ref().0.remote_addr(),
}
}
}
impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept, ref mut socket_addr) => {
match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
*socket_addr = stream.get_ref().0.remote_addr();
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept, _) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_, _) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_, _) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}
impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}
impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}