aquatic_ws: create HandshakeMachine from ConnectionStage

This commit is contained in:
Joakim Frostegård 2020-05-13 19:17:33 +02:00
parent 2967129c1f
commit 0bcfffb2bd
5 changed files with 122 additions and 105 deletions

View file

@ -1,13 +1,12 @@
use std::time::Duration;
use std::io::ErrorKind;
use tungstenite::WebSocket;
use tungstenite::handshake::{HandshakeError, server::ServerHandshake};
use either::Either;
use hashbrown::HashMap;
use native_tls::{TlsAcceptor, TlsStream};
use native_tls::TlsAcceptor;
use mio::{Events, Poll, Interest, Token};
use mio::net::{TcpListener, TcpStream};
use mio::net::TcpListener;
use crate::common::*;
use crate::config::Config;
@ -119,7 +118,7 @@ fn accept_new_streams(
let connection = Connection {
valid_until,
stage: ConnectionStage::TcpStream(stream)
inner: Either::Right(HandshakeMachine::new(stream))
};
connections.insert(token, connection);
@ -136,56 +135,6 @@ fn accept_new_streams(
}
pub fn handle_tls_handshake_result(
result: Result<TlsStream<TcpStream>, ::native_tls::HandshakeError<TcpStream>>,
) -> (Option<ConnectionStage>, bool) {
match result {
Ok(stream) => {
println!("handshake established");
(Some(ConnectionStage::TlsStream(stream)), false)
},
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
println!("interrupted");
(Some(ConnectionStage::TlsMidHandshake(handshake)), true)
},
Err(native_tls::HandshakeError::Failure(err)) => {
(None, false)
}
}
}
pub fn handle_ws_handshake_result(
result: Result<WebSocket<Stream>, HandshakeError<ServerHandshake<Stream, DebugCallback>>> ,
) -> (Option<ConnectionStage>, bool) {
match result {
Ok(mut ws) => {
println!("handshake established");
let peer_addr = ws.get_mut().get_peer_addr();
let established_ws = EstablishedWs {
ws,
peer_addr,
};
(Some(ConnectionStage::EstablishedWs(established_ws)), false)
},
Err(HandshakeError::Interrupted(handshake)) => {
println!("interrupted");
(Some(ConnectionStage::WsMidHandshake(handshake)), true)
},
Err(HandshakeError::Failure(err)) => {
dbg!(err);
(None, false)
}
}
}
/// On the stream given by poll_token, get TLS (if requested) and tungstenite
/// up and running, then read messages and pass on through channel.
@ -199,7 +148,7 @@ pub fn run_handshakes_and_read_messages(
){
loop {
if let Some(Connection {
stage: ConnectionStage::EstablishedWs(established_ws),
inner: Either::Left(established_ws),
..
}) = connections.get_mut(&poll_token){
use ::tungstenite::Error::Io;
@ -231,43 +180,17 @@ pub fn run_handshakes_and_read_messages(
break;
}
}
} else if let Some(connection) = connections.remove(&poll_token) {
let (opt_new_stage, stop_loop) = match connection.stage {
ConnectionStage::TcpStream(stream) => {
if let Some(tls_acceptor) = opt_tls_acceptor {
handle_tls_handshake_result(
tls_acceptor.accept(stream)
)
} else {
let handshake_result = ::tungstenite::server::accept_hdr(
Stream::TcpStream(stream),
DebugCallback
);
} else if let Some(Connection {
inner: Either::Right(machine),
..
}) = connections.remove(&poll_token) {
let (result, stop_loop) = machine
.advance(opt_tls_acceptor);
handle_ws_handshake_result(handshake_result)
}
},
ConnectionStage::TlsStream(stream) => {
let handshake_result = ::tungstenite::server::accept_hdr(
Stream::TlsStream(stream),
DebugCallback
);
handle_ws_handshake_result(handshake_result)
},
ConnectionStage::TlsMidHandshake(handshake) => {
handle_tls_handshake_result(handshake.handshake())
},
ConnectionStage::WsMidHandshake(handshake) => {
handle_ws_handshake_result(handshake.handshake())
},
ConnectionStage::EstablishedWs(_) => unreachable!(),
};
if let Some(stage) = opt_new_stage {
if let Some(inner) = result {
let connection = Connection {
valid_until,
stage,
inner,
};
connections.insert(poll_token, connection);
@ -287,11 +210,11 @@ pub fn send_out_messages(
connections: &mut ConnectionMap,
){
for (meta, out_message) in out_message_receiver {
let opt_stage = connections
let opt_inner = connections
.get_mut(&meta.poll_token)
.map(|v| &mut v.stage);
.map(|v| &mut v.inner);
if let Some(ConnectionStage::EstablishedWs(connection)) = opt_stage {
if let Some(Either::Left(connection)) = opt_inner {
if connection.peer_addr != meta.peer_addr {
eprintln!("socket worker: peer socket addrs didn't match");