diff --git a/aquatic_ws/src/lib/network/common.rs b/aquatic_ws/src/lib/network/common.rs index d93e6f3..f950376 100644 --- a/aquatic_ws/src/lib/network/common.rs +++ b/aquatic_ws/src/lib/network/common.rs @@ -1,4 +1,5 @@ use std::net::{SocketAddr}; +use std::io::{Read, Write}; use hashbrown::HashMap; use mio::Token; @@ -27,7 +28,47 @@ impl ::tungstenite::handshake::server::Callback for DebugCallback { } -pub type Stream = TlsStream; +pub enum Stream { + TcpStream(TcpStream), + TlsStream(TlsStream), +} + + +impl Stream { + pub fn get_peer_addr(&self) -> SocketAddr { + match self { + Self::TcpStream(stream) => stream.peer_addr().unwrap(), + Self::TlsStream(stream) => stream.get_ref().peer_addr().unwrap(), + } + } +} + + +impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> Result { + match self { + Self::TcpStream(stream) => stream.read(buf), + Self::TlsStream(stream) => stream.read(buf), + } + } +} + + +impl Write for Stream { + fn write(&mut self, buf: &[u8]) -> ::std::io::Result { + match self { + Self::TcpStream(stream) => stream.write(buf), + Self::TlsStream(stream) => stream.write(buf), + } + } + + fn flush(&mut self) -> ::std::io::Result<()> { + match self { + Self::TcpStream(stream) => stream.flush(), + Self::TlsStream(stream) => stream.flush(), + } + } +} pub struct EstablishedWs { @@ -37,20 +78,17 @@ pub struct EstablishedWs { pub enum ConnectionStage { - TcpStream(TcpStream), + Stream(Stream), TlsMidHandshake(native_tls::MidHandshakeTlsStream), - TlsStream(Stream), - WsHandshakeNoTls(MidHandshake>), - WsHandshakeTls(MidHandshake>), - EstablishedWsNoTls(EstablishedWs), - EstablishedWsTls(EstablishedWs), + WsHandshake(MidHandshake>), + EstablishedWs(EstablishedWs), } impl ConnectionStage { pub fn is_established(&self) -> bool { match self { - Self::EstablishedWsTls(_) | Self::EstablishedWsNoTls(_) => true, + Self::EstablishedWs(_) => true, _ => false, } } diff --git a/aquatic_ws/src/lib/network/mod.rs b/aquatic_ws/src/lib/network/mod.rs index da112b8..e1a28bc 100644 --- a/aquatic_ws/src/lib/network/mod.rs +++ b/aquatic_ws/src/lib/network/mod.rs @@ -119,10 +119,12 @@ fn accept_new_streams( poll.registry() .register(&mut stream, token, Interest::READABLE) .unwrap(); + + let stream = Stream::TcpStream(stream); let connection = Connection { valid_until, - stage: ConnectionStage::TcpStream(stream) + stage: ConnectionStage::Stream(stream) }; connections.insert(token, connection); @@ -139,94 +141,7 @@ fn accept_new_streams( } -pub fn handle_tls_handshake_result( - connections: &mut ConnectionMap, - poll_token: Token, - valid_until: ValidUntil, - result: Result, native_tls::HandshakeError>, -) -> bool { - match result { - Ok(stream) => { - println!("handshake established"); - - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsStream(stream) - }; - - connections.insert(poll_token, connection); - - false - }, - Err(native_tls::HandshakeError::WouldBlock(handshake)) => { - println!("interrupted"); - - let connection = Connection { - valid_until, - stage: ConnectionStage::TlsMidHandshake(handshake), - }; - - connections.insert(poll_token, connection); - - true - }, - Err(native_tls::HandshakeError::Failure(err)) => { - dbg!(err); - - false - } - } -} - - -pub fn handle_ws_handshake_no_tls_result( - connections: &mut ConnectionMap, - poll_token: Token, - valid_until: ValidUntil, - result: Result, HandshakeError>> , -) -> bool { - match result { - Ok(mut ws) => { - println!("handshake established"); - - let peer_addr = ws.get_mut().peer_addr().unwrap(); - - let established_ws = EstablishedWs { - ws, - peer_addr, - }; - - let connection = Connection { - valid_until, - stage: ConnectionStage::EstablishedWsNoTls(established_ws) - }; - - connections.insert(poll_token, connection); - - false - }, - Err(HandshakeError::Interrupted(handshake)) => { - println!("interrupted"); - - let connection = Connection { - valid_until, - stage: ConnectionStage::WsHandshakeNoTls(handshake), - }; - - connections.insert(poll_token, connection); - - true - }, - Err(HandshakeError::Failure(err)) => { - dbg!(err); - - false - } - } -} - - -pub fn handle_ws_handshake_tls_result( +pub fn handle_ws_handshake_result( connections: &mut ConnectionMap, poll_token: Token, valid_until: ValidUntil, @@ -236,7 +151,7 @@ pub fn handle_ws_handshake_tls_result( Ok(mut ws) => { println!("handshake established"); - let peer_addr = ws.get_mut().get_mut().peer_addr().unwrap(); + let peer_addr = ws.get_mut().get_peer_addr(); let established_ws = EstablishedWs { ws, @@ -245,7 +160,7 @@ pub fn handle_ws_handshake_tls_result( let connection = Connection { valid_until, - stage: ConnectionStage::EstablishedWsTls(established_ws) + stage: ConnectionStage::EstablishedWs(established_ws) }; connections.insert(poll_token, connection); @@ -257,7 +172,7 @@ pub fn handle_ws_handshake_tls_result( let connection = Connection { valid_until, - stage: ConnectionStage::WsHandshakeTls(handshake), + stage: ConnectionStage::WsHandshake(handshake), }; connections.insert(poll_token, connection); @@ -353,25 +268,42 @@ pub fn run_handshakes_and_read_messages( let conn = connections.remove(&poll_token).unwrap(); match conn.stage { - ConnectionStage::TcpStream(stream) => { + ConnectionStage::Stream(Stream::TcpStream(stream)) => { if let Some(tls_acceptor) = opt_tls_acceptor { - let stop_loop = handle_tls_handshake_result( - connections, - poll_token, - valid_until, - tls_acceptor.accept(stream) - ); - - if stop_loop { - break; + match tls_acceptor.accept(stream){ + Ok(stream) => { + println!("handshake established"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::Stream(Stream::TlsStream(stream)) + }; + + connections.insert(poll_token, connection); + }, + Err(native_tls::HandshakeError::WouldBlock(handshake)) => { + println!("interrupted"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::TlsMidHandshake(handshake), + }; + + connections.insert(poll_token, connection); + + break + }, + Err(native_tls::HandshakeError::Failure(err)) => { + dbg!(err); + } } } else { let handshake_result = ::tungstenite::server::accept_hdr( - stream, + Stream::TcpStream(stream), DebugCallback ); - let stop_loop = handle_ws_handshake_no_tls_result( + let stop_loop = handle_ws_handshake_result( connections, poll_token, valid_until, @@ -383,25 +315,13 @@ pub fn run_handshakes_and_read_messages( } } }, - ConnectionStage::TlsMidHandshake(handshake) => { - let stop_loop = handle_tls_handshake_result( - connections, - poll_token, - valid_until, - handshake.handshake() - ); - - if stop_loop { - break; - } - }, - ConnectionStage::TlsStream(stream) => { + ConnectionStage::Stream(Stream::TlsStream(stream)) => { let handshake_result = ::tungstenite::server::accept_hdr( - stream, + Stream::TlsStream(stream), DebugCallback ); - let stop_loop = handle_ws_handshake_tls_result( + let stop_loop = handle_ws_handshake_result( connections, poll_token, valid_until, @@ -412,8 +332,37 @@ pub fn run_handshakes_and_read_messages( break; } }, - ConnectionStage::WsHandshakeNoTls(handshake) => { - let stop_loop = handle_ws_handshake_no_tls_result( + ConnectionStage::TlsMidHandshake(handshake) => { + match handshake.handshake() { + Ok(stream) => { + println!("handshake established"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::Stream(Stream::TlsStream(stream)) + }; + + connections.insert(poll_token, connection); + }, + Err(native_tls::HandshakeError::WouldBlock(handshake)) => { + println!("interrupted"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::TlsMidHandshake(handshake), + }; + + connections.insert(poll_token, connection); + + break + }, + Err(native_tls::HandshakeError::Failure(err)) => { + dbg!(err); + } + } + }, + ConnectionStage::WsHandshake(handshake) => { + let stop_loop = handle_ws_handshake_result( connections, poll_token, valid_until, @@ -424,38 +373,12 @@ pub fn run_handshakes_and_read_messages( break; } }, - ConnectionStage::WsHandshakeTls(handshake) => { - let stop_loop = handle_ws_handshake_tls_result( - connections, - poll_token, - valid_until, - handshake.handshake() - ); - - if stop_loop { - break; - } - }, - ConnectionStage::EstablishedWsNoTls(_) => unreachable!(), - ConnectionStage::EstablishedWsTls(_) => unreachable!(), + ConnectionStage::EstablishedWs(_) => unreachable!(), } } else { match connections.get_mut(&poll_token){ Some(Connection{ - stage: ConnectionStage::EstablishedWsNoTls(established_ws), - .. - }) => { - read_ws_messages!( - socket_worker_index, - in_message_sender, - poll, - connections, - poll_token, - established_ws - ); - }, - Some(Connection{ - stage: ConnectionStage::EstablishedWsTls(established_ws), + stage: ConnectionStage::EstablishedWs(established_ws), .. }) => { read_ws_messages!( @@ -489,32 +412,7 @@ pub fn send_out_messages( // Exactly the same for both established stages match opt_stage { - Some(ConnectionStage::EstablishedWsNoTls(connection)) => { - if connection.peer_addr != meta.peer_addr { - eprintln!("socket worker: peer socket addrs didn't match"); - - continue; - } - - dbg!(out_message.clone()); - - match connection.ws.write_message(out_message.to_ws_message()){ - Ok(()) => {}, - Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => { - continue; - }, - Err(err) => { - dbg!(err); - - remove_connection_if_exists( - poll, - connections, - meta.poll_token - ); - }, - } - }, - Some(ConnectionStage::EstablishedWsTls(connection)) => { + Some(ConnectionStage::EstablishedWs(connection)) => { if connection.peer_addr != meta.peer_addr { eprintln!("socket worker: peer socket addrs didn't match"); diff --git a/aquatic_ws/src/lib/network/utils.rs b/aquatic_ws/src/lib/network/utils.rs index 5c8535d..ba8feda 100644 --- a/aquatic_ws/src/lib/network/utils.rs +++ b/aquatic_ws/src/lib/network/utils.rs @@ -57,37 +57,34 @@ pub fn create_tls_acceptor( } +/// FIXME pub fn close_and_deregister_connection( poll: &mut Poll, connection: &mut Connection, ){ match connection.stage { - ConnectionStage::TcpStream(ref mut stream) => { + ConnectionStage::Stream(ref mut stream) => { + /* poll.registry() .deregister(stream) .unwrap(); + */ }, ConnectionStage::TlsMidHandshake(ref mut handshake) => { + /* poll.registry() .deregister(handshake.get_mut()) .unwrap(); + */ }, - ConnectionStage::TlsStream(ref mut stream) => { - poll.registry() - .deregister(stream.get_mut()) - .unwrap(); - }, - ConnectionStage::WsHandshakeNoTls(ref mut handshake) => { + ConnectionStage::WsHandshake(ref mut handshake) => { + /* poll.registry() .deregister(handshake.get_mut().get_mut()) .unwrap(); + */ }, - ConnectionStage::WsHandshakeTls(ref mut handshake) => { - poll.registry() - .deregister(handshake.get_mut().get_mut().get_mut()) - .unwrap(); - }, - ConnectionStage::EstablishedWsNoTls(ref mut established_ws) => { + ConnectionStage::EstablishedWs(ref mut established_ws) => { if established_ws.ws.can_read(){ established_ws.ws.close(None).unwrap(); @@ -97,23 +94,11 @@ pub fn close_and_deregister_connection( } } + /* poll.registry() .deregister(established_ws.ws.get_mut()) .unwrap(); - }, - ConnectionStage::EstablishedWsTls(ref mut established_ws) => { - if established_ws.ws.can_read(){ - established_ws.ws.close(None).unwrap(); - - // Needs to be done after ws.close() - if let Err(err) = established_ws.ws.write_pending(){ - dbg!(err); - } - } - - poll.registry() - .deregister(established_ws.ws.get_mut().get_mut()) - .unwrap(); + */ }, } }