mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-04-01 02:05:30 +00:00
aquatic_http: glommio: clean up Connection code, stop storing handle
This commit is contained in:
parent
afce23e321
commit
edec526d41
1 changed files with 76 additions and 66 deletions
|
|
@ -14,13 +14,12 @@ use aquatic_http_protocol::response::{
|
|||
};
|
||||
use either::Either;
|
||||
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
|
||||
use futures_rustls::TlsAcceptor;
|
||||
use futures_rustls::server::TlsStream;
|
||||
use futures_rustls::TlsAcceptor;
|
||||
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
|
||||
use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender};
|
||||
use glommio::channels::shared_channel::ConnectedReceiver;
|
||||
use glommio::net::{TcpListener, TcpStream};
|
||||
use glommio::task::JoinHandle;
|
||||
use glommio::timer::TimerActionRepeat;
|
||||
use glommio::{enclose, prelude::*};
|
||||
use slab::Slab;
|
||||
|
|
@ -40,7 +39,6 @@ struct PendingScrapeResponse {
|
|||
|
||||
struct ConnectionReference {
|
||||
response_sender: LocalSender<ChannelResponse>,
|
||||
handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
struct Connection {
|
||||
|
|
@ -49,7 +47,8 @@ struct Connection {
|
|||
request_senders: Rc<Senders<ChannelRequest>>,
|
||||
response_receiver: LocalReceiver<ChannelResponse>,
|
||||
response_consumer_id: ConsumerId,
|
||||
tls_acceptor: TlsAcceptor,
|
||||
stream: TlsStream<TcpStream>,
|
||||
peer_addr: SocketAddr,
|
||||
connection_id: ConnectionId,
|
||||
request_buffer: [u8; MAX_REQUEST_SIZE],
|
||||
request_buffer_position: usize,
|
||||
|
|
@ -123,40 +122,27 @@ pub async fn run_socket_worker(
|
|||
match stream {
|
||||
Ok(stream) => {
|
||||
let (response_sender, response_receiver) = new_bounded(config.request_workers);
|
||||
let key = connection_slab
|
||||
.borrow_mut()
|
||||
.insert(ConnectionReference { response_sender });
|
||||
|
||||
let mut slab = connection_slab.borrow_mut();
|
||||
let entry = slab.vacant_entry();
|
||||
let key = entry.key();
|
||||
|
||||
let mut conn = Connection {
|
||||
config: config.clone(),
|
||||
access_list: access_list.clone(),
|
||||
request_senders: request_senders.clone(),
|
||||
response_receiver,
|
||||
response_consumer_id,
|
||||
tls_acceptor: tls_config.clone().into(),
|
||||
connection_id: ConnectionId(entry.key()),
|
||||
request_buffer: [0u8; MAX_REQUEST_SIZE],
|
||||
request_buffer_position: 0,
|
||||
};
|
||||
|
||||
let connections_to_remove = connections_to_remove.clone();
|
||||
|
||||
let handle = spawn_local(async move {
|
||||
if let Err(err) = conn.handle_stream(stream).await {
|
||||
::log::info!("conn.handle_stream() error: {:?}", err);
|
||||
spawn_local(enclose!((config, access_list, request_senders, tls_config, connections_to_remove) async move {
|
||||
if let Err(err) = Connection::run(
|
||||
config,
|
||||
access_list,
|
||||
request_senders,
|
||||
response_receiver,
|
||||
response_consumer_id,
|
||||
ConnectionId(key),
|
||||
tls_config,
|
||||
stream
|
||||
).await {
|
||||
::log::info!("Connection::run() error: {:?}", err);
|
||||
}
|
||||
|
||||
connections_to_remove.borrow_mut().push(key);
|
||||
})
|
||||
}))
|
||||
.detach();
|
||||
|
||||
let connection_reference = ConnectionReference {
|
||||
response_sender,
|
||||
handle,
|
||||
};
|
||||
|
||||
entry.insert(connection_reference);
|
||||
}
|
||||
Err(err) => {
|
||||
::log::error!("accept connection: {:?}", err);
|
||||
|
|
@ -182,27 +168,62 @@ async fn receive_responses(
|
|||
}
|
||||
|
||||
impl Connection {
|
||||
async fn handle_stream(&mut self, stream: TcpStream) -> anyhow::Result<()> {
|
||||
let peer_addr = Self::get_peer_addr(&stream)?;
|
||||
let mut stream = self.tls_acceptor.accept(stream).await?;
|
||||
async fn run(
|
||||
config: Rc<Config>,
|
||||
access_list: Rc<RefCell<AccessList>>,
|
||||
request_senders: Rc<Senders<ChannelRequest>>,
|
||||
response_receiver: LocalReceiver<ChannelResponse>,
|
||||
response_consumer_id: ConsumerId,
|
||||
connection_id: ConnectionId,
|
||||
tls_config: Arc<TlsConfig>,
|
||||
stream: TcpStream,
|
||||
) -> anyhow::Result<()> {
|
||||
let peer_addr = stream
|
||||
.peer_addr()
|
||||
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?;
|
||||
|
||||
let tls_acceptor: TlsAcceptor = tls_config.into();
|
||||
let stream = tls_acceptor.accept(stream).await?;
|
||||
|
||||
let mut conn = Connection {
|
||||
config: config.clone(),
|
||||
access_list: access_list.clone(),
|
||||
request_senders: request_senders.clone(),
|
||||
response_receiver,
|
||||
response_consumer_id,
|
||||
stream,
|
||||
peer_addr,
|
||||
connection_id,
|
||||
request_buffer: [0u8; MAX_REQUEST_SIZE],
|
||||
request_buffer_position: 0,
|
||||
};
|
||||
|
||||
conn.run_request_response_loop().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_request_response_loop(&mut self) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let response = match self.read_request(&mut stream).await? {
|
||||
let response = match self.read_request().await? {
|
||||
Either::Left(response) => Response::Failure(response),
|
||||
Either::Right(request) => {
|
||||
match self.handle_request(request, peer_addr).await? {
|
||||
Either::Left(response) => Response::Failure(response),
|
||||
Either::Right(opt_pending_scrape_response) => {
|
||||
self.wait_for_response(peer_addr, opt_pending_scrape_response).await?
|
||||
}
|
||||
Either::Right(request) => match self.handle_request(request).await? {
|
||||
Either::Left(response) => Response::Failure(response),
|
||||
Either::Right(opt_pending_scrape_response) => {
|
||||
self.wait_for_response(opt_pending_scrape_response).await?
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
self.write_response(&mut stream, &response).await?;
|
||||
self.write_response(&response).await?;
|
||||
|
||||
if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive {
|
||||
Self::close_stream(stream).await;
|
||||
let _ = self
|
||||
.stream
|
||||
.get_ref()
|
||||
.0
|
||||
.shutdown(std::net::Shutdown::Both)
|
||||
.await;
|
||||
|
||||
break;
|
||||
}
|
||||
|
|
@ -211,23 +232,13 @@ impl Connection {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_peer_addr(stream: &TcpStream) -> anyhow::Result<SocketAddr> {
|
||||
stream
|
||||
.peer_addr()
|
||||
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))
|
||||
}
|
||||
|
||||
async fn close_stream(stream: TlsStream<TcpStream>) {
|
||||
let _ = stream.get_ref().0.shutdown(std::net::Shutdown::Both).await;
|
||||
}
|
||||
|
||||
async fn read_request(&mut self, stream: &mut TlsStream<TcpStream>) -> anyhow::Result<Either<FailureResponse, Request>> {
|
||||
async fn read_request(&mut self) -> anyhow::Result<Either<FailureResponse, Request>> {
|
||||
let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE];
|
||||
|
||||
loop {
|
||||
::log::debug!("read");
|
||||
|
||||
let bytes_read = stream.read(&mut buf).await?;
|
||||
let bytes_read = self.stream.read(&mut buf).await?;
|
||||
let request_buffer_end = self.request_buffer_position + bytes_read;
|
||||
|
||||
if request_buffer_end > self.request_buffer.len() {
|
||||
|
|
@ -276,7 +287,6 @@ impl Connection {
|
|||
async fn handle_request(
|
||||
&self,
|
||||
request: Request,
|
||||
peer_addr: SocketAddr,
|
||||
) -> anyhow::Result<Either<FailureResponse, Option<PendingScrapeResponse>>> {
|
||||
match request {
|
||||
Request::Announce(request) => {
|
||||
|
|
@ -291,7 +301,7 @@ impl Connection {
|
|||
request,
|
||||
connection_id: self.connection_id,
|
||||
response_consumer_id: self.response_consumer_id,
|
||||
peer_addr,
|
||||
peer_addr: self.peer_addr,
|
||||
};
|
||||
|
||||
let consumer_index = calculate_request_consumer_index(&self.config, info_hash);
|
||||
|
|
@ -327,7 +337,7 @@ impl Connection {
|
|||
for (consumer_index, info_hashes) in info_hashes_by_worker {
|
||||
let request = ChannelRequest::Scrape {
|
||||
request: ScrapeRequest { info_hashes },
|
||||
peer_addr,
|
||||
peer_addr: self.peer_addr,
|
||||
response_consumer_id: self.response_consumer_id,
|
||||
connection_id: self.connection_id,
|
||||
};
|
||||
|
|
@ -353,13 +363,12 @@ impl Connection {
|
|||
/// return full response
|
||||
async fn wait_for_response(
|
||||
&self,
|
||||
peer_addr: SocketAddr,
|
||||
mut opt_pending_scrape_response: Option<PendingScrapeResponse>,
|
||||
) -> anyhow::Result<Response> {
|
||||
loop {
|
||||
if let Some(channel_response) = self.response_receiver.recv().await {
|
||||
if channel_response.get_peer_addr() != peer_addr {
|
||||
return Err(anyhow::anyhow!("peer addressess didn't match"));
|
||||
if channel_response.get_peer_addr() != self.peer_addr {
|
||||
return Err(anyhow::anyhow!("peer addresses didn't match"));
|
||||
}
|
||||
|
||||
match channel_response {
|
||||
|
|
@ -396,7 +405,7 @@ impl Connection {
|
|||
}
|
||||
}
|
||||
|
||||
async fn write_response(&mut self, stream: &mut TlsStream<TcpStream>, response: &Response) -> anyhow::Result<()> {
|
||||
async fn write_response(&mut self, response: &Response) -> anyhow::Result<()> {
|
||||
let mut body = Vec::new();
|
||||
|
||||
response.write(&mut body).unwrap();
|
||||
|
|
@ -412,7 +421,8 @@ impl Connection {
|
|||
response_bytes.append(&mut body);
|
||||
response_bytes.extend_from_slice(b"\r\n");
|
||||
|
||||
stream.write(&response_bytes[..]).await?;
|
||||
self.stream.write(&response_bytes[..]).await?;
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue