From 0c93d170de2571eb6282d61c233592e68a7fca4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Tue, 12 May 2020 21:04:47 +0200 Subject: [PATCH] WIP: aquatic_ws network: organize into submodule, other small fixes --- aquatic_ws/src/lib/network/common.rs | 66 ++++ .../src/lib/{network.rs => network/mod.rs} | 303 ++++-------------- aquatic_ws/src/lib/network/utils.rs | 156 +++++++++ 3 files changed, 284 insertions(+), 241 deletions(-) create mode 100644 aquatic_ws/src/lib/network/common.rs rename aquatic_ws/src/lib/{network.rs => network/mod.rs} (66%) create mode 100644 aquatic_ws/src/lib/network/utils.rs diff --git a/aquatic_ws/src/lib/network/common.rs b/aquatic_ws/src/lib/network/common.rs new file mode 100644 index 0000000..d93e6f3 --- /dev/null +++ b/aquatic_ws/src/lib/network/common.rs @@ -0,0 +1,66 @@ +use std::net::{SocketAddr}; + +use hashbrown::HashMap; +use mio::Token; +use mio::net::TcpStream; +use native_tls::TlsStream; +use tungstenite::WebSocket; +use tungstenite::handshake::{MidHandshake, server::ServerHandshake}; + +use crate::common::*; + + +#[derive(Clone, Copy, Debug)] +pub struct DebugCallback; + +impl ::tungstenite::handshake::server::Callback for DebugCallback { + fn on_request( + self, + request: &::tungstenite::handshake::server::Request, + response: ::tungstenite::handshake::server::Response, + ) -> Result<::tungstenite::handshake::server::Response, ::tungstenite::handshake::server::ErrorResponse> { + println!("request: {:#?}", request); + println!("response: {:#?}", response); + + Ok(response) + } +} + + +pub type Stream = TlsStream; + + +pub struct EstablishedWs { + pub ws: WebSocket, + pub peer_addr: SocketAddr, +} + + +pub enum ConnectionStage { + TcpStream(TcpStream), + TlsMidHandshake(native_tls::MidHandshakeTlsStream), + TlsStream(Stream), + WsHandshakeNoTls(MidHandshake>), + WsHandshakeTls(MidHandshake>), + EstablishedWsNoTls(EstablishedWs), + EstablishedWsTls(EstablishedWs), +} + + +impl ConnectionStage { + pub fn is_established(&self) -> bool { + match self { + Self::EstablishedWsTls(_) | Self::EstablishedWsNoTls(_) => true, + _ => false, + } + } +} + + +pub struct Connection { + pub valid_until: ValidUntil, + pub stage: ConnectionStage, +} + + +pub type ConnectionMap = HashMap; \ No newline at end of file diff --git a/aquatic_ws/src/lib/network.rs b/aquatic_ws/src/lib/network/mod.rs similarity index 66% rename from aquatic_ws/src/lib/network.rs rename to aquatic_ws/src/lib/network/mod.rs index a885841..da112b8 100644 --- a/aquatic_ws/src/lib/network.rs +++ b/aquatic_ws/src/lib/network/mod.rs @@ -1,14 +1,10 @@ -use std::fs::File; -use std::io::{Read, Write}; -use std::net::{SocketAddr}; -use std::time::{Duration, Instant}; +use std::time::Duration; use std::io::ErrorKind; use tungstenite::WebSocket; -use tungstenite::handshake::{MidHandshake, HandshakeError, server::{ServerHandshake, NoCallback}}; +use tungstenite::handshake::{HandshakeError, server::ServerHandshake}; use hashbrown::HashMap; -use native_tls::{Identity, TlsAcceptor, TlsStream}; -use net2::{TcpBuilder, unix::UnixTcpBuilderExt}; +use native_tls::{TlsAcceptor, TlsStream}; use mio::{Events, Poll, Interest, Token}; use mio::net::{TcpListener, TcpStream}; @@ -17,207 +13,11 @@ use crate::common::*; use crate::config::Config; use crate::protocol::*; +pub mod common; +pub mod utils; -pub type Stream = TlsStream; - - -pub struct EstablishedWs { - pub ws: WebSocket, - pub peer_addr: SocketAddr, -} - - -pub enum ConnectionStage { - TcpStream(TcpStream), - TlsMidHandshake(native_tls::MidHandshakeTlsStream), - TlsStream(Stream), - WsHandshakeNoTls(MidHandshake>), - WsHandshakeTls(MidHandshake>), - EstablishedWsNoTls(EstablishedWs), - EstablishedWsTls(EstablishedWs), -} - - -impl ConnectionStage { - pub fn is_established(&self) -> bool { - match self { - Self::EstablishedWsTls(_) | Self::EstablishedWsNoTls(_) => true, - _ => false, - } - } -} - - -pub struct Connection { - valid_until: ValidUntil, - stage: ConnectionStage, -} - - -pub type ConnectionMap = HashMap; - - -#[derive(Clone, Copy, Debug)] -pub struct DebugCallback; - -impl ::tungstenite::handshake::server::Callback for DebugCallback { - fn on_request( - self, - request: &::tungstenite::handshake::server::Request, - response: ::tungstenite::handshake::server::Response, - ) -> Result<::tungstenite::handshake::server::Response, ::tungstenite::handshake::server::ErrorResponse> { - println!("request: {:#?}", request); - println!("response: {:#?}", response); - - Ok(response) - } -} - - -fn close_and_deregister_connection( - poll: &mut Poll, - connection: &mut Connection, -){ - match connection.stage { - ConnectionStage::TcpStream(ref mut stream) => { - poll.registry() - .deregister(stream) - .unwrap(); - }, - ConnectionStage::TlsMidHandshake(ref mut handshake) => { - poll.registry() - .deregister(handshake.get_mut()) - .unwrap(); - }, - ConnectionStage::TlsStream(ref mut stream) => { - poll.registry() - .deregister(stream.get_mut()) - .unwrap(); - }, - ConnectionStage::WsHandshakeNoTls(ref mut handshake) => { - poll.registry() - .deregister(handshake.get_mut().get_mut()) - .unwrap(); - }, - ConnectionStage::WsHandshakeTls(ref mut handshake) => { - poll.registry() - .deregister(handshake.get_mut().get_mut().get_mut()) - .unwrap(); - }, - ConnectionStage::EstablishedWsNoTls(ref mut established_ws) => { - if established_ws.ws.can_read(){ - established_ws.ws.close(None).unwrap(); - - // Needs to be done after ws.close() - if let Err(err) = established_ws.ws.write_pending(){ - dbg!(err); - } - } - - poll.registry() - .deregister(established_ws.ws.get_mut()) - .unwrap(); - }, - ConnectionStage::EstablishedWsTls(ref mut established_ws) => { - if established_ws.ws.can_read(){ - established_ws.ws.close(None).unwrap(); - - // Needs to be done after ws.close() - if let Err(err) = established_ws.ws.write_pending(){ - dbg!(err); - } - } - - poll.registry() - .deregister(established_ws.ws.get_mut().get_mut()) - .unwrap(); - }, - } -} - - -fn remove_connection_if_exists( - poll: &mut Poll, - connections: &mut ConnectionMap, - token: Token, -){ - if let Some(mut connection) = connections.remove(&token){ - close_and_deregister_connection(poll, &mut connection); - - connections.remove(&token); - } -} - - -// Close and remove inactive connections -pub fn remove_inactive_connections( - poll: &mut Poll, - connections: &mut ConnectionMap, -){ - let now = Instant::now(); - - connections.retain(|_, connection| { - if connection.valid_until.0 < now { - close_and_deregister_connection(poll, connection); - - println!("closing connection, it is inactive"); - - false - } else { - println!("keeping connection, it is still active"); - - true - } - }); - - connections.shrink_to_fit(); -} - - -fn create_listener(config: &Config) -> ::std::net::TcpListener { - let mut builder = &{ - if config.network.address.is_ipv4(){ - TcpBuilder::new_v4().expect("socket: build") - } else { - TcpBuilder::new_v6().expect("socket: build") - } - }; - - builder = builder.reuse_port(true) - .expect("socket: set reuse port"); - - builder = builder.bind(&config.network.address) - .expect(&format!("socket: bind to {}", &config.network.address)); - - let listener = builder.listen(128) - .expect("tcpbuilder to tcp listener"); - - listener.set_nonblocking(true) - .expect("socket: set nonblocking"); - - listener -} - - -fn create_tls_acceptor( - config: &Config, -) -> TlsAcceptor { - let mut identity_bytes = Vec::new(); - let mut file = File::open(&config.network.pkcs12_path) - .expect("open pkcs12 file"); - - file.read_to_end(&mut identity_bytes).expect("read pkcs12 file"); - - let identity = Identity::from_pkcs12( - &mut identity_bytes, - &config.network.pkcs12_password - ).expect("create pkcs12 identity"); - - let acceptor = TlsAcceptor::new(identity) - .expect("create TlsAcceptor"); - - acceptor -} +use common::*; +use utils::*; pub fn run_socket_worker( @@ -296,7 +96,6 @@ pub fn run_socket_worker( } - fn accept_new_streams( listener: &mut TcpListener, poll: &mut Poll, @@ -475,7 +274,8 @@ pub fn handle_ws_handshake_tls_result( // Macro hack to not have to write the following twice in -// `run_handshakes_and_read_messages` +// `run_handshakes_and_read_messages` (putting it in a function causes error +// because of multiple mutable references) macro_rules! read_ws_messages { ( $socket_worker_index: ident, @@ -531,6 +331,7 @@ macro_rules! read_ws_messages { } +/// Get TLS (if requested) and tungstenite up and running, then read messages pub fn run_handshakes_and_read_messages( socket_worker_index: usize, in_message_sender: &InMessageSender, @@ -673,52 +474,72 @@ pub fn run_handshakes_and_read_messages( } -/// Read messages from channel, send to peers FIXME: NoTls +/// Read messages from channel, send to peers pub fn send_out_messages( out_message_receiver: ::flume::Drain<(ConnectionMeta, OutMessage)>, poll: &mut Poll, connections: &mut ConnectionMap, ){ for (meta, out_message) in out_message_receiver { - let opt_connection = connections + let opt_stage = connections .get_mut(&meta.poll_token) .map(|v| &mut v.stage); + + use ::tungstenite::Error::Io; + + // Exactly the same for both established stages + match opt_stage { + Some(ConnectionStage::EstablishedWsNoTls(connection)) => { + if connection.peer_addr != meta.peer_addr { + eprintln!("socket worker: peer socket addrs didn't match"); - if let Some(ConnectionStage::EstablishedWsTls(connection)) = opt_connection { - if connection.peer_addr != meta.peer_addr { - eprintln!("socket worker: peer socket addrs didn't match"); + continue; + } - continue; - } + dbg!(out_message.clone()); - dbg!(out_message.clone()); - - match connection.ws.write_message(out_message.to_ws_message()){ - Ok(()) => {}, - Err(tungstenite::Error::Io(err)) => { - if err.kind() == ErrorKind::WouldBlock { + match connection.ws.write_message(out_message.to_ws_message()){ + Ok(()) => {}, + Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => { continue; - } + }, + Err(err) => { + dbg!(err); - dbg!(err); + remove_connection_if_exists( + poll, + connections, + meta.poll_token + ); + }, + } + }, + Some(ConnectionStage::EstablishedWsTls(connection)) => { + if connection.peer_addr != meta.peer_addr { + eprintln!("socket worker: peer socket addrs didn't match"); - remove_connection_if_exists( - poll, - connections, - meta.poll_token - ); - }, - Err(tungstenite::Error::ConnectionClosed) => { - remove_connection_if_exists( - poll, - connections, - meta.poll_token - ); - }, - Err(err) => { - dbg!(err); - }, - } + continue; + } + + dbg!(out_message.clone()); + + match connection.ws.write_message(out_message.to_ws_message()){ + Ok(()) => {}, + Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => { + continue; + }, + Err(err) => { + dbg!(err); + + remove_connection_if_exists( + poll, + connections, + meta.poll_token + ); + }, + } + }, + _ => {}, } } } \ No newline at end of file diff --git a/aquatic_ws/src/lib/network/utils.rs b/aquatic_ws/src/lib/network/utils.rs new file mode 100644 index 0000000..5c8535d --- /dev/null +++ b/aquatic_ws/src/lib/network/utils.rs @@ -0,0 +1,156 @@ +use std::fs::File; +use std::io::Read; +use std::time::Instant; + +use mio::{Poll, Token}; +use native_tls::{Identity, TlsAcceptor}; +use net2::{TcpBuilder, unix::UnixTcpBuilderExt}; + +use crate::config::Config; + +use super::common::*; + + +pub fn create_listener(config: &Config) -> ::std::net::TcpListener { + let mut builder = &{ + if config.network.address.is_ipv4(){ + TcpBuilder::new_v4().expect("socket: build") + } else { + TcpBuilder::new_v6().expect("socket: build") + } + }; + + builder = builder.reuse_port(true) + .expect("socket: set reuse port"); + + builder = builder.bind(&config.network.address) + .expect(&format!("socket: bind to {}", &config.network.address)); + + let listener = builder.listen(128) + .expect("tcpbuilder to tcp listener"); + + listener.set_nonblocking(true) + .expect("socket: set nonblocking"); + + listener +} + + +pub fn create_tls_acceptor( + config: &Config, +) -> TlsAcceptor { + let mut identity_bytes = Vec::new(); + let mut file = File::open(&config.network.pkcs12_path) + .expect("open pkcs12 file"); + + file.read_to_end(&mut identity_bytes).expect("read pkcs12 file"); + + let identity = Identity::from_pkcs12( + &mut identity_bytes, + &config.network.pkcs12_password + ).expect("create pkcs12 identity"); + + let acceptor = TlsAcceptor::new(identity) + .expect("create TlsAcceptor"); + + acceptor +} + + +pub fn close_and_deregister_connection( + poll: &mut Poll, + connection: &mut Connection, +){ + match connection.stage { + ConnectionStage::TcpStream(ref mut stream) => { + poll.registry() + .deregister(stream) + .unwrap(); + }, + ConnectionStage::TlsMidHandshake(ref mut handshake) => { + poll.registry() + .deregister(handshake.get_mut()) + .unwrap(); + }, + ConnectionStage::TlsStream(ref mut stream) => { + poll.registry() + .deregister(stream.get_mut()) + .unwrap(); + }, + ConnectionStage::WsHandshakeNoTls(ref mut handshake) => { + poll.registry() + .deregister(handshake.get_mut().get_mut()) + .unwrap(); + }, + ConnectionStage::WsHandshakeTls(ref mut handshake) => { + poll.registry() + .deregister(handshake.get_mut().get_mut().get_mut()) + .unwrap(); + }, + ConnectionStage::EstablishedWsNoTls(ref mut established_ws) => { + if established_ws.ws.can_read(){ + established_ws.ws.close(None).unwrap(); + + // Needs to be done after ws.close() + if let Err(err) = established_ws.ws.write_pending(){ + dbg!(err); + } + } + + poll.registry() + .deregister(established_ws.ws.get_mut()) + .unwrap(); + }, + ConnectionStage::EstablishedWsTls(ref mut established_ws) => { + if established_ws.ws.can_read(){ + established_ws.ws.close(None).unwrap(); + + // Needs to be done after ws.close() + if let Err(err) = established_ws.ws.write_pending(){ + dbg!(err); + } + } + + poll.registry() + .deregister(established_ws.ws.get_mut().get_mut()) + .unwrap(); + }, + } +} + + +pub fn remove_connection_if_exists( + poll: &mut Poll, + connections: &mut ConnectionMap, + token: Token, +){ + if let Some(mut connection) = connections.remove(&token){ + close_and_deregister_connection(poll, &mut connection); + + connections.remove(&token); + } +} + +// Close and remove inactive connections +pub fn remove_inactive_connections( + poll: &mut Poll, + connections: &mut ConnectionMap, +){ + let now = Instant::now(); + + connections.retain(|_, connection| { + if connection.valid_until.0 < now { + close_and_deregister_connection(poll, connection); + + println!("closing connection, it is inactive"); + + false + } else { + println!("keeping connection, it is still active"); + + true + } + }); + + connections.shrink_to_fit(); +} \ No newline at end of file