diff --git a/aquatic_ws/src/lib/network/mod.rs b/aquatic_ws/src/lib/network/mod.rs index 816a4bd..c8cf59e 100644 --- a/aquatic_ws/src/lib/network/mod.rs +++ b/aquatic_ws/src/lib/network/mod.rs @@ -137,49 +137,29 @@ fn accept_new_streams( pub fn handle_tls_handshake_result( - connections: &mut ConnectionMap, - poll_token: Token, - valid_until: ValidUntil, result: Result, ::native_tls::HandshakeError>, -) -> bool { +) -> (Option, bool) { match result { Ok(stream) => { println!("handshake established"); - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsStream(stream) - }; - - connections.insert(poll_token, connection); + (Some(ConnectionStage::TlsStream(stream)), false) }, Err(native_tls::HandshakeError::WouldBlock(handshake)) => { println!("interrupted"); - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsMidHandshake(handshake), - }; - - connections.insert(poll_token, connection); - - return true; + (Some(ConnectionStage::TlsMidHandshake(handshake)), true) }, Err(native_tls::HandshakeError::Failure(err)) => { - dbg!(err); + (None, false) } } - - false } pub fn handle_ws_handshake_result( - connections: &mut ConnectionMap, - poll_token: Token, - valid_until: ValidUntil, result: Result, HandshakeError>> , -) -> bool { +) -> (Option, bool) { match result { Ok(mut ws) => { println!("handshake established"); @@ -191,31 +171,19 @@ pub fn handle_ws_handshake_result( peer_addr, }; - let connection = Connection { - valid_until, - stage: ConnectionStage::EstablishedWs(established_ws) - }; - - connections.insert(poll_token, connection); + (Some(ConnectionStage::EstablishedWs(established_ws)), false) }, Err(HandshakeError::Interrupted(handshake)) => { println!("interrupted"); - let connection = Connection { - valid_until, - stage: ConnectionStage::WsMidHandshake(handshake), - }; - - connections.insert(poll_token, connection); - - return true + (Some(ConnectionStage::WsMidHandshake(handshake)), true) }, Err(HandshakeError::Failure(err)) => { dbg!(err); + + (None, false) } } - - false } @@ -264,13 +232,10 @@ pub fn run_handshakes_and_read_messages( } } } else if let Some(connection) = connections.remove(&poll_token) { - let stop_loop = match connection.stage { + let (opt_new_stage, stop_loop) = match connection.stage { ConnectionStage::TcpStream(stream) => { if let Some(tls_acceptor) = opt_tls_acceptor { handle_tls_handshake_result( - connections, - poll_token, - valid_until, tls_acceptor.accept(stream) ) } else { @@ -279,12 +244,7 @@ pub fn run_handshakes_and_read_messages( DebugCallback ); - handle_ws_handshake_result( - connections, - poll_token, - valid_until, - handshake_result - ) + handle_ws_handshake_result(handshake_result) } }, ConnectionStage::TlsStream(stream) => { @@ -293,32 +253,26 @@ pub fn run_handshakes_and_read_messages( DebugCallback ); - handle_ws_handshake_result( - connections, - poll_token, - valid_until, - handshake_result - ) + handle_ws_handshake_result(handshake_result) }, ConnectionStage::TlsMidHandshake(handshake) => { - handle_tls_handshake_result( - connections, - poll_token, - valid_until, - handshake.handshake() - ) + handle_tls_handshake_result(handshake.handshake()) }, ConnectionStage::WsMidHandshake(handshake) => { - handle_ws_handshake_result( - connections, - poll_token, - valid_until, - handshake.handshake() - ) + handle_ws_handshake_result(handshake.handshake()) }, ConnectionStage::EstablishedWs(_) => unreachable!(), }; + if let Some(stage) = opt_new_stage { + let connection = Connection { + valid_until, + stage, + }; + + connections.insert(poll_token, connection); + } + if stop_loop { break; }