From b221f3fc346c33866cd9c7629043968bf95e96f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Tue, 12 May 2020 20:20:00 +0200 Subject: [PATCH] WIP: aquatic_ws: support tls and no tls with same functions --- aquatic_ws/src/lib/lib.rs | 1 + aquatic_ws/src/lib/network.rs | 292 +++++++++++++++++++++++++--------- 2 files changed, 221 insertions(+), 72 deletions(-) diff --git a/aquatic_ws/src/lib/lib.rs b/aquatic_ws/src/lib/lib.rs index d81b803..844b45c 100644 --- a/aquatic_ws/src/lib/lib.rs +++ b/aquatic_ws/src/lib/lib.rs @@ -35,6 +35,7 @@ pub fn run(config: Config){ i, in_message_sender, out_message_receiver, + true ); }); } diff --git a/aquatic_ws/src/lib/network.rs b/aquatic_ws/src/lib/network.rs index 5351e77..6ca011a 100644 --- a/aquatic_ws/src/lib/network.rs +++ b/aquatic_ws/src/lib/network.rs @@ -1,5 +1,5 @@ use std::fs::File; -use std::io::Read; +use std::io::{Read, Write}; use std::net::{SocketAddr}; use std::time::{Duration, Instant}; use std::io::ErrorKind; @@ -21,8 +21,8 @@ use crate::protocol::*; pub type Stream = TlsStream; -pub struct PeerConnection { - pub ws: WebSocket, +pub struct PeerConnection { + pub ws: WebSocket, pub peer_addr: SocketAddr, pub valid_until: ValidUntil, } @@ -32,8 +32,10 @@ pub enum ConnectionStage { TcpStream(TcpStream), TlsMidHandshake(native_tls::MidHandshakeTlsStream), TlsStream(Stream), - WsHandshake(MidHandshake>), - Established(PeerConnection), + WsHandshakeNoTls(MidHandshake>), + WsHandshakeTls(MidHandshake>), + EstablishedNoTls(PeerConnection), + EstablishedTls(PeerConnection), } @@ -83,12 +85,31 @@ fn close_and_deregister_connection( .deregister(stream.get_mut()) .unwrap(); }, - ConnectionStage::WsHandshake(ref mut handshake) => { + ConnectionStage::WsHandshakeNoTls(ref mut handshake) => { + poll.registry() + .deregister(handshake.get_mut().get_mut()) + .unwrap(); + }, + ConnectionStage::WsHandshakeTls(ref mut handshake) => { poll.registry() .deregister(handshake.get_mut().get_mut().get_mut()) .unwrap(); }, - ConnectionStage::Established(ref mut peer_connection) => { + ConnectionStage::EstablishedNoTls(ref mut peer_connection) => { + if peer_connection.ws.can_read(){ + peer_connection.ws.close(None).unwrap(); + + // Needs to be done after ws.close() + if let Err(err) = peer_connection.ws.write_pending(){ + dbg!(err); + } + } + + poll.registry() + .deregister(peer_connection.ws.get_mut()) + .unwrap(); + }, + ConnectionStage::EstablishedTls(ref mut peer_connection) => { if peer_connection.ws.can_read(){ peer_connection.ws.close(None).unwrap(); @@ -195,6 +216,7 @@ pub fn run_socket_worker( socket_worker_index: usize, in_message_sender: InMessageSender, out_message_receiver: OutMessageReceiver, + use_tls: bool ){ let poll_timeout = Duration::from_millis( config.network.poll_timeout_milliseconds @@ -230,7 +252,7 @@ pub fn run_socket_worker( &mut poll, &mut connections, valid_until, - &mut poll_token_counter + &mut poll_token_counter, ); } else if event.is_readable(){ run_handshakes_and_read_messages( @@ -240,7 +262,8 @@ pub fn run_socket_worker( &mut poll, &mut connections, token, - valid_until + valid_until, + use_tls ); } } @@ -345,7 +368,55 @@ pub fn handle_tls_handshake_result( } -pub fn handle_ws_handshake_result( +pub fn handle_ws_handshake_no_tls_result( + connections: &mut ConnectionMap, + poll_token: Token, + valid_until: ValidUntil, + result: Result, HandshakeError>> , +) -> bool { + match result { + Ok(mut ws) => { + println!("handshake established"); + + let peer_addr = ws.get_mut().peer_addr().unwrap(); + + let peer_connection = PeerConnection { + ws, + peer_addr, + valid_until, + }; + + let connection = Connection { + valid_until, + stage: ConnectionStage::EstablishedNoTls(peer_connection) + }; + + connections.insert(poll_token, connection); + + false + }, + Err(HandshakeError::Interrupted(handshake)) => { + println!("interrupted"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::WsHandshakeNoTls(handshake), + }; + + connections.insert(poll_token, connection); + + true + }, + Err(HandshakeError::Failure(err)) => { + dbg!(err); + + false + } + } +} + + +pub fn handle_ws_handshake_tls_result( connections: &mut ConnectionMap, poll_token: Token, valid_until: ValidUntil, @@ -365,7 +436,7 @@ pub fn handle_ws_handshake_result( let connection = Connection { valid_until, - stage: ConnectionStage::Established(peer_connection) + stage: ConnectionStage::EstablishedTls(peer_connection) }; connections.insert(poll_token, connection); @@ -377,7 +448,7 @@ pub fn handle_ws_handshake_result( let connection = Connection { valid_until, - stage: ConnectionStage::WsHandshake(handshake), + stage: ConnectionStage::WsHandshakeTls(handshake), }; connections.insert(poll_token, connection); @@ -393,6 +464,66 @@ pub fn handle_ws_handshake_result( } +// Macro hack to not have to write the following twice in +// `run_handshakes_and_read_messages` +macro_rules! read_ws_messages { + ( + $socket_worker_index: ident, + $in_message_sender: ident, + $poll: ident, + $connections: ident, + $poll_token: ident, + $valid_until: ident, + $peer_connection: ident + ) => { + println!("conn established"); + + match $peer_connection.ws.read_message(){ + Ok(ws_message) => { + dbg!(ws_message.clone()); + + if let Some(in_message) = InMessage::from_ws_message(ws_message){ + dbg!(in_message.clone()); + + let meta = ConnectionMeta { + worker_index: $socket_worker_index, + poll_token: $poll_token, + peer_addr: $peer_connection.peer_addr + }; + + $in_message_sender.send((meta, in_message)); + } + + $peer_connection.valid_until = $valid_until; + }, + Err(tungstenite::Error::Io(err)) => { + if err.kind() == ErrorKind::WouldBlock { + break; + } + + remove_connection_if_exists($poll, $connections, $poll_token); + + eprint!("{}", err); + + break; + }, + Err(tungstenite::Error::ConnectionClosed) => { + remove_connection_if_exists($poll, $connections, $poll_token); + + break; + }, + Err(err) => { + dbg!(err); + + remove_connection_if_exists($poll, $connections, $poll_token); + + break; + } + } + }; +} + + pub fn run_handshakes_and_read_messages( socket_worker_index: usize, in_message_sender: &InMessageSender, @@ -401,12 +532,14 @@ pub fn run_handshakes_and_read_messages( connections: &mut ConnectionMap, poll_token: Token, valid_until: ValidUntil, + use_tls: bool ){ println!("poll_token: {}", poll_token.0); loop { let established = match connections.get(&poll_token).map(|c| &c.stage){ - Some(ConnectionStage::Established(_)) => true, + Some(ConnectionStage::EstablishedTls(_)) => true, + Some(ConnectionStage::EstablishedNoTls(_)) => true, Some(_) => false, None => break, }; @@ -416,15 +549,33 @@ pub fn run_handshakes_and_read_messages( match conn.stage { ConnectionStage::TcpStream(stream) => { - let stop_loop = handle_tls_handshake_result( - connections, - poll_token, - valid_until, - tls_acceptor.accept(stream) - ); + if use_tls { + let stop_loop = handle_tls_handshake_result( + connections, + poll_token, + valid_until, + tls_acceptor.accept(stream) + ); + + if stop_loop { + break; + } + } else { + let handshake_result = ::tungstenite::server::accept_hdr( + stream, + DebugCallback + ); - if stop_loop { - break; + let stop_loop = handle_ws_handshake_no_tls_result( + connections, + poll_token, + valid_until, + handshake_result + ); + + if stop_loop { + break; + } } }, ConnectionStage::TlsMidHandshake(handshake) => { @@ -445,7 +596,7 @@ pub fn run_handshakes_and_read_messages( DebugCallback ); - let stop_loop = handle_ws_handshake_result( + let stop_loop = handle_ws_handshake_tls_result( connections, poll_token, valid_until, @@ -456,8 +607,8 @@ pub fn run_handshakes_and_read_messages( break; } }, - ConnectionStage::WsHandshake(handshake) => { - let stop_loop = handle_ws_handshake_result( + ConnectionStage::WsHandshakeNoTls(handshake) => { + let stop_loop = handle_ws_handshake_no_tls_result( connections, poll_token, valid_until, @@ -468,62 +619,59 @@ pub fn run_handshakes_and_read_messages( break; } }, - ConnectionStage::Established(_) => unreachable!(), + ConnectionStage::WsHandshakeTls(handshake) => { + let stop_loop = handle_ws_handshake_tls_result( + connections, + poll_token, + valid_until, + handshake.handshake() + ); + + if stop_loop { + break; + } + }, + ConnectionStage::EstablishedNoTls(_) => unreachable!(), + ConnectionStage::EstablishedTls(_) => unreachable!(), } - } else if let Some(Connection{ - stage: ConnectionStage::Established(peer_connection), - .. - }) = connections.get_mut(&poll_token){ - println!("conn established"); - - match peer_connection.ws.read_message(){ - Ok(ws_message) => { - dbg!(ws_message.clone()); - - if let Some(in_message) = InMessage::from_ws_message(ws_message){ - dbg!(in_message.clone()); - - let meta = ConnectionMeta { - worker_index: socket_worker_index, - poll_token, - peer_addr: peer_connection.peer_addr - }; - - in_message_sender.send((meta, in_message)); - } - - peer_connection.valid_until = valid_until; + } else { + match connections.get_mut(&poll_token){ + Some(Connection{ + stage: ConnectionStage::EstablishedNoTls(peer_connection), + .. + }) => { + read_ws_messages!( + socket_worker_index, + in_message_sender, + poll, + connections, + poll_token, + valid_until, + peer_connection + ); }, - Err(tungstenite::Error::Io(err)) => { - if err.kind() == ErrorKind::WouldBlock { - break - } - - remove_connection_if_exists(poll, connections, poll_token); - - eprint!("{}", err); - - break + Some(Connection{ + stage: ConnectionStage::EstablishedTls(peer_connection), + .. + }) => { + read_ws_messages!( + socket_worker_index, + in_message_sender, + poll, + connections, + poll_token, + valid_until, + peer_connection + ); }, - Err(tungstenite::Error::ConnectionClosed) => { - remove_connection_if_exists(poll, connections, poll_token); - - break; - }, - Err(err) => { - dbg!(err); - - remove_connection_if_exists(poll, connections, poll_token); - - break; - } + _ => () } } } } -/// Read messages from channel, send to peers +/// Read messages from channel, send to peers FIXME: NoTls pub fn send_out_messages( out_message_receiver: ::flume::Drain<(ConnectionMeta, OutMessage)>, poll: &mut Poll, @@ -534,7 +682,7 @@ pub fn send_out_messages( .get_mut(&meta.poll_token) .map(|v| &mut v.stage); - if let Some(ConnectionStage::Established(connection)) = opt_connection { + if let Some(ConnectionStage::EstablishedTls(connection)) = opt_connection { if connection.peer_addr != meta.peer_addr { eprintln!("socket worker: peer socket addrs didn't match");