aquatic_http: glommio: clean up Connection code, stop storing handle

This commit is contained in:
Joakim Frostegård 2021-11-01 01:53:00 +01:00
parent afce23e321
commit edec526d41

View file

@ -14,13 +14,12 @@ use aquatic_http_protocol::response::{
};
use either::Either;
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
use futures_rustls::TlsAcceptor;
use futures_rustls::server::TlsStream;
use futures_rustls::TlsAcceptor;
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender};
use glommio::channels::shared_channel::ConnectedReceiver;
use glommio::net::{TcpListener, TcpStream};
use glommio::task::JoinHandle;
use glommio::timer::TimerActionRepeat;
use glommio::{enclose, prelude::*};
use slab::Slab;
@ -40,7 +39,6 @@ struct PendingScrapeResponse {
struct ConnectionReference {
response_sender: LocalSender<ChannelResponse>,
handle: JoinHandle<()>,
}
struct Connection {
@ -49,7 +47,8 @@ struct Connection {
request_senders: Rc<Senders<ChannelRequest>>,
response_receiver: LocalReceiver<ChannelResponse>,
response_consumer_id: ConsumerId,
tls_acceptor: TlsAcceptor,
stream: TlsStream<TcpStream>,
peer_addr: SocketAddr,
connection_id: ConnectionId,
request_buffer: [u8; MAX_REQUEST_SIZE],
request_buffer_position: usize,
@ -123,40 +122,27 @@ pub async fn run_socket_worker(
match stream {
Ok(stream) => {
let (response_sender, response_receiver) = new_bounded(config.request_workers);
let key = connection_slab
.borrow_mut()
.insert(ConnectionReference { response_sender });
let mut slab = connection_slab.borrow_mut();
let entry = slab.vacant_entry();
let key = entry.key();
let mut conn = Connection {
config: config.clone(),
access_list: access_list.clone(),
request_senders: request_senders.clone(),
response_receiver,
response_consumer_id,
tls_acceptor: tls_config.clone().into(),
connection_id: ConnectionId(entry.key()),
request_buffer: [0u8; MAX_REQUEST_SIZE],
request_buffer_position: 0,
};
let connections_to_remove = connections_to_remove.clone();
let handle = spawn_local(async move {
if let Err(err) = conn.handle_stream(stream).await {
::log::info!("conn.handle_stream() error: {:?}", err);
spawn_local(enclose!((config, access_list, request_senders, tls_config, connections_to_remove) async move {
if let Err(err) = Connection::run(
config,
access_list,
request_senders,
response_receiver,
response_consumer_id,
ConnectionId(key),
tls_config,
stream
).await {
::log::info!("Connection::run() error: {:?}", err);
}
connections_to_remove.borrow_mut().push(key);
})
}))
.detach();
let connection_reference = ConnectionReference {
response_sender,
handle,
};
entry.insert(connection_reference);
}
Err(err) => {
::log::error!("accept connection: {:?}", err);
@ -182,27 +168,62 @@ async fn receive_responses(
}
impl Connection {
async fn handle_stream(&mut self, stream: TcpStream) -> anyhow::Result<()> {
let peer_addr = Self::get_peer_addr(&stream)?;
let mut stream = self.tls_acceptor.accept(stream).await?;
async fn run(
config: Rc<Config>,
access_list: Rc<RefCell<AccessList>>,
request_senders: Rc<Senders<ChannelRequest>>,
response_receiver: LocalReceiver<ChannelResponse>,
response_consumer_id: ConsumerId,
connection_id: ConnectionId,
tls_config: Arc<TlsConfig>,
stream: TcpStream,
) -> anyhow::Result<()> {
let peer_addr = stream
.peer_addr()
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?;
let tls_acceptor: TlsAcceptor = tls_config.into();
let stream = tls_acceptor.accept(stream).await?;
let mut conn = Connection {
config: config.clone(),
access_list: access_list.clone(),
request_senders: request_senders.clone(),
response_receiver,
response_consumer_id,
stream,
peer_addr,
connection_id,
request_buffer: [0u8; MAX_REQUEST_SIZE],
request_buffer_position: 0,
};
conn.run_request_response_loop().await?;
Ok(())
}
async fn run_request_response_loop(&mut self) -> anyhow::Result<()> {
loop {
let response = match self.read_request(&mut stream).await? {
let response = match self.read_request().await? {
Either::Left(response) => Response::Failure(response),
Either::Right(request) => {
match self.handle_request(request, peer_addr).await? {
Either::Left(response) => Response::Failure(response),
Either::Right(opt_pending_scrape_response) => {
self.wait_for_response(peer_addr, opt_pending_scrape_response).await?
}
Either::Right(request) => match self.handle_request(request).await? {
Either::Left(response) => Response::Failure(response),
Either::Right(opt_pending_scrape_response) => {
self.wait_for_response(opt_pending_scrape_response).await?
}
}
},
};
self.write_response(&mut stream, &response).await?;
self.write_response(&response).await?;
if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive {
Self::close_stream(stream).await;
let _ = self
.stream
.get_ref()
.0
.shutdown(std::net::Shutdown::Both)
.await;
break;
}
@ -211,23 +232,13 @@ impl Connection {
Ok(())
}
fn get_peer_addr(stream: &TcpStream) -> anyhow::Result<SocketAddr> {
stream
.peer_addr()
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))
}
async fn close_stream(stream: TlsStream<TcpStream>) {
let _ = stream.get_ref().0.shutdown(std::net::Shutdown::Both).await;
}
async fn read_request(&mut self, stream: &mut TlsStream<TcpStream>) -> anyhow::Result<Either<FailureResponse, Request>> {
async fn read_request(&mut self) -> anyhow::Result<Either<FailureResponse, Request>> {
let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE];
loop {
::log::debug!("read");
let bytes_read = stream.read(&mut buf).await?;
let bytes_read = self.stream.read(&mut buf).await?;
let request_buffer_end = self.request_buffer_position + bytes_read;
if request_buffer_end > self.request_buffer.len() {
@ -276,7 +287,6 @@ impl Connection {
async fn handle_request(
&self,
request: Request,
peer_addr: SocketAddr,
) -> anyhow::Result<Either<FailureResponse, Option<PendingScrapeResponse>>> {
match request {
Request::Announce(request) => {
@ -291,7 +301,7 @@ impl Connection {
request,
connection_id: self.connection_id,
response_consumer_id: self.response_consumer_id,
peer_addr,
peer_addr: self.peer_addr,
};
let consumer_index = calculate_request_consumer_index(&self.config, info_hash);
@ -327,7 +337,7 @@ impl Connection {
for (consumer_index, info_hashes) in info_hashes_by_worker {
let request = ChannelRequest::Scrape {
request: ScrapeRequest { info_hashes },
peer_addr,
peer_addr: self.peer_addr,
response_consumer_id: self.response_consumer_id,
connection_id: self.connection_id,
};
@ -353,13 +363,12 @@ impl Connection {
/// return full response
async fn wait_for_response(
&self,
peer_addr: SocketAddr,
mut opt_pending_scrape_response: Option<PendingScrapeResponse>,
) -> anyhow::Result<Response> {
loop {
if let Some(channel_response) = self.response_receiver.recv().await {
if channel_response.get_peer_addr() != peer_addr {
return Err(anyhow::anyhow!("peer addressess didn't match"));
if channel_response.get_peer_addr() != self.peer_addr {
return Err(anyhow::anyhow!("peer addresses didn't match"));
}
match channel_response {
@ -396,7 +405,7 @@ impl Connection {
}
}
async fn write_response(&mut self, stream: &mut TlsStream<TcpStream>, response: &Response) -> anyhow::Result<()> {
async fn write_response(&mut self, response: &Response) -> anyhow::Result<()> {
let mut body = Vec::new();
response.write(&mut body).unwrap();
@ -412,7 +421,8 @@ impl Connection {
response_bytes.append(&mut body);
response_bytes.extend_from_slice(b"\r\n");
stream.write(&response_bytes[..]).await?;
self.stream.write(&response_bytes[..]).await?;
self.stream.flush().await?;
Ok(())
}