mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-03-31 17:55:36 +00:00
aquatic_ws: create HandshakeMachine from ConnectionStage
This commit is contained in:
parent
2967129c1f
commit
0bcfffb2bd
5 changed files with 122 additions and 105 deletions
7
Cargo.lock
generated
7
Cargo.lock
generated
|
|
@ -93,6 +93,7 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aquatic_common",
|
"aquatic_common",
|
||||||
"cli_helpers",
|
"cli_helpers",
|
||||||
|
"either",
|
||||||
"flume",
|
"flume",
|
||||||
"hashbrown",
|
"hashbrown",
|
||||||
"indexmap",
|
"indexmap",
|
||||||
|
|
@ -330,6 +331,12 @@ dependencies = [
|
||||||
"generic-array",
|
"generic-array",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "encode_unicode"
|
name = "encode_unicode"
|
||||||
version = "0.3.6"
|
version = "0.3.6"
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ path = "src/bin/main.rs"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
aquatic_common = { path = "../aquatic_common" }
|
aquatic_common = { path = "../aquatic_common" }
|
||||||
cli_helpers = { path = "../cli_helpers" }
|
cli_helpers = { path = "../cli_helpers" }
|
||||||
|
either = "1"
|
||||||
flume = "0.7"
|
flume = "0.7"
|
||||||
hashbrown = { version = "0.7", features = ["serde"] }
|
hashbrown = { version = "0.7", features = ["serde"] }
|
||||||
indexmap = "1"
|
indexmap = "1"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
use std::net::{SocketAddr};
|
use std::net::{SocketAddr};
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
|
|
||||||
|
use either::Either;
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
use mio::Token;
|
use mio::Token;
|
||||||
use mio::net::TcpStream;
|
use mio::net::TcpStream;
|
||||||
use native_tls::{TlsStream, MidHandshakeTlsStream};
|
use native_tls::{TlsAcceptor, TlsStream, MidHandshakeTlsStream};
|
||||||
use tungstenite::WebSocket;
|
use tungstenite::WebSocket;
|
||||||
use tungstenite::handshake::{MidHandshake, server::ServerHandshake};
|
use tungstenite::handshake::{MidHandshake, HandshakeError};
|
||||||
|
use tungstenite::server::ServerHandshake;
|
||||||
|
|
||||||
use crate::common::*;
|
use crate::common::*;
|
||||||
|
|
||||||
|
|
@ -71,26 +73,109 @@ impl Write for Stream {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub struct EstablishedWs<S> {
|
pub struct EstablishedWs {
|
||||||
pub ws: WebSocket<S>,
|
pub ws: WebSocket<Stream>,
|
||||||
pub peer_addr: SocketAddr,
|
pub peer_addr: SocketAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub enum ConnectionStage {
|
pub enum HandshakeMachine {
|
||||||
TcpStream(TcpStream),
|
TcpStream(TcpStream),
|
||||||
TlsStream(TlsStream<TcpStream>),
|
TlsStream(TlsStream<TcpStream>),
|
||||||
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
|
TlsMidHandshake(MidHandshakeTlsStream<TcpStream>),
|
||||||
WsMidHandshake(MidHandshake<ServerHandshake<Stream, DebugCallback>>),
|
WsMidHandshake(MidHandshake<ServerHandshake<Stream, DebugCallback>>),
|
||||||
EstablishedWs(EstablishedWs<Stream>),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl ConnectionStage {
|
impl HandshakeMachine {
|
||||||
pub fn is_established(&self) -> bool {
|
pub fn new(tcp_stream: TcpStream) -> Self {
|
||||||
|
Self::TcpStream(tcp_stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn advance(
|
||||||
|
self,
|
||||||
|
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS
|
||||||
|
) -> (Option<Either<EstablishedWs, Self>>, bool) { // bool = stop looping
|
||||||
match self {
|
match self {
|
||||||
Self::EstablishedWs(_) => true,
|
HandshakeMachine::TcpStream(stream) => {
|
||||||
_ => false,
|
if let Some(tls_acceptor) = opt_tls_acceptor {
|
||||||
|
Self::handle_tls_handshake_result(
|
||||||
|
tls_acceptor.accept(stream)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
let handshake_result = ::tungstenite::server::accept_hdr(
|
||||||
|
Stream::TcpStream(stream),
|
||||||
|
DebugCallback
|
||||||
|
);
|
||||||
|
|
||||||
|
Self::handle_ws_handshake_result(handshake_result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
HandshakeMachine::TlsStream(stream) => {
|
||||||
|
let handshake_result = ::tungstenite::server::accept_hdr(
|
||||||
|
Stream::TlsStream(stream),
|
||||||
|
DebugCallback
|
||||||
|
);
|
||||||
|
|
||||||
|
Self::handle_ws_handshake_result(handshake_result)
|
||||||
|
},
|
||||||
|
HandshakeMachine::TlsMidHandshake(handshake) => {
|
||||||
|
Self::handle_tls_handshake_result(handshake.handshake())
|
||||||
|
},
|
||||||
|
HandshakeMachine::WsMidHandshake(handshake) => {
|
||||||
|
Self::handle_ws_handshake_result(handshake.handshake())
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_tls_handshake_result(
|
||||||
|
result: Result<TlsStream<TcpStream>, ::native_tls::HandshakeError<TcpStream>>,
|
||||||
|
) -> (Option<Either<EstablishedWs, Self>>, bool) {
|
||||||
|
match result {
|
||||||
|
Ok(stream) => {
|
||||||
|
println!("handshake established");
|
||||||
|
|
||||||
|
(Some(Either::Right(Self::TlsStream(stream))), false)
|
||||||
|
},
|
||||||
|
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
|
||||||
|
println!("interrupted");
|
||||||
|
|
||||||
|
(Some(Either::Right(Self::TlsMidHandshake(handshake))), true)
|
||||||
|
},
|
||||||
|
Err(native_tls::HandshakeError::Failure(err)) => {
|
||||||
|
dbg!(err);
|
||||||
|
|
||||||
|
(None, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_ws_handshake_result(
|
||||||
|
result: Result<WebSocket<Stream>, HandshakeError<ServerHandshake<Stream, DebugCallback>>> ,
|
||||||
|
) -> (Option<Either<EstablishedWs, Self>>, bool) {
|
||||||
|
match result {
|
||||||
|
Ok(mut ws) => {
|
||||||
|
println!("handshake established");
|
||||||
|
|
||||||
|
let peer_addr = ws.get_mut().get_peer_addr();
|
||||||
|
|
||||||
|
let established_ws = EstablishedWs {
|
||||||
|
ws,
|
||||||
|
peer_addr,
|
||||||
|
};
|
||||||
|
|
||||||
|
(Some(Either::Left(established_ws)), false)
|
||||||
|
},
|
||||||
|
Err(HandshakeError::Interrupted(handshake)) => {
|
||||||
|
println!("interrupted");
|
||||||
|
|
||||||
|
(Some(Either::Right(HandshakeMachine::WsMidHandshake(handshake))), true)
|
||||||
|
},
|
||||||
|
Err(HandshakeError::Failure(err)) => {
|
||||||
|
dbg!(err);
|
||||||
|
|
||||||
|
(None, false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -98,7 +183,7 @@ impl ConnectionStage {
|
||||||
|
|
||||||
pub struct Connection {
|
pub struct Connection {
|
||||||
pub valid_until: ValidUntil,
|
pub valid_until: ValidUntil,
|
||||||
pub stage: ConnectionStage,
|
pub inner: Either<EstablishedWs, HandshakeMachine>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::io::ErrorKind;
|
use std::io::ErrorKind;
|
||||||
|
|
||||||
use tungstenite::WebSocket;
|
use either::Either;
|
||||||
use tungstenite::handshake::{HandshakeError, server::ServerHandshake};
|
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
use native_tls::{TlsAcceptor, TlsStream};
|
use native_tls::TlsAcceptor;
|
||||||
|
|
||||||
use mio::{Events, Poll, Interest, Token};
|
use mio::{Events, Poll, Interest, Token};
|
||||||
use mio::net::{TcpListener, TcpStream};
|
use mio::net::TcpListener;
|
||||||
|
|
||||||
use crate::common::*;
|
use crate::common::*;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
|
@ -119,7 +118,7 @@ fn accept_new_streams(
|
||||||
|
|
||||||
let connection = Connection {
|
let connection = Connection {
|
||||||
valid_until,
|
valid_until,
|
||||||
stage: ConnectionStage::TcpStream(stream)
|
inner: Either::Right(HandshakeMachine::new(stream))
|
||||||
};
|
};
|
||||||
|
|
||||||
connections.insert(token, connection);
|
connections.insert(token, connection);
|
||||||
|
|
@ -136,56 +135,6 @@ fn accept_new_streams(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn handle_tls_handshake_result(
|
|
||||||
result: Result<TlsStream<TcpStream>, ::native_tls::HandshakeError<TcpStream>>,
|
|
||||||
) -> (Option<ConnectionStage>, bool) {
|
|
||||||
match result {
|
|
||||||
Ok(stream) => {
|
|
||||||
println!("handshake established");
|
|
||||||
|
|
||||||
(Some(ConnectionStage::TlsStream(stream)), false)
|
|
||||||
},
|
|
||||||
Err(native_tls::HandshakeError::WouldBlock(handshake)) => {
|
|
||||||
println!("interrupted");
|
|
||||||
|
|
||||||
(Some(ConnectionStage::TlsMidHandshake(handshake)), true)
|
|
||||||
},
|
|
||||||
Err(native_tls::HandshakeError::Failure(err)) => {
|
|
||||||
(None, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
pub fn handle_ws_handshake_result(
|
|
||||||
result: Result<WebSocket<Stream>, HandshakeError<ServerHandshake<Stream, DebugCallback>>> ,
|
|
||||||
) -> (Option<ConnectionStage>, bool) {
|
|
||||||
match result {
|
|
||||||
Ok(mut ws) => {
|
|
||||||
println!("handshake established");
|
|
||||||
|
|
||||||
let peer_addr = ws.get_mut().get_peer_addr();
|
|
||||||
|
|
||||||
let established_ws = EstablishedWs {
|
|
||||||
ws,
|
|
||||||
peer_addr,
|
|
||||||
};
|
|
||||||
|
|
||||||
(Some(ConnectionStage::EstablishedWs(established_ws)), false)
|
|
||||||
},
|
|
||||||
Err(HandshakeError::Interrupted(handshake)) => {
|
|
||||||
println!("interrupted");
|
|
||||||
|
|
||||||
(Some(ConnectionStage::WsMidHandshake(handshake)), true)
|
|
||||||
},
|
|
||||||
Err(HandshakeError::Failure(err)) => {
|
|
||||||
dbg!(err);
|
|
||||||
|
|
||||||
(None, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/// On the stream given by poll_token, get TLS (if requested) and tungstenite
|
/// On the stream given by poll_token, get TLS (if requested) and tungstenite
|
||||||
/// up and running, then read messages and pass on through channel.
|
/// up and running, then read messages and pass on through channel.
|
||||||
|
|
@ -199,7 +148,7 @@ pub fn run_handshakes_and_read_messages(
|
||||||
){
|
){
|
||||||
loop {
|
loop {
|
||||||
if let Some(Connection {
|
if let Some(Connection {
|
||||||
stage: ConnectionStage::EstablishedWs(established_ws),
|
inner: Either::Left(established_ws),
|
||||||
..
|
..
|
||||||
}) = connections.get_mut(&poll_token){
|
}) = connections.get_mut(&poll_token){
|
||||||
use ::tungstenite::Error::Io;
|
use ::tungstenite::Error::Io;
|
||||||
|
|
@ -231,43 +180,17 @@ pub fn run_handshakes_and_read_messages(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if let Some(connection) = connections.remove(&poll_token) {
|
} else if let Some(Connection {
|
||||||
let (opt_new_stage, stop_loop) = match connection.stage {
|
inner: Either::Right(machine),
|
||||||
ConnectionStage::TcpStream(stream) => {
|
..
|
||||||
if let Some(tls_acceptor) = opt_tls_acceptor {
|
}) = connections.remove(&poll_token) {
|
||||||
handle_tls_handshake_result(
|
let (result, stop_loop) = machine
|
||||||
tls_acceptor.accept(stream)
|
.advance(opt_tls_acceptor);
|
||||||
)
|
|
||||||
} else {
|
|
||||||
let handshake_result = ::tungstenite::server::accept_hdr(
|
|
||||||
Stream::TcpStream(stream),
|
|
||||||
DebugCallback
|
|
||||||
);
|
|
||||||
|
|
||||||
handle_ws_handshake_result(handshake_result)
|
if let Some(inner) = result {
|
||||||
}
|
|
||||||
},
|
|
||||||
ConnectionStage::TlsStream(stream) => {
|
|
||||||
let handshake_result = ::tungstenite::server::accept_hdr(
|
|
||||||
Stream::TlsStream(stream),
|
|
||||||
DebugCallback
|
|
||||||
);
|
|
||||||
|
|
||||||
handle_ws_handshake_result(handshake_result)
|
|
||||||
},
|
|
||||||
ConnectionStage::TlsMidHandshake(handshake) => {
|
|
||||||
handle_tls_handshake_result(handshake.handshake())
|
|
||||||
},
|
|
||||||
ConnectionStage::WsMidHandshake(handshake) => {
|
|
||||||
handle_ws_handshake_result(handshake.handshake())
|
|
||||||
},
|
|
||||||
ConnectionStage::EstablishedWs(_) => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(stage) = opt_new_stage {
|
|
||||||
let connection = Connection {
|
let connection = Connection {
|
||||||
valid_until,
|
valid_until,
|
||||||
stage,
|
inner,
|
||||||
};
|
};
|
||||||
|
|
||||||
connections.insert(poll_token, connection);
|
connections.insert(poll_token, connection);
|
||||||
|
|
@ -287,11 +210,11 @@ pub fn send_out_messages(
|
||||||
connections: &mut ConnectionMap,
|
connections: &mut ConnectionMap,
|
||||||
){
|
){
|
||||||
for (meta, out_message) in out_message_receiver {
|
for (meta, out_message) in out_message_receiver {
|
||||||
let opt_stage = connections
|
let opt_inner = connections
|
||||||
.get_mut(&meta.poll_token)
|
.get_mut(&meta.poll_token)
|
||||||
.map(|v| &mut v.stage);
|
.map(|v| &mut v.inner);
|
||||||
|
|
||||||
if let Some(ConnectionStage::EstablishedWs(connection)) = opt_stage {
|
if let Some(Either::Left(connection)) = opt_inner {
|
||||||
if connection.peer_addr != meta.peer_addr {
|
if connection.peer_addr != meta.peer_addr {
|
||||||
eprintln!("socket worker: peer socket addrs didn't match");
|
eprintln!("socket worker: peer socket addrs didn't match");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ use std::fs::File;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use either::Either;
|
||||||
use mio::Token;
|
use mio::Token;
|
||||||
use native_tls::{Identity, TlsAcceptor};
|
use native_tls::{Identity, TlsAcceptor};
|
||||||
use net2::{TcpBuilder, unix::UnixTcpBuilderExt};
|
use net2::{TcpBuilder, unix::UnixTcpBuilderExt};
|
||||||
|
|
@ -58,7 +59,7 @@ pub fn create_tls_acceptor(
|
||||||
|
|
||||||
|
|
||||||
pub fn close_connection(connection: &mut Connection){
|
pub fn close_connection(connection: &mut Connection){
|
||||||
if let ConnectionStage::EstablishedWs(ref mut ews) = connection.stage {
|
if let Either::Left(ref mut ews) = connection.inner {
|
||||||
if ews.ws.can_read(){
|
if ews.ws.can_read(){
|
||||||
ews.ws.close(None).unwrap();
|
ews.ws.close(None).unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue