aquatic_http: glommio: clean up Connection code, stop storing handle

This commit is contained in:
Joakim Frostegård 2021-11-01 01:53:00 +01:00
parent afce23e321
commit edec526d41

View file

@ -14,13 +14,12 @@ use aquatic_http_protocol::response::{
}; };
use either::Either; use either::Either;
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
use futures_rustls::TlsAcceptor;
use futures_rustls::server::TlsStream; use futures_rustls::server::TlsStream;
use futures_rustls::TlsAcceptor;
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender};
use glommio::channels::shared_channel::ConnectedReceiver; use glommio::channels::shared_channel::ConnectedReceiver;
use glommio::net::{TcpListener, TcpStream}; use glommio::net::{TcpListener, TcpStream};
use glommio::task::JoinHandle;
use glommio::timer::TimerActionRepeat; use glommio::timer::TimerActionRepeat;
use glommio::{enclose, prelude::*}; use glommio::{enclose, prelude::*};
use slab::Slab; use slab::Slab;
@ -40,7 +39,6 @@ struct PendingScrapeResponse {
struct ConnectionReference { struct ConnectionReference {
response_sender: LocalSender<ChannelResponse>, response_sender: LocalSender<ChannelResponse>,
handle: JoinHandle<()>,
} }
struct Connection { struct Connection {
@ -49,7 +47,8 @@ struct Connection {
request_senders: Rc<Senders<ChannelRequest>>, request_senders: Rc<Senders<ChannelRequest>>,
response_receiver: LocalReceiver<ChannelResponse>, response_receiver: LocalReceiver<ChannelResponse>,
response_consumer_id: ConsumerId, response_consumer_id: ConsumerId,
tls_acceptor: TlsAcceptor, stream: TlsStream<TcpStream>,
peer_addr: SocketAddr,
connection_id: ConnectionId, connection_id: ConnectionId,
request_buffer: [u8; MAX_REQUEST_SIZE], request_buffer: [u8; MAX_REQUEST_SIZE],
request_buffer_position: usize, request_buffer_position: usize,
@ -123,40 +122,27 @@ pub async fn run_socket_worker(
match stream { match stream {
Ok(stream) => { Ok(stream) => {
let (response_sender, response_receiver) = new_bounded(config.request_workers); let (response_sender, response_receiver) = new_bounded(config.request_workers);
let key = connection_slab
.borrow_mut()
.insert(ConnectionReference { response_sender });
let mut slab = connection_slab.borrow_mut(); spawn_local(enclose!((config, access_list, request_senders, tls_config, connections_to_remove) async move {
let entry = slab.vacant_entry(); if let Err(err) = Connection::run(
let key = entry.key(); config,
access_list,
let mut conn = Connection { request_senders,
config: config.clone(), response_receiver,
access_list: access_list.clone(), response_consumer_id,
request_senders: request_senders.clone(), ConnectionId(key),
response_receiver, tls_config,
response_consumer_id, stream
tls_acceptor: tls_config.clone().into(), ).await {
connection_id: ConnectionId(entry.key()), ::log::info!("Connection::run() error: {:?}", err);
request_buffer: [0u8; MAX_REQUEST_SIZE],
request_buffer_position: 0,
};
let connections_to_remove = connections_to_remove.clone();
let handle = spawn_local(async move {
if let Err(err) = conn.handle_stream(stream).await {
::log::info!("conn.handle_stream() error: {:?}", err);
} }
connections_to_remove.borrow_mut().push(key); connections_to_remove.borrow_mut().push(key);
}) }))
.detach(); .detach();
let connection_reference = ConnectionReference {
response_sender,
handle,
};
entry.insert(connection_reference);
} }
Err(err) => { Err(err) => {
::log::error!("accept connection: {:?}", err); ::log::error!("accept connection: {:?}", err);
@ -182,27 +168,62 @@ async fn receive_responses(
} }
impl Connection { impl Connection {
async fn handle_stream(&mut self, stream: TcpStream) -> anyhow::Result<()> { async fn run(
let peer_addr = Self::get_peer_addr(&stream)?; config: Rc<Config>,
let mut stream = self.tls_acceptor.accept(stream).await?; access_list: Rc<RefCell<AccessList>>,
request_senders: Rc<Senders<ChannelRequest>>,
response_receiver: LocalReceiver<ChannelResponse>,
response_consumer_id: ConsumerId,
connection_id: ConnectionId,
tls_config: Arc<TlsConfig>,
stream: TcpStream,
) -> anyhow::Result<()> {
let peer_addr = stream
.peer_addr()
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?;
let tls_acceptor: TlsAcceptor = tls_config.into();
let stream = tls_acceptor.accept(stream).await?;
let mut conn = Connection {
config: config.clone(),
access_list: access_list.clone(),
request_senders: request_senders.clone(),
response_receiver,
response_consumer_id,
stream,
peer_addr,
connection_id,
request_buffer: [0u8; MAX_REQUEST_SIZE],
request_buffer_position: 0,
};
conn.run_request_response_loop().await?;
Ok(())
}
async fn run_request_response_loop(&mut self) -> anyhow::Result<()> {
loop { loop {
let response = match self.read_request(&mut stream).await? { let response = match self.read_request().await? {
Either::Left(response) => Response::Failure(response), Either::Left(response) => Response::Failure(response),
Either::Right(request) => { Either::Right(request) => match self.handle_request(request).await? {
match self.handle_request(request, peer_addr).await? { Either::Left(response) => Response::Failure(response),
Either::Left(response) => Response::Failure(response), Either::Right(opt_pending_scrape_response) => {
Either::Right(opt_pending_scrape_response) => { self.wait_for_response(opt_pending_scrape_response).await?
self.wait_for_response(peer_addr, opt_pending_scrape_response).await?
}
} }
} },
}; };
self.write_response(&mut stream, &response).await?; self.write_response(&response).await?;
if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive { if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive {
Self::close_stream(stream).await; let _ = self
.stream
.get_ref()
.0
.shutdown(std::net::Shutdown::Both)
.await;
break; break;
} }
@ -211,23 +232,13 @@ impl Connection {
Ok(()) Ok(())
} }
fn get_peer_addr(stream: &TcpStream) -> anyhow::Result<SocketAddr> { async fn read_request(&mut self) -> anyhow::Result<Either<FailureResponse, Request>> {
stream
.peer_addr()
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))
}
async fn close_stream(stream: TlsStream<TcpStream>) {
let _ = stream.get_ref().0.shutdown(std::net::Shutdown::Both).await;
}
async fn read_request(&mut self, stream: &mut TlsStream<TcpStream>) -> anyhow::Result<Either<FailureResponse, Request>> {
let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE]; let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE];
loop { loop {
::log::debug!("read"); ::log::debug!("read");
let bytes_read = stream.read(&mut buf).await?; let bytes_read = self.stream.read(&mut buf).await?;
let request_buffer_end = self.request_buffer_position + bytes_read; let request_buffer_end = self.request_buffer_position + bytes_read;
if request_buffer_end > self.request_buffer.len() { if request_buffer_end > self.request_buffer.len() {
@ -276,7 +287,6 @@ impl Connection {
async fn handle_request( async fn handle_request(
&self, &self,
request: Request, request: Request,
peer_addr: SocketAddr,
) -> anyhow::Result<Either<FailureResponse, Option<PendingScrapeResponse>>> { ) -> anyhow::Result<Either<FailureResponse, Option<PendingScrapeResponse>>> {
match request { match request {
Request::Announce(request) => { Request::Announce(request) => {
@ -291,7 +301,7 @@ impl Connection {
request, request,
connection_id: self.connection_id, connection_id: self.connection_id,
response_consumer_id: self.response_consumer_id, response_consumer_id: self.response_consumer_id,
peer_addr, peer_addr: self.peer_addr,
}; };
let consumer_index = calculate_request_consumer_index(&self.config, info_hash); let consumer_index = calculate_request_consumer_index(&self.config, info_hash);
@ -327,7 +337,7 @@ impl Connection {
for (consumer_index, info_hashes) in info_hashes_by_worker { for (consumer_index, info_hashes) in info_hashes_by_worker {
let request = ChannelRequest::Scrape { let request = ChannelRequest::Scrape {
request: ScrapeRequest { info_hashes }, request: ScrapeRequest { info_hashes },
peer_addr, peer_addr: self.peer_addr,
response_consumer_id: self.response_consumer_id, response_consumer_id: self.response_consumer_id,
connection_id: self.connection_id, connection_id: self.connection_id,
}; };
@ -353,13 +363,12 @@ impl Connection {
/// return full response /// return full response
async fn wait_for_response( async fn wait_for_response(
&self, &self,
peer_addr: SocketAddr,
mut opt_pending_scrape_response: Option<PendingScrapeResponse>, mut opt_pending_scrape_response: Option<PendingScrapeResponse>,
) -> anyhow::Result<Response> { ) -> anyhow::Result<Response> {
loop { loop {
if let Some(channel_response) = self.response_receiver.recv().await { if let Some(channel_response) = self.response_receiver.recv().await {
if channel_response.get_peer_addr() != peer_addr { if channel_response.get_peer_addr() != self.peer_addr {
return Err(anyhow::anyhow!("peer addressess didn't match")); return Err(anyhow::anyhow!("peer addresses didn't match"));
} }
match channel_response { match channel_response {
@ -396,7 +405,7 @@ impl Connection {
} }
} }
async fn write_response(&mut self, stream: &mut TlsStream<TcpStream>, response: &Response) -> anyhow::Result<()> { async fn write_response(&mut self, response: &Response) -> anyhow::Result<()> {
let mut body = Vec::new(); let mut body = Vec::new();
response.write(&mut body).unwrap(); response.write(&mut body).unwrap();
@ -412,7 +421,8 @@ impl Connection {
response_bytes.append(&mut body); response_bytes.append(&mut body);
response_bytes.extend_from_slice(b"\r\n"); response_bytes.extend_from_slice(b"\r\n");
stream.write(&response_bytes[..]).await?; self.stream.write(&response_bytes[..]).await?;
self.stream.flush().await?;
Ok(()) Ok(())
} }