aquatic_http: refactor TlsHandshakeMachine, adding error type

This commit is contained in:
Joakim Frostegård 2020-07-04 13:10:57 +02:00
parent acf5ee5af1
commit 73b1646c71
2 changed files with 45 additions and 25 deletions

View file

@ -5,7 +5,6 @@ use std::sync::Arc;
use either::Either; use either::Either;
use hashbrown::HashMap; use hashbrown::HashMap;
use log::info;
use mio::Token; use mio::Token;
use mio::net::TcpStream; use mio::net::TcpStream;
use native_tls::{TlsAcceptor, MidHandshakeTlsStream}; use native_tls::{TlsAcceptor, MidHandshakeTlsStream};
@ -122,7 +121,13 @@ impl EstablishedConnection {
} }
enum HandshakeMachineInner { pub enum TlsHandshakeMachineError {
WouldBlock(TlsHandshakeMachine),
Failure(native_tls::Error)
}
enum TlsHandshakeMachineInner {
TcpStream(TcpStream), TcpStream(TcpStream),
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>), TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
} }
@ -130,7 +135,7 @@ enum HandshakeMachineInner {
pub struct TlsHandshakeMachine { pub struct TlsHandshakeMachine {
tls_acceptor: Arc<TlsAcceptor>, tls_acceptor: Arc<TlsAcceptor>,
inner: HandshakeMachineInner, inner: TlsHandshakeMachineInner,
} }
@ -142,19 +147,19 @@ impl <'a>TlsHandshakeMachine {
) -> Self { ) -> Self {
Self { Self {
tls_acceptor, tls_acceptor,
inner: HandshakeMachineInner::TcpStream(tcp_stream) inner: TlsHandshakeMachineInner::TcpStream(tcp_stream)
} }
} }
/// Attempt to establish a TLS connection. On a WouldBlock error, return
/// the machine wrapped in an error for later attempts.
#[inline] #[inline]
pub fn advance( pub fn establish_tls(self) -> Result<EstablishedConnection, TlsHandshakeMachineError> {
self,
) -> (Option<Either<EstablishedConnection, Self>>, bool) { // bool = would block
let handshake_result = match self.inner { let handshake_result = match self.inner {
HandshakeMachineInner::TcpStream(stream) => { TlsHandshakeMachineInner::TcpStream(stream) => {
self.tls_acceptor.accept(stream) self.tls_acceptor.accept(stream)
}, },
HandshakeMachineInner::TlsMidHandshake(handshake) => { TlsHandshakeMachineInner::TlsMidHandshake(handshake) => {
handshake.handshake() handshake.handshake()
}, },
}; };
@ -165,20 +170,22 @@ impl <'a>TlsHandshakeMachine {
Stream::TlsStream(stream) Stream::TlsStream(stream)
); );
(Some(Either::Left(established)), false) Ok(established)
}, },
Err(native_tls::HandshakeError::WouldBlock(handshake)) => { Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
let inner = TlsHandshakeMachineInner::TlsMidHandshake(
handshake
);
let machine = Self { let machine = Self {
tls_acceptor: self.tls_acceptor, tls_acceptor: self.tls_acceptor,
inner: HandshakeMachineInner::TlsMidHandshake(handshake), inner,
}; };
(Some(Either::Right(machine)), true) Err(TlsHandshakeMachineError::WouldBlock(machine))
}, },
Err(native_tls::HandshakeError::Failure(err)) => { Err(native_tls::HandshakeError::Failure(err)) => {
info!("tls handshake error: {}", err); Err(TlsHandshakeMachineError::Failure(err))
(None, false)
} }
} }
} }

View file

@ -5,6 +5,7 @@ use std::io::ErrorKind;
use std::sync::Arc; use std::sync::Arc;
use std::vec::Drain; use std::vec::Drain;
use either::Either;
use hashbrown::HashMap; use hashbrown::HashMap;
use log::{info, debug, error}; use log::{info, debug, error};
use native_tls::TlsAcceptor; use native_tls::TlsAcceptor;
@ -247,19 +248,31 @@ pub fn run_handshakes_and_read_requests(
} else if let Some(handshake_machine) = connections.remove(&poll_token) } else if let Some(handshake_machine) = connections.remove(&poll_token)
.and_then(|c| c.inner.right()) .and_then(|c| c.inner.right())
{ {
let (opt_inner, would_block) = handshake_machine.advance(); match handshake_machine.establish_tls(){
Ok(established) => {
let connection = Connection {
valid_until,
inner: Either::Left(established)
};
if let Some(inner) = opt_inner { connections.insert(poll_token, connection);
let connection = Connection { },
valid_until, Err(TlsHandshakeMachineError::WouldBlock(machine)) => {
inner let connection = Connection {
}; valid_until,
inner: Either::Right(machine)
};
connections.insert(poll_token, connection); connections.insert(poll_token, connection);
}
if would_block { break
break; },
Err(TlsHandshakeMachineError::Failure(err)) => {
info!("tls handshake error: {}", err);
// TLS negotiation error occured
break
}
} }
} }
} }