diff --git a/Cargo.lock b/Cargo.lock index dc56638..481817b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,6 +93,7 @@ version = "0.1.0" dependencies = [ "aquatic_common", "cli_helpers", + "either", "flume", "hashbrown", "indexmap", @@ -330,6 +331,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "either" +version = "1.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" + [[package]] name = "encode_unicode" version = "0.3.6" diff --git a/aquatic_ws/Cargo.toml b/aquatic_ws/Cargo.toml index 5e89e52..2601720 100644 --- a/aquatic_ws/Cargo.toml +++ b/aquatic_ws/Cargo.toml @@ -16,6 +16,7 @@ path = "src/bin/main.rs" [dependencies] aquatic_common = { path = "../aquatic_common" } cli_helpers = { path = "../cli_helpers" } +either = "1" flume = "0.7" hashbrown = { version = "0.7", features = ["serde"] } indexmap = "1" diff --git a/aquatic_ws/src/lib/network/common.rs b/aquatic_ws/src/lib/network/common.rs index 4bc6318..c65da9a 100644 --- a/aquatic_ws/src/lib/network/common.rs +++ b/aquatic_ws/src/lib/network/common.rs @@ -1,12 +1,14 @@ use std::net::{SocketAddr}; use std::io::{Read, Write}; +use either::Either; use hashbrown::HashMap; use mio::Token; use mio::net::TcpStream; -use native_tls::{TlsStream, MidHandshakeTlsStream}; +use native_tls::{TlsAcceptor, TlsStream, MidHandshakeTlsStream}; use tungstenite::WebSocket; -use tungstenite::handshake::{MidHandshake, server::ServerHandshake}; +use tungstenite::handshake::{MidHandshake, HandshakeError}; +use tungstenite::server::ServerHandshake; use crate::common::*; @@ -71,26 +73,109 @@ impl Write for Stream { } -pub struct EstablishedWs { - pub ws: WebSocket, +pub struct EstablishedWs { + pub ws: WebSocket, pub peer_addr: SocketAddr, } -pub enum ConnectionStage { +pub enum HandshakeMachine { TcpStream(TcpStream), TlsStream(TlsStream), TlsMidHandshake(MidHandshakeTlsStream), WsMidHandshake(MidHandshake>), - EstablishedWs(EstablishedWs), } -impl ConnectionStage { - pub fn is_established(&self) -> bool { +impl HandshakeMachine { + pub fn new(tcp_stream: TcpStream) -> Self { + Self::TcpStream(tcp_stream) + } + + pub fn advance( + self, + opt_tls_acceptor: &Option, // If set, run TLS + ) -> (Option>, bool) { // bool = stop looping match self { - Self::EstablishedWs(_) => true, - _ => false, + HandshakeMachine::TcpStream(stream) => { + if let Some(tls_acceptor) = opt_tls_acceptor { + Self::handle_tls_handshake_result( + tls_acceptor.accept(stream) + ) + } else { + let handshake_result = ::tungstenite::server::accept_hdr( + Stream::TcpStream(stream), + DebugCallback + ); + + Self::handle_ws_handshake_result(handshake_result) + } + }, + HandshakeMachine::TlsStream(stream) => { + let handshake_result = ::tungstenite::server::accept_hdr( + Stream::TlsStream(stream), + DebugCallback + ); + + Self::handle_ws_handshake_result(handshake_result) + }, + HandshakeMachine::TlsMidHandshake(handshake) => { + Self::handle_tls_handshake_result(handshake.handshake()) + }, + HandshakeMachine::WsMidHandshake(handshake) => { + Self::handle_ws_handshake_result(handshake.handshake()) + }, + } + } + + fn handle_tls_handshake_result( + result: Result, ::native_tls::HandshakeError>, + ) -> (Option>, bool) { + match result { + Ok(stream) => { + println!("handshake established"); + + (Some(Either::Right(Self::TlsStream(stream))), false) + }, + Err(native_tls::HandshakeError::WouldBlock(handshake)) => { + println!("interrupted"); + + (Some(Either::Right(Self::TlsMidHandshake(handshake))), true) + }, + Err(native_tls::HandshakeError::Failure(err)) => { + dbg!(err); + + (None, false) + } + } + } + + fn handle_ws_handshake_result( + result: Result, HandshakeError>> , + ) -> (Option>, bool) { + match result { + Ok(mut ws) => { + println!("handshake established"); + + let peer_addr = ws.get_mut().get_peer_addr(); + + let established_ws = EstablishedWs { + ws, + peer_addr, + }; + + (Some(Either::Left(established_ws)), false) + }, + Err(HandshakeError::Interrupted(handshake)) => { + println!("interrupted"); + + (Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))), true) + }, + Err(HandshakeError::Failure(err)) => { + dbg!(err); + + (None, false) + } } } } @@ -98,7 +183,7 @@ impl ConnectionStage { pub struct Connection { pub valid_until: ValidUntil, - pub stage: ConnectionStage, + pub inner: Either, } diff --git a/aquatic_ws/src/lib/network/mod.rs b/aquatic_ws/src/lib/network/mod.rs index c8cf59e..b813ab4 100644 --- a/aquatic_ws/src/lib/network/mod.rs +++ b/aquatic_ws/src/lib/network/mod.rs @@ -1,13 +1,12 @@ use std::time::Duration; use std::io::ErrorKind; -use tungstenite::WebSocket; -use tungstenite::handshake::{HandshakeError, server::ServerHandshake}; +use either::Either; use hashbrown::HashMap; -use native_tls::{TlsAcceptor, TlsStream}; +use native_tls::TlsAcceptor; use mio::{Events, Poll, Interest, Token}; -use mio::net::{TcpListener, TcpStream}; +use mio::net::TcpListener; use crate::common::*; use crate::config::Config; @@ -119,7 +118,7 @@ fn accept_new_streams( let connection = Connection { valid_until, - stage: ConnectionStage::TcpStream(stream) + inner: Either::Right(HandshakeMachine::new(stream)) }; connections.insert(token, connection); @@ -136,56 +135,6 @@ fn accept_new_streams( } -pub fn handle_tls_handshake_result( - result: Result, ::native_tls::HandshakeError>, -) -> (Option, bool) { - match result { - Ok(stream) => { - println!("handshake established"); - - (Some(ConnectionStage::TlsStream(stream)), false) - }, - Err(native_tls::HandshakeError::WouldBlock(handshake)) => { - println!("interrupted"); - - (Some(ConnectionStage::TlsMidHandshake(handshake)), true) - }, - Err(native_tls::HandshakeError::Failure(err)) => { - (None, false) - } - } -} - - -pub fn handle_ws_handshake_result( - result: Result, HandshakeError>> , -) -> (Option, bool) { - match result { - Ok(mut ws) => { - println!("handshake established"); - - let peer_addr = ws.get_mut().get_peer_addr(); - - let established_ws = EstablishedWs { - ws, - peer_addr, - }; - - (Some(ConnectionStage::EstablishedWs(established_ws)), false) - }, - Err(HandshakeError::Interrupted(handshake)) => { - println!("interrupted"); - - (Some(ConnectionStage::WsMidHandshake(handshake)), true) - }, - Err(HandshakeError::Failure(err)) => { - dbg!(err); - - (None, false) - } - } -} - /// On the stream given by poll_token, get TLS (if requested) and tungstenite /// up and running, then read messages and pass on through channel. @@ -199,7 +148,7 @@ pub fn run_handshakes_and_read_messages( ){ loop { if let Some(Connection { - stage: ConnectionStage::EstablishedWs(established_ws), + inner: Either::Left(established_ws), .. }) = connections.get_mut(&poll_token){ use ::tungstenite::Error::Io; @@ -231,43 +180,17 @@ pub fn run_handshakes_and_read_messages( break; } } - } else if let Some(connection) = connections.remove(&poll_token) { - let (opt_new_stage, stop_loop) = match connection.stage { - ConnectionStage::TcpStream(stream) => { - if let Some(tls_acceptor) = opt_tls_acceptor { - handle_tls_handshake_result( - tls_acceptor.accept(stream) - ) - } else { - let handshake_result = ::tungstenite::server::accept_hdr( - Stream::TcpStream(stream), - DebugCallback - ); + } else if let Some(Connection { + inner: Either::Right(machine), + .. + }) = connections.remove(&poll_token) { + let (result, stop_loop) = machine + .advance(opt_tls_acceptor); - handle_ws_handshake_result(handshake_result) - } - }, - ConnectionStage::TlsStream(stream) => { - let handshake_result = ::tungstenite::server::accept_hdr( - Stream::TlsStream(stream), - DebugCallback - ); - - handle_ws_handshake_result(handshake_result) - }, - ConnectionStage::TlsMidHandshake(handshake) => { - handle_tls_handshake_result(handshake.handshake()) - }, - ConnectionStage::WsMidHandshake(handshake) => { - handle_ws_handshake_result(handshake.handshake()) - }, - ConnectionStage::EstablishedWs(_) => unreachable!(), - }; - - if let Some(stage) = opt_new_stage { + if let Some(inner) = result { let connection = Connection { valid_until, - stage, + inner, }; connections.insert(poll_token, connection); @@ -287,11 +210,11 @@ pub fn send_out_messages( connections: &mut ConnectionMap, ){ for (meta, out_message) in out_message_receiver { - let opt_stage = connections + let opt_inner = connections .get_mut(&meta.poll_token) - .map(|v| &mut v.stage); + .map(|v| &mut v.inner); - if let Some(ConnectionStage::EstablishedWs(connection)) = opt_stage { + if let Some(Either::Left(connection)) = opt_inner { if connection.peer_addr != meta.peer_addr { eprintln!("socket worker: peer socket addrs didn't match"); diff --git a/aquatic_ws/src/lib/network/utils.rs b/aquatic_ws/src/lib/network/utils.rs index c7fdce5..1e370aa 100644 --- a/aquatic_ws/src/lib/network/utils.rs +++ b/aquatic_ws/src/lib/network/utils.rs @@ -2,6 +2,7 @@ use std::fs::File; use std::io::Read; use std::time::Instant; +use either::Either; use mio::Token; use native_tls::{Identity, TlsAcceptor}; use net2::{TcpBuilder, unix::UnixTcpBuilderExt}; @@ -58,7 +59,7 @@ pub fn create_tls_acceptor( pub fn close_connection(connection: &mut Connection){ - if let ConnectionStage::EstablishedWs(ref mut ews) = connection.stage { + if let Either::Left(ref mut ews) = connection.inner { if ews.ws.can_read(){ ews.ws.close(None).unwrap();