aquatic_udp: glommio: return Request in read_tls to reduce state

This commit is contained in:
Joakim Frostegård 2021-10-26 22:21:38 +02:00
parent 636a434ca6
commit 8a66b5ce69

View file

@ -1,5 +1,5 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::io::{BufReader, Cursor, ErrorKind, Read, Write}; use std::io::{Cursor, ErrorKind, Read, Write};
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
@ -13,7 +13,7 @@ use glommio::prelude::*;
use glommio::net::{TcpListener, TcpStream}; use glommio::net::{TcpListener, TcpStream};
use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender};
use glommio::task::JoinHandle; use glommio::task::JoinHandle;
use rustls::{IoState, ServerConnection}; use rustls::{ServerConnection};
use slab::Slab; use slab::Slab;
use crate::common::num_digits_in_usize; use crate::common::num_digits_in_usize;
@ -37,7 +37,6 @@ struct Connection {
stream: TcpStream, stream: TcpStream,
index: ConnectionId, index: ConnectionId,
request_buffer: Vec<u8>, request_buffer: Vec<u8>,
wait_for_response: bool,
close_after_writing: bool, close_after_writing: bool,
} }
@ -82,7 +81,6 @@ pub async fn run_socket_worker(
stream, stream,
index: ConnectionId(entry.key()), index: ConnectionId(entry.key()),
request_buffer: Vec::new(), request_buffer: Vec::new(),
wait_for_response: false,
close_after_writing: false, close_after_writing: false,
}; };
@ -123,13 +121,20 @@ async fn receive_responses(
impl Connection { impl Connection {
async fn handle_stream(&mut self) -> anyhow::Result<()> { async fn handle_stream(&mut self) -> anyhow::Result<()> {
loop { loop {
self.read_tls().await?; let opt_request = self.read_tls().await?;
if self.wait_for_response { if let Some(request) = opt_request {
let peer_addr = self.stream
.peer_addr()
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?;
// TODO: send request to channel
// Wait for response to arrive, then send it
if let Some(response) = self.response_receiver.recv().await { if let Some(response) = self.response_receiver.recv().await {
self.queue_response(&response)?; // TODO: compare IP addresses?
self.wait_for_response = false; self.queue_response(&response)?;
if !self.config.network.keep_alive { if !self.config.network.keep_alive {
self.close_after_writing = true; self.close_after_writing = true;
@ -149,7 +154,7 @@ impl Connection {
Ok(()) Ok(())
} }
async fn read_tls(&mut self) -> anyhow::Result<()> { async fn read_tls(&mut self) -> anyhow::Result<Option<Request>> {
loop { loop {
::log::debug!("read_tls"); ::log::debug!("read_tls");
@ -198,9 +203,9 @@ impl Connection {
if added_plaintext { if added_plaintext {
match Request::from_bytes(&self.request_buffer[..]) { match Request::from_bytes(&self.request_buffer[..]) {
Ok(request) => { Ok(request) => {
self.wait_for_response = true; ::log::debug!("received request: {:?}", request);
::log::trace!("received request: {:?}", request); return Ok(Some(request));
} }
Err(RequestParseError::NeedMoreData) => { Err(RequestParseError::NeedMoreData) => {
::log::debug!("need more request data. current data: {:?}", std::str::from_utf8(&self.request_buffer)); ::log::debug!("need more request data. current data: {:?}", std::str::from_utf8(&self.request_buffer));
@ -225,7 +230,7 @@ impl Connection {
} }
} }
Ok(()) Ok(None)
} }
async fn write_tls(&mut self) -> anyhow::Result<()> { async fn write_tls(&mut self) -> anyhow::Result<()> {