aquatic_http: simplify tls handshake code further

This commit is contained in:
Joakim Frostegård 2020-07-02 22:43:54 +02:00
parent ccfa03f6cc
commit d5b82bcf70
2 changed files with 49 additions and 48 deletions

View file

@ -1,13 +1,14 @@
use std::net::{SocketAddr}; use std::net::{SocketAddr};
use std::io::ErrorKind; use std::io::ErrorKind;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::sync::Arc;
use either::Either; use either::Either;
use hashbrown::HashMap; use hashbrown::HashMap;
use log::info; use log::info;
use mio::Token; use mio::Token;
use mio::net::TcpStream; use mio::net::TcpStream;
use native_tls::{TlsAcceptor, TlsStream, MidHandshakeTlsStream}; use native_tls::{TlsAcceptor, MidHandshakeTlsStream};
use aquatic_common_tcp::network::stream::Stream; use aquatic_common_tcp::network::stream::Stream;
@ -109,55 +110,58 @@ impl EstablishedConnection {
} }
pub enum HandshakeMachine { enum HandshakeMachineInner {
TcpStream(TcpStream), TcpStream(TcpStream),
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>), TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
} }
impl <'a>HandshakeMachine { pub struct TlsHandshakeMachine {
tls_acceptor: Arc<TlsAcceptor>,
inner: HandshakeMachineInner,
}
impl <'a>TlsHandshakeMachine {
#[inline] #[inline]
fn new(tcp_stream: TcpStream) -> Self { fn new(
Self::TcpStream(tcp_stream) tls_acceptor: Arc<TlsAcceptor>,
tcp_stream: TcpStream
) -> Self {
Self {
tls_acceptor,
inner: HandshakeMachineInner::TcpStream(tcp_stream)
}
} }
#[inline] #[inline]
pub fn advance( pub fn advance(
self, self,
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS ) -> (Option<Either<EstablishedConnection, Self>>, bool) { // bool = would block
) -> (Option<Either<EstablishedConnection, Self>>, bool) { // bool = stop looping let handshake_result = match self.inner {
match self { HandshakeMachineInner::TcpStream(stream) => {
HandshakeMachine::TcpStream(stream) => { self.tls_acceptor.accept(stream)
if let Some(tls_acceptor) = opt_tls_acceptor {
Self::handle_tls_handshake_result(
tls_acceptor.accept(stream)
)
} else {
log::debug!("established connection");
let established_connection = EstablishedConnection::new(
Stream::TcpStream(stream)
);
(Some(Either::Left(established_connection)), false)
}
}, },
HandshakeMachine::TlsMidHandshake(handshake) => { HandshakeMachineInner::TlsMidHandshake(handshake) => {
Self::handle_tls_handshake_result(handshake.handshake()) handshake.handshake()
}, },
} };
}
#[inline] match handshake_result {
fn handle_tls_handshake_result(
result: Result<TlsStream<TcpStream>, ::native_tls::HandshakeError<TcpStream>>,
) -> (Option<Either<EstablishedConnection, Self>>, bool) {
match result {
Ok(stream) => { Ok(stream) => {
(Some(Either::Left(EstablishedConnection::new(Stream::TlsStream(stream)))), false) let established = EstablishedConnection::new(
Stream::TlsStream(stream)
);
(Some(Either::Left(established)), false)
}, },
Err(native_tls::HandshakeError::WouldBlock(handshake)) => { Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
(Some(Either::Right(Self::TlsMidHandshake(handshake))), true) let machine = Self {
tls_acceptor: self.tls_acceptor,
inner: HandshakeMachineInner::TlsMidHandshake(handshake),
};
(Some(Either::Right(machine)), true)
}, },
Err(native_tls::HandshakeError::Failure(err)) => { Err(native_tls::HandshakeError::Failure(err)) => {
info!("tls handshake error: {}", err); info!("tls handshake error: {}", err);
@ -171,22 +175,21 @@ impl <'a>HandshakeMachine {
pub struct Connection { pub struct Connection {
pub valid_until: ValidUntil, pub valid_until: ValidUntil,
pub inner: Either<EstablishedConnection, HandshakeMachine>, pub inner: Either<EstablishedConnection, TlsHandshakeMachine>,
} }
impl Connection { impl Connection {
#[inline] #[inline]
pub fn new( pub fn new(
use_tls: bool, opt_tls_acceptor: &Option<Arc<TlsAcceptor>>,
valid_until: ValidUntil, valid_until: ValidUntil,
tcp_stream: TcpStream, tcp_stream: TcpStream,
) -> Self { ) -> Self {
let inner = if use_tls { // Setup handshake machine if TLS is requested
Either::Right(HandshakeMachine::new(tcp_stream)) let inner = if let Some(tls_acceptor) = opt_tls_acceptor {
Either::Right(TlsHandshakeMachine::new(tls_acceptor.clone(), tcp_stream))
} else { } else {
// If no TLS should be used, just go directly to established
// connection
Either::Left(EstablishedConnection::new(Stream::TcpStream(tcp_stream))) Either::Left(EstablishedConnection::new(Stream::TcpStream(tcp_stream)))
}; };

View file

@ -2,6 +2,7 @@ pub mod connection;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::io::ErrorKind; use std::io::ErrorKind;
use std::sync::Arc;
use hashbrown::HashMap; use hashbrown::HashMap;
use log::{info, debug, error}; use log::{info, debug, error};
@ -24,7 +25,7 @@ fn accept_new_streams(
connections: &mut ConnectionMap, connections: &mut ConnectionMap,
valid_until: ValidUntil, valid_until: ValidUntil,
poll_token_counter: &mut Token, poll_token_counter: &mut Token,
use_tls: bool, opt_tls_acceptor: &Option<Arc<TlsAcceptor>>,
){ ){
loop { loop {
match listener.accept(){ match listener.accept(){
@ -44,7 +45,7 @@ fn accept_new_streams(
.register(&mut stream, token, Interest::READABLE) .register(&mut stream, token, Interest::READABLE)
.unwrap(); .unwrap();
let connection = Connection::new(use_tls, valid_until, stream); let connection = Connection::new(opt_tls_acceptor, valid_until, stream);
connections.insert(token, connection); connections.insert(token, connection);
}, },
@ -65,7 +66,6 @@ fn accept_new_streams(
pub fn run_handshake_and_read_requests( pub fn run_handshake_and_read_requests(
socket_worker_index: usize, socket_worker_index: usize,
request_channel_sender: &RequestChannelSender, request_channel_sender: &RequestChannelSender,
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS
connections: &mut ConnectionMap, connections: &mut ConnectionMap,
poll_token: Token, poll_token: Token,
valid_until: ValidUntil, valid_until: ValidUntil,
@ -115,9 +115,7 @@ pub fn run_handshake_and_read_requests(
} else if let Some(handshake_machine) = connections.remove(&poll_token) } else if let Some(handshake_machine) = connections.remove(&poll_token)
.and_then(|c| c.inner.right()) .and_then(|c| c.inner.right())
{ {
let (opt_inner, stop_loop) = handshake_machine.advance( let (opt_inner, would_block) = handshake_machine.advance();
opt_tls_acceptor
);
if let Some(inner) = opt_inner { if let Some(inner) = opt_inner {
let connection = Connection { let connection = Connection {
@ -128,7 +126,7 @@ pub fn run_handshake_and_read_requests(
connections.insert(poll_token, connection); connections.insert(poll_token, connection);
} }
if stop_loop { if would_block {
break; break;
} }
} }
@ -204,6 +202,7 @@ pub fn run_poll_loop(
.unwrap(); .unwrap();
let mut connections: ConnectionMap = HashMap::new(); let mut connections: ConnectionMap = HashMap::new();
let opt_tls_acceptor = opt_tls_acceptor.map(Arc::new);
let mut poll_token_counter = Token(0usize); let mut poll_token_counter = Token(0usize);
let mut iter_counter = 0usize; let mut iter_counter = 0usize;
@ -224,13 +223,12 @@ pub fn run_poll_loop(
&mut connections, &mut connections,
valid_until, valid_until,
&mut poll_token_counter, &mut poll_token_counter,
opt_tls_acceptor.is_some(), &opt_tls_acceptor,
); );
} else { } else {
run_handshake_and_read_requests( run_handshake_and_read_requests(
socket_worker_index, socket_worker_index,
&request_channel_sender, &request_channel_sender,
&opt_tls_acceptor,
&mut connections, &mut connections,
token, token,
valid_until, valid_until,