diff --git a/aquatic_http/src/lib/network/connection.rs b/aquatic_http/src/lib/network/connection.rs index 2514061..1ad7069 100644 --- a/aquatic_http/src/lib/network/connection.rs +++ b/aquatic_http/src/lib/network/connection.rs @@ -1,13 +1,14 @@ use std::net::{SocketAddr}; use std::io::ErrorKind; use std::io::{Read, Write}; +use std::sync::Arc; use either::Either; use hashbrown::HashMap; use log::info; use mio::Token; use mio::net::TcpStream; -use native_tls::{TlsAcceptor, TlsStream, MidHandshakeTlsStream}; +use native_tls::{TlsAcceptor, MidHandshakeTlsStream}; use aquatic_common_tcp::network::stream::Stream; @@ -109,55 +110,58 @@ impl EstablishedConnection { } -pub enum HandshakeMachine { +enum HandshakeMachineInner { TcpStream(TcpStream), TlsMidHandshake(MidHandshakeTlsStream), } -impl <'a>HandshakeMachine { +pub struct TlsHandshakeMachine { + tls_acceptor: Arc, + inner: HandshakeMachineInner, +} + + +impl <'a>TlsHandshakeMachine { #[inline] - fn new(tcp_stream: TcpStream) -> Self { - Self::TcpStream(tcp_stream) + fn new( + tls_acceptor: Arc, + tcp_stream: TcpStream + ) -> Self { + Self { + tls_acceptor, + inner: HandshakeMachineInner::TcpStream(tcp_stream) + } } #[inline] pub fn advance( self, - opt_tls_acceptor: &Option, // If set, run TLS - ) -> (Option>, bool) { // bool = stop looping - match self { - HandshakeMachine::TcpStream(stream) => { - if let Some(tls_acceptor) = opt_tls_acceptor { - Self::handle_tls_handshake_result( - tls_acceptor.accept(stream) - ) - } else { - log::debug!("established connection"); - - let established_connection = EstablishedConnection::new( - Stream::TcpStream(stream) - ); - - (Some(Either::Left(established_connection)), false) - } + ) -> (Option>, bool) { // bool = would block + let handshake_result = match self.inner { + HandshakeMachineInner::TcpStream(stream) => { + self.tls_acceptor.accept(stream) }, - HandshakeMachine::TlsMidHandshake(handshake) => { - Self::handle_tls_handshake_result(handshake.handshake()) + HandshakeMachineInner::TlsMidHandshake(handshake) => { + handshake.handshake() }, - } - } + }; - #[inline] - fn handle_tls_handshake_result( - result: Result, ::native_tls::HandshakeError>, - ) -> (Option>, bool) { - match result { + match handshake_result { Ok(stream) => { - (Some(Either::Left(EstablishedConnection::new(Stream::TlsStream(stream)))), false) + let established = EstablishedConnection::new( + Stream::TlsStream(stream) + ); + + (Some(Either::Left(established)), false) }, Err(native_tls::HandshakeError::WouldBlock(handshake)) => { - (Some(Either::Right(Self::TlsMidHandshake(handshake))), true) + let machine = Self { + tls_acceptor: self.tls_acceptor, + inner: HandshakeMachineInner::TlsMidHandshake(handshake), + }; + + (Some(Either::Right(machine)), true) }, Err(native_tls::HandshakeError::Failure(err)) => { info!("tls handshake error: {}", err); @@ -171,22 +175,21 @@ impl <'a>HandshakeMachine { pub struct Connection { pub valid_until: ValidUntil, - pub inner: Either, + pub inner: Either, } impl Connection { #[inline] pub fn new( - use_tls: bool, + opt_tls_acceptor: &Option>, valid_until: ValidUntil, tcp_stream: TcpStream, ) -> Self { - let inner = if use_tls { - Either::Right(HandshakeMachine::new(tcp_stream)) + // Setup handshake machine if TLS is requested + let inner = if let Some(tls_acceptor) = opt_tls_acceptor { + Either::Right(TlsHandshakeMachine::new(tls_acceptor.clone(), tcp_stream)) } else { - // If no TLS should be used, just go directly to established - // connection Either::Left(EstablishedConnection::new(Stream::TcpStream(tcp_stream))) }; diff --git a/aquatic_http/src/lib/network/mod.rs b/aquatic_http/src/lib/network/mod.rs index fe02a57..c3e4ee1 100644 --- a/aquatic_http/src/lib/network/mod.rs +++ b/aquatic_http/src/lib/network/mod.rs @@ -2,6 +2,7 @@ pub mod connection; use std::time::{Duration, Instant}; use std::io::ErrorKind; +use std::sync::Arc; use hashbrown::HashMap; use log::{info, debug, error}; @@ -24,7 +25,7 @@ fn accept_new_streams( connections: &mut ConnectionMap, valid_until: ValidUntil, poll_token_counter: &mut Token, - use_tls: bool, + opt_tls_acceptor: &Option>, ){ loop { match listener.accept(){ @@ -44,7 +45,7 @@ fn accept_new_streams( .register(&mut stream, token, Interest::READABLE) .unwrap(); - let connection = Connection::new(use_tls, valid_until, stream); + let connection = Connection::new(opt_tls_acceptor, valid_until, stream); connections.insert(token, connection); }, @@ -65,7 +66,6 @@ fn accept_new_streams( pub fn run_handshake_and_read_requests( socket_worker_index: usize, request_channel_sender: &RequestChannelSender, - opt_tls_acceptor: &Option, // If set, run TLS connections: &mut ConnectionMap, poll_token: Token, valid_until: ValidUntil, @@ -115,9 +115,7 @@ pub fn run_handshake_and_read_requests( } else if let Some(handshake_machine) = connections.remove(&poll_token) .and_then(|c| c.inner.right()) { - let (opt_inner, stop_loop) = handshake_machine.advance( - opt_tls_acceptor - ); + let (opt_inner, would_block) = handshake_machine.advance(); if let Some(inner) = opt_inner { let connection = Connection { @@ -128,7 +126,7 @@ pub fn run_handshake_and_read_requests( connections.insert(poll_token, connection); } - if stop_loop { + if would_block { break; } } @@ -204,6 +202,7 @@ pub fn run_poll_loop( .unwrap(); let mut connections: ConnectionMap = HashMap::new(); + let opt_tls_acceptor = opt_tls_acceptor.map(Arc::new); let mut poll_token_counter = Token(0usize); let mut iter_counter = 0usize; @@ -224,13 +223,12 @@ pub fn run_poll_loop( &mut connections, valid_until, &mut poll_token_counter, - opt_tls_acceptor.is_some(), + &opt_tls_acceptor, ); } else { run_handshake_and_read_requests( socket_worker_index, &request_channel_sender, - &opt_tls_acceptor, &mut connections, token, valid_until,