mirror of
https://github.com/YGGverse/aquatic.git
synced 2026-04-02 10:45:30 +00:00
aquatic_http: glommio: clean up Connection code, stop storing handle
This commit is contained in:
parent
afce23e321
commit
edec526d41
1 changed files with 76 additions and 66 deletions
|
|
@ -14,13 +14,12 @@ use aquatic_http_protocol::response::{
|
||||||
};
|
};
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
|
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
|
||||||
use futures_rustls::TlsAcceptor;
|
|
||||||
use futures_rustls::server::TlsStream;
|
use futures_rustls::server::TlsStream;
|
||||||
|
use futures_rustls::TlsAcceptor;
|
||||||
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
|
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
|
||||||
use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender};
|
use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender};
|
||||||
use glommio::channels::shared_channel::ConnectedReceiver;
|
use glommio::channels::shared_channel::ConnectedReceiver;
|
||||||
use glommio::net::{TcpListener, TcpStream};
|
use glommio::net::{TcpListener, TcpStream};
|
||||||
use glommio::task::JoinHandle;
|
|
||||||
use glommio::timer::TimerActionRepeat;
|
use glommio::timer::TimerActionRepeat;
|
||||||
use glommio::{enclose, prelude::*};
|
use glommio::{enclose, prelude::*};
|
||||||
use slab::Slab;
|
use slab::Slab;
|
||||||
|
|
@ -40,7 +39,6 @@ struct PendingScrapeResponse {
|
||||||
|
|
||||||
struct ConnectionReference {
|
struct ConnectionReference {
|
||||||
response_sender: LocalSender<ChannelResponse>,
|
response_sender: LocalSender<ChannelResponse>,
|
||||||
handle: JoinHandle<()>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Connection {
|
struct Connection {
|
||||||
|
|
@ -49,7 +47,8 @@ struct Connection {
|
||||||
request_senders: Rc<Senders<ChannelRequest>>,
|
request_senders: Rc<Senders<ChannelRequest>>,
|
||||||
response_receiver: LocalReceiver<ChannelResponse>,
|
response_receiver: LocalReceiver<ChannelResponse>,
|
||||||
response_consumer_id: ConsumerId,
|
response_consumer_id: ConsumerId,
|
||||||
tls_acceptor: TlsAcceptor,
|
stream: TlsStream<TcpStream>,
|
||||||
|
peer_addr: SocketAddr,
|
||||||
connection_id: ConnectionId,
|
connection_id: ConnectionId,
|
||||||
request_buffer: [u8; MAX_REQUEST_SIZE],
|
request_buffer: [u8; MAX_REQUEST_SIZE],
|
||||||
request_buffer_position: usize,
|
request_buffer_position: usize,
|
||||||
|
|
@ -123,40 +122,27 @@ pub async fn run_socket_worker(
|
||||||
match stream {
|
match stream {
|
||||||
Ok(stream) => {
|
Ok(stream) => {
|
||||||
let (response_sender, response_receiver) = new_bounded(config.request_workers);
|
let (response_sender, response_receiver) = new_bounded(config.request_workers);
|
||||||
|
let key = connection_slab
|
||||||
|
.borrow_mut()
|
||||||
|
.insert(ConnectionReference { response_sender });
|
||||||
|
|
||||||
let mut slab = connection_slab.borrow_mut();
|
spawn_local(enclose!((config, access_list, request_senders, tls_config, connections_to_remove) async move {
|
||||||
let entry = slab.vacant_entry();
|
if let Err(err) = Connection::run(
|
||||||
let key = entry.key();
|
config,
|
||||||
|
access_list,
|
||||||
let mut conn = Connection {
|
request_senders,
|
||||||
config: config.clone(),
|
|
||||||
access_list: access_list.clone(),
|
|
||||||
request_senders: request_senders.clone(),
|
|
||||||
response_receiver,
|
response_receiver,
|
||||||
response_consumer_id,
|
response_consumer_id,
|
||||||
tls_acceptor: tls_config.clone().into(),
|
ConnectionId(key),
|
||||||
connection_id: ConnectionId(entry.key()),
|
tls_config,
|
||||||
request_buffer: [0u8; MAX_REQUEST_SIZE],
|
stream
|
||||||
request_buffer_position: 0,
|
).await {
|
||||||
};
|
::log::info!("Connection::run() error: {:?}", err);
|
||||||
|
|
||||||
let connections_to_remove = connections_to_remove.clone();
|
|
||||||
|
|
||||||
let handle = spawn_local(async move {
|
|
||||||
if let Err(err) = conn.handle_stream(stream).await {
|
|
||||||
::log::info!("conn.handle_stream() error: {:?}", err);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
connections_to_remove.borrow_mut().push(key);
|
connections_to_remove.borrow_mut().push(key);
|
||||||
})
|
}))
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
let connection_reference = ConnectionReference {
|
|
||||||
response_sender,
|
|
||||||
handle,
|
|
||||||
};
|
|
||||||
|
|
||||||
entry.insert(connection_reference);
|
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
::log::error!("accept connection: {:?}", err);
|
::log::error!("accept connection: {:?}", err);
|
||||||
|
|
@ -182,27 +168,62 @@ async fn receive_responses(
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Connection {
|
impl Connection {
|
||||||
async fn handle_stream(&mut self, stream: TcpStream) -> anyhow::Result<()> {
|
async fn run(
|
||||||
let peer_addr = Self::get_peer_addr(&stream)?;
|
config: Rc<Config>,
|
||||||
let mut stream = self.tls_acceptor.accept(stream).await?;
|
access_list: Rc<RefCell<AccessList>>,
|
||||||
|
request_senders: Rc<Senders<ChannelRequest>>,
|
||||||
|
response_receiver: LocalReceiver<ChannelResponse>,
|
||||||
|
response_consumer_id: ConsumerId,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
tls_config: Arc<TlsConfig>,
|
||||||
|
stream: TcpStream,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let peer_addr = stream
|
||||||
|
.peer_addr()
|
||||||
|
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))?;
|
||||||
|
|
||||||
loop {
|
let tls_acceptor: TlsAcceptor = tls_config.into();
|
||||||
let response = match self.read_request(&mut stream).await? {
|
let stream = tls_acceptor.accept(stream).await?;
|
||||||
Either::Left(response) => Response::Failure(response),
|
|
||||||
Either::Right(request) => {
|
let mut conn = Connection {
|
||||||
match self.handle_request(request, peer_addr).await? {
|
config: config.clone(),
|
||||||
Either::Left(response) => Response::Failure(response),
|
access_list: access_list.clone(),
|
||||||
Either::Right(opt_pending_scrape_response) => {
|
request_senders: request_senders.clone(),
|
||||||
self.wait_for_response(peer_addr, opt_pending_scrape_response).await?
|
response_receiver,
|
||||||
}
|
response_consumer_id,
|
||||||
}
|
stream,
|
||||||
}
|
peer_addr,
|
||||||
|
connection_id,
|
||||||
|
request_buffer: [0u8; MAX_REQUEST_SIZE],
|
||||||
|
request_buffer_position: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.write_response(&mut stream, &response).await?;
|
conn.run_request_response_loop().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_request_response_loop(&mut self) -> anyhow::Result<()> {
|
||||||
|
loop {
|
||||||
|
let response = match self.read_request().await? {
|
||||||
|
Either::Left(response) => Response::Failure(response),
|
||||||
|
Either::Right(request) => match self.handle_request(request).await? {
|
||||||
|
Either::Left(response) => Response::Failure(response),
|
||||||
|
Either::Right(opt_pending_scrape_response) => {
|
||||||
|
self.wait_for_response(opt_pending_scrape_response).await?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
self.write_response(&response).await?;
|
||||||
|
|
||||||
if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive {
|
if matches!(response, Response::Failure(_)) || !self.config.network.keep_alive {
|
||||||
Self::close_stream(stream).await;
|
let _ = self
|
||||||
|
.stream
|
||||||
|
.get_ref()
|
||||||
|
.0
|
||||||
|
.shutdown(std::net::Shutdown::Both)
|
||||||
|
.await;
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -211,23 +232,13 @@ impl Connection {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_peer_addr(stream: &TcpStream) -> anyhow::Result<SocketAddr> {
|
async fn read_request(&mut self) -> anyhow::Result<Either<FailureResponse, Request>> {
|
||||||
stream
|
|
||||||
.peer_addr()
|
|
||||||
.map_err(|err| anyhow::anyhow!("Couldn't get peer addr: {:?}", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn close_stream(stream: TlsStream<TcpStream>) {
|
|
||||||
let _ = stream.get_ref().0.shutdown(std::net::Shutdown::Both).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn read_request(&mut self, stream: &mut TlsStream<TcpStream>) -> anyhow::Result<Either<FailureResponse, Request>> {
|
|
||||||
let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE];
|
let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
::log::debug!("read");
|
::log::debug!("read");
|
||||||
|
|
||||||
let bytes_read = stream.read(&mut buf).await?;
|
let bytes_read = self.stream.read(&mut buf).await?;
|
||||||
let request_buffer_end = self.request_buffer_position + bytes_read;
|
let request_buffer_end = self.request_buffer_position + bytes_read;
|
||||||
|
|
||||||
if request_buffer_end > self.request_buffer.len() {
|
if request_buffer_end > self.request_buffer.len() {
|
||||||
|
|
@ -276,7 +287,6 @@ impl Connection {
|
||||||
async fn handle_request(
|
async fn handle_request(
|
||||||
&self,
|
&self,
|
||||||
request: Request,
|
request: Request,
|
||||||
peer_addr: SocketAddr,
|
|
||||||
) -> anyhow::Result<Either<FailureResponse, Option<PendingScrapeResponse>>> {
|
) -> anyhow::Result<Either<FailureResponse, Option<PendingScrapeResponse>>> {
|
||||||
match request {
|
match request {
|
||||||
Request::Announce(request) => {
|
Request::Announce(request) => {
|
||||||
|
|
@ -291,7 +301,7 @@ impl Connection {
|
||||||
request,
|
request,
|
||||||
connection_id: self.connection_id,
|
connection_id: self.connection_id,
|
||||||
response_consumer_id: self.response_consumer_id,
|
response_consumer_id: self.response_consumer_id,
|
||||||
peer_addr,
|
peer_addr: self.peer_addr,
|
||||||
};
|
};
|
||||||
|
|
||||||
let consumer_index = calculate_request_consumer_index(&self.config, info_hash);
|
let consumer_index = calculate_request_consumer_index(&self.config, info_hash);
|
||||||
|
|
@ -327,7 +337,7 @@ impl Connection {
|
||||||
for (consumer_index, info_hashes) in info_hashes_by_worker {
|
for (consumer_index, info_hashes) in info_hashes_by_worker {
|
||||||
let request = ChannelRequest::Scrape {
|
let request = ChannelRequest::Scrape {
|
||||||
request: ScrapeRequest { info_hashes },
|
request: ScrapeRequest { info_hashes },
|
||||||
peer_addr,
|
peer_addr: self.peer_addr,
|
||||||
response_consumer_id: self.response_consumer_id,
|
response_consumer_id: self.response_consumer_id,
|
||||||
connection_id: self.connection_id,
|
connection_id: self.connection_id,
|
||||||
};
|
};
|
||||||
|
|
@ -353,13 +363,12 @@ impl Connection {
|
||||||
/// return full response
|
/// return full response
|
||||||
async fn wait_for_response(
|
async fn wait_for_response(
|
||||||
&self,
|
&self,
|
||||||
peer_addr: SocketAddr,
|
|
||||||
mut opt_pending_scrape_response: Option<PendingScrapeResponse>,
|
mut opt_pending_scrape_response: Option<PendingScrapeResponse>,
|
||||||
) -> anyhow::Result<Response> {
|
) -> anyhow::Result<Response> {
|
||||||
loop {
|
loop {
|
||||||
if let Some(channel_response) = self.response_receiver.recv().await {
|
if let Some(channel_response) = self.response_receiver.recv().await {
|
||||||
if channel_response.get_peer_addr() != peer_addr {
|
if channel_response.get_peer_addr() != self.peer_addr {
|
||||||
return Err(anyhow::anyhow!("peer addressess didn't match"));
|
return Err(anyhow::anyhow!("peer addresses didn't match"));
|
||||||
}
|
}
|
||||||
|
|
||||||
match channel_response {
|
match channel_response {
|
||||||
|
|
@ -396,7 +405,7 @@ impl Connection {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn write_response(&mut self, stream: &mut TlsStream<TcpStream>, response: &Response) -> anyhow::Result<()> {
|
async fn write_response(&mut self, response: &Response) -> anyhow::Result<()> {
|
||||||
let mut body = Vec::new();
|
let mut body = Vec::new();
|
||||||
|
|
||||||
response.write(&mut body).unwrap();
|
response.write(&mut body).unwrap();
|
||||||
|
|
@ -412,7 +421,8 @@ impl Connection {
|
||||||
response_bytes.append(&mut body);
|
response_bytes.append(&mut body);
|
||||||
response_bytes.extend_from_slice(b"\r\n");
|
response_bytes.extend_from_slice(b"\r\n");
|
||||||
|
|
||||||
stream.write(&response_bytes[..]).await?;
|
self.stream.write(&response_bytes[..]).await?;
|
||||||
|
self.stream.flush().await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue