diff --git a/aquatic_http/src/lib/glommio/network.rs b/aquatic_http/src/lib/glommio/network.rs index 8bc1ff6..7f1fcdf 100644 --- a/aquatic_http/src/lib/glommio/network.rs +++ b/aquatic_http/src/lib/glommio/network.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::io::{BufReader, Cursor, Read}; use std::rc::Rc; use std::sync::Arc; @@ -5,8 +6,8 @@ use std::sync::Arc; use aquatic_http_protocol::request::{Request, RequestParseError}; use aquatic_http_protocol::response::Response; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; -use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; -use glommio::channels::shared_channel::{ConnectedSender, SharedSender}; +use glommio::channels::channel_mesh::{MeshBuilder, Partial, Receivers, Role, Senders}; +use glommio::channels::shared_channel::{ConnectedReceiver, ConnectedSender, SharedSender}; use glommio::prelude::*; use glommio::net::{TcpListener, TcpStream}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; @@ -16,32 +17,43 @@ use slab::Slab; use crate::config::Config; +#[derive(Clone, Copy, Debug)] +pub struct ConnectionId(pub usize); + struct ConnectionReference { response_sender: LocalSender, handle: JoinHandle<()>, } struct Connection { - request_senders: Rc>, + request_senders: Rc>, response_receiver: LocalReceiver, tls: ServerConnection, stream: TcpStream, - index: usize, + index: ConnectionId, expects_request: bool, } pub async fn run_socket_worker( config: Config, - request_mesh_builder: MeshBuilder, + request_mesh_builder: MeshBuilder<(ConnectionId, Request), Partial>, + response_mesh_builder: MeshBuilder<(ConnectionId, Response), Partial>, ) { let tls_config = Arc::new(create_tls_config(&config)); let config = Rc::new(config); let listener = TcpListener::bind(config.network.address).expect("bind socket"); + + let (_, mut response_receivers) = response_mesh_builder.join(Role::Consumer).await.unwrap(); + let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap(); let request_senders = Rc::new(request_senders); - let mut connection_slab: Slab = Slab::new(); + let connection_slab = Rc::new(RefCell::new(Slab::new())); + + for (_, response_receiver) in response_receivers.streams() { + spawn_local(receive_responses(response_receiver, connection_slab.clone())).detach(); + } let mut incoming = listener.incoming(); @@ -50,14 +62,15 @@ pub async fn run_socket_worker( Ok(stream) => { let (response_sender, response_receiver) = new_bounded(1); - let entry = connection_slab.vacant_entry(); + let mut slab = connection_slab.borrow_mut(); + let entry = slab.vacant_entry(); let conn = Connection { request_senders: request_senders.clone(), response_receiver, tls: ServerConnection::new(tls_config.clone()).unwrap(), stream, - index: entry.key(), + index: ConnectionId(entry.key()), expects_request: true, }; @@ -84,6 +97,17 @@ pub async fn run_socket_worker( } } +async fn receive_responses( + mut response_receiver: ConnectedReceiver<(ConnectionId, Response)>, + connection_references: Rc>>, +) { + while let Some((connection_id, response)) = response_receiver.next().await { + if let Some(reference) = connection_references.borrow().get(connection_id.0) { + reference.response_sender.try_send(response); + } + } +} + impl Connection { async fn handle_stream(&mut self) -> anyhow::Result<()> { loop { @@ -94,7 +118,7 @@ impl Connection { if self.expects_request { let request = self.extract_request()?; - self.request_senders.try_send_to(0, request); + self.request_senders.try_send_to(0, (self.index, request)); self.expects_request = false; } else if let Some(response) = self.response_receiver.recv().await {