diff --git a/Cargo.lock b/Cargo.lock index 419b461..dc56638 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,6 +98,7 @@ dependencies = [ "indexmap", "mimalloc", "mio", + "native-tls", "net2", "parking_lot", "quickcheck", diff --git a/aquatic_ws/Cargo.toml b/aquatic_ws/Cargo.toml index aeddb1f..5e89e52 100644 --- a/aquatic_ws/Cargo.toml +++ b/aquatic_ws/Cargo.toml @@ -21,6 +21,7 @@ hashbrown = { version = "0.7", features = ["serde"] } indexmap = "1" mimalloc = { version = "0.1", default-features = false } mio = { version = "0.7", features = ["tcp", "os-poll", "os-util"] } +native-tls = "0.2" net2 = "0.2" parking_lot = "0.10" rand = { version = "0.7", features = ["small_rng"] } diff --git a/aquatic_ws/src/lib/config.rs b/aquatic_ws/src/lib/config.rs index e482e84..bd49438 100644 --- a/aquatic_ws/src/lib/config.rs +++ b/aquatic_ws/src/lib/config.rs @@ -32,6 +32,8 @@ pub struct NetworkConfig { pub peer_announce_interval: usize, // FIXME: should this really be in NetworkConfig? pub poll_event_capacity: usize, pub poll_timeout_milliseconds: u64, + pub pkcs12_path: String, + pub pkcs12_password: String, } @@ -92,6 +94,8 @@ impl Default for NetworkConfig { peer_announce_interval: 120, poll_event_capacity: 4096, poll_timeout_milliseconds: 50, + pkcs12_path: "".into(), + pkcs12_password: "".into(), } } } diff --git a/aquatic_ws/src/lib/network.rs b/aquatic_ws/src/lib/network.rs index 8a2e187..5351e77 100644 --- a/aquatic_ws/src/lib/network.rs +++ b/aquatic_ws/src/lib/network.rs @@ -1,3 +1,5 @@ +use std::fs::File; +use std::io::Read; use std::net::{SocketAddr}; use std::time::{Duration, Instant}; use std::io::ErrorKind; @@ -5,6 +7,7 @@ use std::io::ErrorKind; use tungstenite::WebSocket; use tungstenite::handshake::{MidHandshake, HandshakeError, server::{ServerHandshake, NoCallback}}; use hashbrown::HashMap; +use native_tls::{Identity, TlsAcceptor, TlsStream}; use net2::{TcpBuilder, unix::UnixTcpBuilderExt}; use mio::{Events, Poll, Interest, Token}; @@ -15,23 +18,28 @@ use crate::config::Config; use crate::protocol::*; -pub struct Connection { - valid_until: ValidUntil, - stage: ConnectionStage, +pub type Stream = TlsStream; + + +pub struct PeerConnection { + pub ws: WebSocket, + pub peer_addr: SocketAddr, + pub valid_until: ValidUntil, } pub enum ConnectionStage { - Stream(TcpStream), - MidHandshake(MidHandshake>), + TcpStream(TcpStream), + TlsMidHandshake(native_tls::MidHandshakeTlsStream), + TlsStream(Stream), + WsHandshake(MidHandshake>), Established(PeerConnection), } -pub struct PeerConnection { - pub ws: WebSocket, - pub peer_addr: SocketAddr, - pub valid_until: ValidUntil, +pub struct Connection { + valid_until: ValidUntil, + stage: ConnectionStage, } @@ -55,6 +63,62 @@ impl ::tungstenite::handshake::server::Callback for DebugCallback { } +fn close_and_deregister_connection( + poll: &mut Poll, + connection: &mut Connection, +){ + match connection.stage { + ConnectionStage::TcpStream(ref mut stream) => { + poll.registry() + .deregister(stream) + .unwrap(); + }, + ConnectionStage::TlsMidHandshake(ref mut handshake) => { + poll.registry() + .deregister(handshake.get_mut()) + .unwrap(); + }, + ConnectionStage::TlsStream(ref mut stream) => { + poll.registry() + .deregister(stream.get_mut()) + .unwrap(); + }, + ConnectionStage::WsHandshake(ref mut handshake) => { + poll.registry() + .deregister(handshake.get_mut().get_mut().get_mut()) + .unwrap(); + }, + ConnectionStage::Established(ref mut peer_connection) => { + if peer_connection.ws.can_read(){ + peer_connection.ws.close(None).unwrap(); + + // Needs to be done after ws.close() + if let Err(err) = peer_connection.ws.write_pending(){ + dbg!(err); + } + } + + poll.registry() + .deregister(peer_connection.ws.get_mut().get_mut()) + .unwrap(); + }, + } +} + + +fn remove_connection_if_exists( + poll: &mut Poll, + connections: &mut ConnectionMap, + token: Token, +){ + if let Some(mut connection) = connections.remove(&token){ + close_and_deregister_connection(poll, &mut connection); + + connections.remove(&token); + } +} + + // Close and remove inactive connections pub fn remove_inactive_connections( poll: &mut Poll, @@ -64,30 +128,7 @@ pub fn remove_inactive_connections( connections.retain(|_, connection| { if connection.valid_until.0 < now { - match connection.stage { - ConnectionStage::Stream(ref mut stream) => { - poll.registry() - .deregister(stream) - .unwrap(); - }, - ConnectionStage::MidHandshake(ref mut handshake) => { - poll.registry() - .deregister(handshake.get_mut().get_mut()) - .unwrap(); - }, - ConnectionStage::Established(ref mut peer_connection) => { - peer_connection.ws.close(None).unwrap(); - - // Needs to be done after ws.close() - if let Err(err) = peer_connection.ws.write_pending(){ - dbg!(err); - } - - poll.registry() - .deregister(peer_connection.ws.get_mut()) - .unwrap(); - }, - } + close_and_deregister_connection(poll, connection); println!("closing connection, it is inactive"); @@ -128,6 +169,27 @@ fn create_listener(config: &Config) -> ::std::net::TcpListener { } +fn create_tls_acceptor( + config: &Config, +) -> TlsAcceptor { + let mut identity_bytes = Vec::new(); + let mut file = File::open(&config.network.pkcs12_path) + .expect("open pkcs12 file"); + + file.read_to_end(&mut identity_bytes).expect("read pkcs12 file"); + + let identity = Identity::from_pkcs12( + &mut identity_bytes, + &config.network.pkcs12_password + ).expect("create pkcs12 identity"); + + let acceptor = TlsAcceptor::new(identity) + .expect("create TlsAcceptor"); + + acceptor +} + + pub fn run_socket_worker( config: Config, socket_worker_index: usize, @@ -146,6 +208,8 @@ pub fn run_socket_worker( .register(&mut listener, Token(0), Interest::READABLE) .unwrap(); + let tls_acceptor = create_tls_acceptor(&config); + let mut connections: ConnectionMap = HashMap::new(); let mut poll_token_counter = Token(0usize); @@ -169,9 +233,10 @@ pub fn run_socket_worker( &mut poll_token_counter ); } else if event.is_readable(){ - run_handshake_and_read_messages( + run_handshakes_and_read_messages( socket_worker_index, &in_message_sender, + &tls_acceptor, &mut poll, &mut connections, token, @@ -196,37 +261,6 @@ pub fn run_socket_worker( } -/// FIXME: close ws if not closed? Would be good in some cases, like when -/// connection removed because token has wrapped back, and generally when the -/// reason for removal is not that ws is already closed. -fn remove_connection_if_exists( - poll: &mut Poll, - connections: &mut ConnectionMap, - token: Token, -){ - if let Some(connection) = connections.remove(&token){ - match connection.stage { - ConnectionStage::Stream(mut stream) => { - poll.registry() - .deregister(&mut stream) - .unwrap(); - }, - ConnectionStage::MidHandshake(mut handshake) => { - poll.registry() - .deregister(handshake.get_mut().get_mut()) - .unwrap(); - } - ConnectionStage::Established(mut peer_connection) => { - poll.registry() - .deregister(peer_connection.ws.get_mut()) - .unwrap(); - } - }; - - connections.remove(&token); - } -} - fn accept_new_streams( listener: &mut TcpListener, @@ -254,7 +288,7 @@ fn accept_new_streams( let connection = Connection { valid_until, - stage: ConnectionStage::Stream(stream) + stage: ConnectionStage::TcpStream(stream) }; connections.insert(token, connection); @@ -271,17 +305,57 @@ fn accept_new_streams( } -pub fn handle_handshake_result( +pub fn handle_tls_handshake_result( connections: &mut ConnectionMap, poll_token: Token, valid_until: ValidUntil, - result: Result, HandshakeError>> , + result: Result, native_tls::HandshakeError>, +) -> bool { + match result { + Ok(stream) => { + println!("handshake established"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::TlsStream(stream) + }; + + connections.insert(poll_token, connection); + + false + }, + Err(native_tls::HandshakeError::WouldBlock(handshake)) => { + println!("interrupted"); + + let connection = Connection { + valid_until, + stage: ConnectionStage::TlsMidHandshake(handshake), + }; + + connections.insert(poll_token, connection); + + true + }, + Err(native_tls::HandshakeError::Failure(err)) => { + dbg!(err); + + false + } + } +} + + +pub fn handle_ws_handshake_result( + connections: &mut ConnectionMap, + poll_token: Token, + valid_until: ValidUntil, + result: Result, HandshakeError>> , ) -> bool { match result { Ok(mut ws) => { println!("handshake established"); - let peer_addr = ws.get_mut().peer_addr().unwrap(); + let peer_addr = ws.get_mut().get_mut().peer_addr().unwrap(); let peer_connection = PeerConnection { ws, @@ -303,7 +377,7 @@ pub fn handle_handshake_result( let connection = Connection { valid_until, - stage: ConnectionStage::MidHandshake(handshake), + stage: ConnectionStage::WsHandshake(handshake), }; connections.insert(poll_token, connection); @@ -319,9 +393,10 @@ pub fn handle_handshake_result( } -pub fn run_handshake_and_read_messages( +pub fn run_handshakes_and_read_messages( socket_worker_index: usize, in_message_sender: &InMessageSender, + tls_acceptor: &TlsAcceptor, poll: &mut Poll, connections: &mut ConnectionMap, poll_token: Token, @@ -331,9 +406,8 @@ pub fn run_handshake_and_read_messages( loop { let established = match connections.get(&poll_token).map(|c| &c.stage){ - Some(ConnectionStage::Stream(_)) => false, - Some(ConnectionStage::MidHandshake(_)) => false, Some(ConnectionStage::Established(_)) => true, + Some(_) => false, None => break, }; @@ -341,13 +415,37 @@ pub fn run_handshake_and_read_messages( let conn = connections.remove(&poll_token).unwrap(); match conn.stage { - ConnectionStage::Stream(stream) => { + ConnectionStage::TcpStream(stream) => { + let stop_loop = handle_tls_handshake_result( + connections, + poll_token, + valid_until, + tls_acceptor.accept(stream) + ); + + if stop_loop { + break; + } + }, + ConnectionStage::TlsMidHandshake(handshake) => { + let stop_loop = handle_tls_handshake_result( + connections, + poll_token, + valid_until, + handshake.handshake() + ); + + if stop_loop { + break; + } + }, + ConnectionStage::TlsStream(stream) => { let handshake_result = ::tungstenite::server::accept_hdr( stream, DebugCallback ); - let stop_loop = handle_handshake_result( + let stop_loop = handle_ws_handshake_result( connections, poll_token, valid_until, @@ -358,8 +456,8 @@ pub fn run_handshake_and_read_messages( break; } }, - ConnectionStage::MidHandshake(handshake) => { - let stop_loop = handle_handshake_result( + ConnectionStage::WsHandshake(handshake) => { + let stop_loop = handle_ws_handshake_result( connections, poll_token, valid_until,