diff --git a/TODO.md b/TODO.md index 05c6274..dc9cac4 100644 --- a/TODO.md +++ b/TODO.md @@ -2,9 +2,6 @@ ## aquatic_ws * network - * think about if at least established connections could be boxed behind some - dyn trait so that there are not two different version for what is - essentially the same thing * actually run tls. probably add config fields for number of tls and non-tls workers, then run that amount of each * test tls! diff --git a/aquatic_ws/src/lib/network/mod.rs b/aquatic_ws/src/lib/network/mod.rs index 7b71f6b..d4043e5 100644 --- a/aquatic_ws/src/lib/network/mod.rs +++ b/aquatic_ws/src/lib/network/mod.rs @@ -4,10 +4,10 @@ use std::io::ErrorKind; use tungstenite::WebSocket; use tungstenite::handshake::{HandshakeError, server::ServerHandshake}; use hashbrown::HashMap; -use native_tls::TlsAcceptor; +use native_tls::{TlsAcceptor, TlsStream}; use mio::{Events, Poll, Interest, Token}; -use mio::net::TcpListener; +use mio::net::{TcpListener, TcpStream}; use crate::common::*; use crate::config::Config; @@ -137,6 +137,44 @@ fn accept_new_streams( } +pub fn handle_tls_handshake_result( + connections: &mut ConnectionMap, + poll_token: Token, + valid_until: ValidUntil, + result: Result, ::native_tls::HandshakeError>, +) -> bool { + match result { + Ok(stream) => { + println!("handshake established"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::TlsStream(stream) + }; + + connections.insert(poll_token, connection); + }, + Err(native_tls::HandshakeError::WouldBlock(handshake)) => { + println!("interrupted"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::TlsMidHandshake(handshake), + }; + + connections.insert(poll_token, connection); + + return true; + }, + Err(native_tls::HandshakeError::Failure(err)) => { + dbg!(err); + } + } + + false +} + + pub fn handle_ws_handshake_result( connections: &mut ConnectionMap, poll_token: Token, @@ -193,8 +231,6 @@ pub fn run_handshakes_and_read_messages( poll_token: Token, valid_until: ValidUntil, ){ - println!("poll_token: {}", poll_token.0); - loop { let established = if let Some(c) = connections.get(&poll_token){ c.stage.is_established() @@ -208,32 +244,15 @@ pub fn run_handshakes_and_read_messages( match conn.stage { ConnectionStage::TcpStream(stream) => { if let Some(tls_acceptor) = opt_tls_acceptor { - match tls_acceptor.accept(stream){ - Ok(stream) => { - println!("handshake established"); - - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsStream(stream) - }; - - connections.insert(poll_token, connection); - }, - Err(native_tls::HandshakeError::WouldBlock(handshake)) => { - println!("interrupted"); - - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsMidHandshake(handshake), - }; - - connections.insert(poll_token, connection); - - break - }, - Err(native_tls::HandshakeError::Failure(err)) => { - dbg!(err); - } + let stop_loop = handle_tls_handshake_result( + connections, + poll_token, + valid_until, + tls_acceptor.accept(stream) + ); + + if stop_loop { + break } } else { let handshake_result = ::tungstenite::server::accept_hdr( @@ -271,32 +290,15 @@ pub fn run_handshakes_and_read_messages( } }, ConnectionStage::TlsMidHandshake(handshake) => { - match handshake.handshake() { - Ok(stream) => { - println!("handshake established"); - - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsStream(stream) - }; - - connections.insert(poll_token, connection); - }, - Err(native_tls::HandshakeError::WouldBlock(handshake)) => { - println!("interrupted"); - - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsMidHandshake(handshake), - }; - - connections.insert(poll_token, connection); - - break - }, - Err(native_tls::HandshakeError::Failure(err)) => { - dbg!(err); - } + let stop_loop = handle_tls_handshake_result( + connections, + poll_token, + valid_until, + handshake.handshake() + ); + + if stop_loop { + break } }, ConnectionStage::WsMidHandshake(handshake) => {