aquatic_ws: create HandshakeMachine from ConnectionStage

This commit is contained in:
Joakim Frostegård 2020-05-13 19:17:33 +02:00
parent 2967129c1f
commit 0bcfffb2bd
5 changed files with 122 additions and 105 deletions

7
Cargo.lock generated
View file

@ -93,6 +93,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"aquatic_common", "aquatic_common",
"cli_helpers", "cli_helpers",
"either",
"flume", "flume",
"hashbrown", "hashbrown",
"indexmap", "indexmap",
@ -330,6 +331,12 @@ dependencies = [
"generic-array", "generic-array",
] ]
[[package]]
name = "either"
version = "1.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
[[package]] [[package]]
name = "encode_unicode" name = "encode_unicode"
version = "0.3.6" version = "0.3.6"

View file

@ -16,6 +16,7 @@ path = "src/bin/main.rs"
[dependencies] [dependencies]
aquatic_common = { path = "../aquatic_common" } aquatic_common = { path = "../aquatic_common" }
cli_helpers = { path = "../cli_helpers" } cli_helpers = { path = "../cli_helpers" }
either = "1"
flume = "0.7" flume = "0.7"
hashbrown = { version = "0.7", features = ["serde"] } hashbrown = { version = "0.7", features = ["serde"] }
indexmap = "1" indexmap = "1"

View file

