aquatic_http: glommio: add access list, refactor networking slightly

This commit is contained in:
Joakim Frostegård 2021-10-27 13:55:17 +02:00
parent 1a4e5750a3
commit 6404be8ef8
2 changed files with 74 additions and 40 deletions

View file

@ -54,7 +54,7 @@ pub fn run(config: Config) -> anyhow::Result<()> {
request_mesh_builder, request_mesh_builder,
response_mesh_builder, response_mesh_builder,
num_bound_sockets, num_bound_sockets,
// access_list, access_list,
) )
.await .await
}); });

View file

@ -7,11 +7,13 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use aquatic_common::access_list::AccessList;
use aquatic_http_protocol::common::InfoHash; use aquatic_http_protocol::common::InfoHash;
use aquatic_http_protocol::request::{AnnounceRequest, Request, RequestParseError, ScrapeRequest}; use aquatic_http_protocol::request::{AnnounceRequest, Request, RequestParseError, ScrapeRequest};
use aquatic_http_protocol::response::{ use aquatic_http_protocol::response::{
FailureResponse, Response, ScrapeResponse, ScrapeStatistics, FailureResponse, Response, ScrapeResponse, ScrapeStatistics,
}; };
use either::Either;
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders}; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Role, Senders};
use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender}; use glommio::channels::local_channel::{new_bounded, LocalReceiver, LocalSender};
@ -42,6 +44,7 @@ struct ConnectionReference {
struct Connection { struct Connection {
config: Rc<Config>, config: Rc<Config>,
access_list: Rc<RefCell<AccessList>>,
request_senders: Rc<Senders<ChannelRequest>>, request_senders: Rc<Senders<ChannelRequest>>,
response_receiver: LocalReceiver<ChannelResponse>, response_receiver: LocalReceiver<ChannelResponse>,
response_consumer_id: ConsumerId, response_consumer_id: ConsumerId,
@ -50,7 +53,6 @@ struct Connection {
connection_id: ConnectionId, connection_id: ConnectionId,
request_buffer: Vec<u8>, request_buffer: Vec<u8>,
close_after_writing: bool, close_after_writing: bool,
pending_scrape_response: Option<PendingScrapeResponse>,
} }
pub async fn run_socket_worker( pub async fn run_socket_worker(
@ -59,8 +61,10 @@ pub async fn run_socket_worker(
request_mesh_builder: MeshBuilder<ChannelRequest, Partial>, request_mesh_builder: MeshBuilder<ChannelRequest, Partial>,
response_mesh_builder: MeshBuilder<ChannelResponse, Partial>, response_mesh_builder: MeshBuilder<ChannelResponse, Partial>,
num_bound_sockets: Arc<AtomicUsize>, num_bound_sockets: Arc<AtomicUsize>,
access_list: AccessList,
) { ) {
let config = Rc::new(config); let config = Rc::new(config);
let access_list = Rc::new(RefCell::new(access_list));
let listener = TcpListener::bind(config.network.address).expect("bind socket"); let listener = TcpListener::bind(config.network.address).expect("bind socket");
num_bound_sockets.fetch_add(1, Ordering::SeqCst); num_bound_sockets.fetch_add(1, Ordering::SeqCst);
@ -115,6 +119,7 @@ pub async fn run_socket_worker(
let mut conn = Connection { let mut conn = Connection {
config: config.clone(), config: config.clone(),
access_list: access_list.clone(),
request_senders: request_senders.clone(), request_senders: request_senders.clone(),
response_receiver, response_receiver,
response_consumer_id, response_consumer_id,
@ -123,7 +128,6 @@ pub async fn run_socket_worker(
connection_id: ConnectionId(entry.key()), connection_id: ConnectionId(entry.key()),
request_buffer: Vec::new(), request_buffer: Vec::new(),
close_after_writing: false, close_after_writing: false,
pending_scrape_response: None,
}; };
let connections_to_remove = connections_to_remove.clone(); let connections_to_remove = connections_to_remove.clone();
@ -172,8 +176,23 @@ impl Connection {
let opt_request = self.read_tls().await?; let opt_request = self.read_tls().await?;
if let Some(request) = opt_request { if let Some(request) = opt_request {
self.handle_request(request).await?; let response = match self.handle_request(request).await? {
self.wait_for_and_send_response().await?; Some(Either::Left(response)) => {
response
}
Some(Either::Right(pending_scrape_response)) => {
self.wait_for_response(Some(pending_scrape_response)).await?
},
None => {
self.wait_for_response(None).await?
}
};
self.queue_response(&response)?;
if !self.config.network.keep_alive {
self.close_after_writing = true;
}
} }
self.write_tls().await?; self.write_tls().await?;
@ -292,26 +311,43 @@ impl Connection {
Ok(()) Ok(())
} }
/// Send on request to proper request worker/workers /// Take a request and:
async fn handle_request(&mut self, request: Request) -> anyhow::Result<()> { /// - Return error response if request is not allowed
/// - If it is an announce requests, pass it on to request workers and return None
/// - If it is a scrape requests, split it up and pass on parts to
/// relevant request workers, and return PendingScrapeResponse struct.
async fn handle_request(
&self,
request: Request
) -> anyhow::Result<Option<Either<Response, PendingScrapeResponse>>> {
let peer_addr = self.get_peer_addr()?; let peer_addr = self.get_peer_addr()?;
match request { match request {
Request::Announce(request @ AnnounceRequest { info_hash, .. }) => { Request::Announce(request @ AnnounceRequest { info_hash, .. }) => {
let request = ChannelRequest::Announce { if self.access_list.borrow().allows(self.config.access_list.mode, &info_hash.0) {
request, let request = ChannelRequest::Announce {
connection_id: self.connection_id, request,
response_consumer_id: self.response_consumer_id, connection_id: self.connection_id,
peer_addr, response_consumer_id: self.response_consumer_id,
}; peer_addr,
};
let consumer_index = calculate_request_consumer_index(&self.config, info_hash); let consumer_index = calculate_request_consumer_index(&self.config, info_hash);
// Only fails when receiver is closed // Only fails when receiver is closed
self.request_senders self.request_senders
.send_to(consumer_index, request) .send_to(consumer_index, request)
.await .await
.unwrap(); .unwrap();
Ok(None)
} else {
let response = Response::Failure(FailureResponse {
failure_reason: "Info hash not allowed".into(),
});
Ok(Some(Either::Left(response)))
}
} }
Request::Scrape(ScrapeRequest { info_hashes }) => { Request::Scrape(ScrapeRequest { info_hashes }) => {
let mut info_hashes_by_worker: BTreeMap<usize, Vec<InfoHash>> = BTreeMap::new(); let mut info_hashes_by_worker: BTreeMap<usize, Vec<InfoHash>> = BTreeMap::new();
@ -324,10 +360,7 @@ impl Connection {
info_hashes.push(info_hash); info_hashes.push(info_hash);
} }
self.pending_scrape_response = Some(PendingScrapeResponse { let pending_worker_responses = info_hashes_by_worker.len();
pending_worker_responses: info_hashes_by_worker.len(),
stats: Default::default(),
});
for (consumer_index, info_hashes) in info_hashes_by_worker { for (consumer_index, info_hashes) in info_hashes_by_worker {
let request = ChannelRequest::Scrape { let request = ChannelRequest::Scrape {
@ -343,15 +376,24 @@ impl Connection {
.await .await
.unwrap(); .unwrap();
} }
let pending_scrape_response = PendingScrapeResponse {
pending_worker_responses,
stats: Default::default(),
};
Ok(Some(Either::Right(pending_scrape_response)))
} }
} }
Ok(())
} }
// Wait for response/responses to arrive, then queue response for sending to peer /// Wait for announce response or partial scrape responses to arrive,
async fn wait_for_and_send_response(&mut self) -> anyhow::Result<()> { /// return full response
let response = loop { async fn wait_for_response(
&self,
mut opt_pending_scrape_response: Option<PendingScrapeResponse>
) -> anyhow::Result<Response> {
loop {
if let Some(channel_response) = self.response_receiver.recv().await { if let Some(channel_response) = self.response_receiver.recv().await {
if channel_response.get_peer_addr() != self.get_peer_addr()? { if channel_response.get_peer_addr() != self.get_peer_addr()? {
return Err(anyhow::anyhow!("peer addressess didn't match")); return Err(anyhow::anyhow!("peer addressess didn't match"));
@ -359,10 +401,10 @@ impl Connection {
match channel_response { match channel_response {
ChannelResponse::Announce { response, .. } => { ChannelResponse::Announce { response, .. } => {
break Response::Announce(response); break Ok(Response::Announce(response));
} }
ChannelResponse::Scrape { response, .. } => { ChannelResponse::Scrape { response, .. } => {
if let Some(mut pending) = self.pending_scrape_response.take() { if let Some(mut pending) = opt_pending_scrape_response.take() {
pending.stats.extend(response.files); pending.stats.extend(response.files);
pending.pending_worker_responses -= 1; pending.pending_worker_responses -= 1;
@ -371,9 +413,9 @@ impl Connection {
files: pending.stats, files: pending.stats,
}); });
break response; break Ok(response);
} else { } else {
self.pending_scrape_response = Some(pending); opt_pending_scrape_response = Some(pending);
} }
} else { } else {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
@ -388,15 +430,7 @@ impl Connection {
"response receiver can't receive - sender is closed" "response receiver can't receive - sender is closed"
)); ));
} }
};
self.queue_response(&response)?;
if !self.config.network.keep_alive {
self.close_after_writing = true;
} }
Ok(())
} }
fn queue_response(&mut self, response: &Response) -> anyhow::Result<()> { fn queue_response(&mut self, response: &Response) -> anyhow::Result<()> {