aquatic_ws: when removing connection, reregister from poll

This commit is contained in:
Joakim Frostegård 2020-08-10 03:51:08 +02:00
parent fbcd5aa7c9
commit 1a3ab54b3f
3 changed files with 52 additions and 10 deletions

View file

@ -4,7 +4,7 @@ use std::io::{Read, Write};
use either::Either;
use hashbrown::HashMap;
use log::info;
use mio::Token;
use mio::{Poll, Token};
use mio::net::TcpStream;
use native_tls::{TlsAcceptor, TlsStream, MidHandshakeTlsStream};
use tungstenite::WebSocket;
@ -29,6 +29,16 @@ impl Stream {
Self::TlsStream(stream) => stream.get_ref().peer_addr().unwrap(),
}
}
#[inline]
pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
match self {
Self::TcpStream(stream) =>
poll.registry().deregister(stream),
Self::TlsStream(stream) =>
poll.registry().deregister(stream.get_mut()),
}
}
}
@ -274,6 +284,28 @@ impl Connection {
}
}
}
pub fn deregister(&mut self, poll: &mut Poll) -> ::std::io::Result<()> {
use Either::{Left, Right};
match self.inner {
Left(EstablishedWs { ref mut ws, .. }) => {
ws.get_mut().deregister(poll)
},
Right(HandshakeMachine::TcpStream(ref mut stream)) => {
poll.registry().deregister(stream)
},
Right(HandshakeMachine::TlsMidHandshake(ref mut handshake)) => {
poll.registry().deregister(handshake.get_mut())
},
Right(HandshakeMachine::TlsStream(ref mut stream)) => {
poll.registry().deregister(stream.get_mut())
},
Right(HandshakeMachine::WsMidHandshake(ref mut handshake)) => {
handshake.get_mut().get_mut().deregister(poll)
},
}
}
}

View file

@ -105,6 +105,7 @@ pub fn run_poll_loop(
socket_worker_index,
&in_message_sender,
&opt_tls_acceptor,
&mut poll,
&mut connections,
token,
valid_until,
@ -114,6 +115,7 @@ pub fn run_poll_loop(
if !out_message_receiver.is_empty(){
send_out_messages(
&mut poll,
&out_message_receiver,
&mut connections
);
@ -148,7 +150,7 @@ fn accept_new_streams(
let token = *poll_token_counter;
remove_connection_if_exists(connections, token);
remove_connection_if_exists(poll, connections, token);
poll.registry()
.register(&mut stream, token, Interest::READABLE)
@ -176,6 +178,7 @@ pub fn run_handshakes_and_read_messages(
socket_worker_index: usize,
in_message_sender: &InMessageSender,
opt_tls_acceptor: &Option<TlsAcceptor>, // If set, run TLS
poll: &mut Poll,
connections: &mut ConnectionMap,
poll_token: Token,
valid_until: ValidUntil,
@ -222,14 +225,14 @@ pub fn run_handshakes_and_read_messages(
break;
},
Err(tungstenite::Error::ConnectionClosed) => {
remove_connection_if_exists(connections, poll_token);
remove_connection_if_exists(poll, connections, poll_token);
break
},
Err(err) => {
info!("error reading messages: {}", err);
remove_connection_if_exists(connections, poll_token);
remove_connection_if_exists(poll, connections, poll_token);
break;
}
@ -256,6 +259,7 @@ pub fn run_handshakes_and_read_messages(
/// Read messages from channel, send to peers
pub fn send_out_messages(
poll: &mut Poll,
out_message_receiver: &Receiver<(ConnectionMeta, OutMessage)>,
connections: &mut ConnectionMap,
){
@ -280,12 +284,17 @@ pub fn send_out_messages(
},
Err(Io(err)) if err.kind() == ErrorKind::WouldBlock => {},
Err(tungstenite::Error::ConnectionClosed) => {
remove_connection_if_exists(connections, meta.poll_token);
remove_connection_if_exists(
poll,
connections,
meta.poll_token
);
},
Err(err) => {
info!("error writing ws message: {}", err);
remove_connection_if_exists(
poll,
connections,
meta.poll_token
);

View file

@ -1,7 +1,7 @@
use std::time::Instant;
use anyhow::Context;
use mio::Token;
use mio::{Poll, Token};
use socket2::{Socket, Domain, Type, Protocol};
use crate::config::Config;
@ -37,16 +37,17 @@ pub fn create_listener(
}
/// Don't bother with deregistering from Poll. In my understanding, this is
/// done automatically when the stream is dropped, as long as there are no
/// other references to the file descriptor, such as when it is accessed
/// in multiple threads.
pub fn remove_connection_if_exists(
poll: &mut Poll,
connections: &mut ConnectionMap,
token: Token,
){
if let Some(mut connection) = connections.remove(&token){
connection.close();
if let Err(err) = connection.deregister(poll){
::log::error!("couldn't deregister stream: {}", err);
}
}
}