@ -1,12 +1,14 @@
use std::net::{SocketAddr}; use std::net::{SocketAddr};
use std::io::{Read, Write}; use std::io::{Read, Write};
use either::Either;
use hashbrown::HashMap; use hashbrown::HashMap;
use mio::Token; use mio::Token;
use mio::net::TcpStream; use mio::net::TcpStream;
use native_tls::{TlsStream, MidHandshakeTlsStream}; use native_tls::{TlsAcceptor, TlsStream, MidHandshakeTlsStream};
use tungstenite::WebSocket; use tungstenite::WebSocket;
use tungstenite::handshake::{MidHandshake, server::ServerHandshake}; use tungstenite::handshake::{MidHandshake, HandshakeError};
use tungstenite::server::ServerHandshake;
use crate::common::*; use crate::common::*;
@ -71,26 +73,109 @@ impl Write for Stream {
} }
pub struct EstablishedWs<S> { pub struct EstablishedWs {
pub ws: WebSocket<S>, pub ws: WebSocket<Stream>,
pub peer_addr: SocketAddr, pub peer_addr: SocketAddr,
} }
pub enum ConnectionStage { pub enum HandshakeMachine {
TcpStream(TcpStream), TcpStream(TcpStream),
TlsStream(TlsStream<TcpStream>), TlsStream(TlsStream<TcpStream>),
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>), TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
WsMidHandshake(MidHandshake<ServerHandshake<Stream, DebugCallback>>), WsMidHandshake(MidHandshake<ServerHandshake<Stream, DebugCallback>>),
EstablishedWs(EstablishedWs<Stream>),
} }
impl ConnectionStage { impl HandshakeMachine {
pub fn is_established(&self) -> bool { pub fn new(tcp_stream: TcpStream) -> Self {
Self::TcpStream(tcp_stream)
}
pub fn advance(
self,
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS
) -> (Option<Either<EstablishedWs, Self>>, bool) { // bool = stop looping
match self { match self {
Self::EstablishedWs(_) => true, HandshakeMachine::TcpStream(stream) => {
_ => false, if let Some(tls_acceptor) = opt_tls_acceptor {
Self::handle_tls_handshake_result(
tls_acceptor.accept(stream)
)
} else {
let handshake_result = ::tungstenite::server::accept_hdr(
Stream::TcpStream(stream),
DebugCallback
);
Self::handle_ws_handshake_result(handshake_result)
}
},
HandshakeMachine::TlsStream(stream) => {
let handshake_result = ::tungstenite::server::accept_hdr(
Stream::TlsStream(stream),
DebugCallback
);
Self::handle_ws_handshake_result(handshake_result)
},
HandshakeMachine::TlsMidHandshake(handshake) => {
Self::handle_tls_handshake_result(handshake.handshake())
},
HandshakeMachine::WsMidHandshake(handshake) => {
Self::handle_ws_handshake_result(handshake.handshake())
},
}
}
fn handle_tls_handshake_result(
result: Result<TlsStream<TcpStream>, ::native_tls::HandshakeError<TcpStream>>,
) -> (Option<Either<EstablishedWs, Self>>, bool) {
match result {
Ok(stream) => {
println!("handshake established");
(Some(Either::Right(Self::TlsStream(stream))), false)
},
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
println!("interrupted");
(Some(Either::Right(Self::TlsMidHandshake(handshake))), true)
},
Err(native_tls::HandshakeError::Failure(err)) => {
dbg!(err);
(None, false)
}
}
}
fn handle_ws_handshake_result(
result: Result<WebSocket<Stream>, HandshakeError<ServerHandshake<Stream, DebugCallback>>> ,
) -> (Option<Either<EstablishedWs, Self>>, bool) {
match result {
Ok(mut ws) => {
println!("handshake established");
let peer_addr = ws.get_mut().get_peer_addr();
let established_ws = EstablishedWs {
ws,
peer_addr,
};
(Some(Either::Left(established_ws)), false)
},
Err(HandshakeError::Interrupted(handshake)) => {
println!("interrupted");
(Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))), true)
},
Err(HandshakeError::Failure(err)) => {
dbg!(err);
(None, false)
}
} }
} }
} }
@ -98,7 +183,7 @@ impl ConnectionStage {
pub struct Connection { pub struct Connection {
pub valid_until: ValidUntil, pub valid_until: ValidUntil,
pub stage: ConnectionStage, pub inner: Either<EstablishedWs, HandshakeMachine>,
} }

View file

@ -1,13 +1,12 @@
use std::time::Duration; use std::time::Duration;
use std::io::ErrorKind; use std::io::ErrorKind;
use tungstenite::WebSocket; use either::Either;
use tungstenite::handshake::{HandshakeError, server::ServerHandshake};
use hashbrown::HashMap; use hashbrown::HashMap;
use native_tls::{TlsAcceptor, TlsStream}; use native_tls::TlsAcceptor;
use mio::{Events, Poll, Interest, Token}; use mio::{Events, Poll, Interest, Token};
use mio::net::{TcpListener, TcpStream}; use mio::net::TcpListener;
use crate::common::*; use crate::common::*;
use crate::config::Config; use crate::config::Config;
@ -119,7 +118,7 @@ fn accept_new_streams(
let connection = Connection { let connection = Connection {
valid_until, valid_until,
stage: ConnectionStage::TcpStream(stream) inner: Either::Right(HandshakeMachine::new(stream))
}; };
connections.insert(token, connection); connections.insert(token, connection);
@ -136,56 +135,6 @@ fn accept_new_streams(
} }
pub fn handle_tls_handshake_result(
result: Result<TlsStream<TcpStream>, ::native_tls::HandshakeError<TcpStream>>,
) -> (Option<ConnectionStage>, bool) {
match result {
Ok(stream) => {
println!("handshake established");
(Some(ConnectionStage::TlsStream(stream)), false)
},
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
println!("interrupted");
(Some(ConnectionStage::TlsMidHandshake(handshake)), true)
},
Err(native_tls::HandshakeError::Failure(err)) => {
(None, false)
}
}
}
pub fn handle_ws_handshake_result(
result: Result<WebSocket<Stream>, HandshakeError<ServerHandshake<Stream, DebugCallback>>> ,
) -> (Option<ConnectionStage>, bool) {
match result {
Ok(mut ws) => {
println!("handshake established");
let peer_addr = ws.get_mut().get_peer_addr();
let established_ws = EstablishedWs {
ws,
peer_addr,
};
(Some(ConnectionStage::EstablishedWs(established_ws)), false)
},
Err(HandshakeError::Interrupted(handshake)) => {
println!("interrupted");
(Some(ConnectionStage::WsMidHandshake(handshake)), true)
},
Err(HandshakeError::Failure(err)) => {
dbg!(err);
(None, false)
}
}
}
/// On the stream given by poll_token, get TLS (if requested) and tungstenite /// On the stream given by poll_token, get TLS (if requested) and tungstenite
/// up and running, then read messages and pass on through channel. /// up and running, then read messages and pass on through channel.
@ -199,7 +148,7 @@ pub fn run_handshakes_and_read_messages(
){ ){
loop { loop {
if let Some(Connection { if let Some(Connection {
stage: ConnectionStage::EstablishedWs(established_ws), inner: Either::Left(established_ws),
.. ..
}) = connections.get_mut(&poll_token){ }) = connections.get_mut(&poll_token){
use ::tungstenite::Error::Io; use ::tungstenite::Error::Io;
@ -231,43 +180,17 @@ pub fn run_handshakes_and_read_messages(
break; break;
} }
} }
} else if let Some(connection) = connections.remove(&poll_token) { } else if let Some(Connection {
let (opt_new_stage, stop_loop) = match connection.stage { inner: Either::Right(machine),
ConnectionStage::TcpStream(stream) => { ..
if let Some(tls_acceptor) = opt_tls_acceptor { }) = connections.remove(&poll_token) {
handle_tls_handshake_result( let (result, stop_loop) = machine
tls_acceptor.accept(stream) .advance(opt_tls_acceptor);
)
} else {
let handshake_result = ::tungstenite::server::accept_hdr(
Stream::TcpStream(stream),
DebugCallback
);
handle_ws_handshake_result(handshake_result) if let Some(inner) = result {
}
},
ConnectionStage::TlsStream(stream) => {
let handshake_result = ::tungstenite::server::accept_hdr(
Stream::TlsStream(stream),
DebugCallback
);
handle_ws_handshake_result(handshake_result)
},
ConnectionStage::TlsMidHandshake(handshake) => {
handle_tls_handshake_result(handshake.handshake())
},
ConnectionStage::WsMidHandshake(handshake) => {
handle_ws_handshake_result(handshake.handshake())
},
ConnectionStage::EstablishedWs(_) => unreachable!(),
};
if let Some(stage) = opt_new_stage {
let connection = Connection { let connection = Connection {
valid_until, valid_until,
stage, inner,
}; };
connections.insert(poll_token, connection); connections.insert(poll_token, connection);
@ -287,11 +210,11 @@ pub fn send_out_messages(
connections: &mut ConnectionMap, connections: &mut ConnectionMap,
){ ){
for (meta, out_message) in out_message_receiver { for (meta, out_message) in out_message_receiver {
let opt_stage = connections let opt_inner = connections
.get_mut(&meta.poll_token) .get_mut(&meta.poll_token)
.map(|v| &mut v.stage); .map(|v| &mut v.inner);
if let Some(ConnectionStage::EstablishedWs(connection)) = opt_stage { if let Some(Either::Left(connection)) = opt_inner {
if connection.peer_addr != meta.peer_addr { if connection.peer_addr != meta.peer_addr {
eprintln!("socket worker: peer socket addrs didn't match"); eprintln!("socket worker: peer socket addrs didn't match");

View file

@ -2,6 +2,7 @@ use std::fs::File;
use std::io::Read; use std::io::Read;
use std::time::Instant; use std::time::Instant;
use either::Either;
use mio::Token; use mio::Token;
use native_tls::{Identity, TlsAcceptor}; use native_tls::{Identity, TlsAcceptor};
use net2::{TcpBuilder, unix::UnixTcpBuilderExt}; use net2::{TcpBuilder, unix::UnixTcpBuilderExt};
@ -58,7 +59,7 @@ pub fn create_tls_acceptor(
pub fn close_connection(connection: &mut Connection){ pub fn close_connection(connection: &mut Connection){
if let ConnectionStage::EstablishedWs(ref mut ews) = connection.stage { if let Either::Left(ref mut ews) = connection.inner {
if ews.ws.can_read(){ if ews.ws.can_read(){
ews.ws.close(None).unwrap(); ews.ws.close(None).unwrap();