From 96e128bb9099fab7c4f80dd38cfc1be9aecf041a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Sun, 3 Apr 2022 00:36:35 +0200 Subject: [PATCH] http_private: get basic announce route working --- aquatic_http_private/src/common.rs | 48 +++++++++++ aquatic_http_private/src/config.rs | 2 + aquatic_http_private/src/lib.rs | 21 ++++- aquatic_http_private/src/workers/common.rs | 5 +- aquatic_http_private/src/workers/request.rs | 37 ++++++++- .../src/workers/socket/mod.rs | 24 ++++-- .../src/workers/socket/routes.rs | 80 +++++++++++++------ 7 files changed, 181 insertions(+), 36 deletions(-) create mode 100644 aquatic_http_private/src/common.rs diff --git a/aquatic_http_private/src/common.rs b/aquatic_http_private/src/common.rs new file mode 100644 index 0000000..6b00a47 --- /dev/null +++ b/aquatic_http_private/src/common.rs @@ -0,0 +1,48 @@ +use tokio::sync::{mpsc, oneshot}; + +use aquatic_common::CanonicalSocketAddr; +use aquatic_http_protocol::{common::InfoHash, response::Response}; + +use crate::{ + config::Config, + workers::{common::ChannelAnnounceRequest, socket::db::ValidatedAnnounceRequest}, +}; + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct RequestWorkerIndex(pub usize); + +impl RequestWorkerIndex { + pub fn from_info_hash(config: &Config, info_hash: InfoHash) -> Self { + Self(info_hash.0[0] as usize % config.request_workers) + } +} + +pub struct ChannelRequestSender(Vec>); + +impl ChannelRequestSender { + pub fn new(senders: Vec>) -> Self { + Self(senders) + } + + pub async fn send_to( + &self, + index: RequestWorkerIndex, + request: ValidatedAnnounceRequest, + source_addr: CanonicalSocketAddr, + ) -> anyhow::Result> { + let (response_sender, response_receiver) = oneshot::channel(); + + let request = ChannelAnnounceRequest { + request, + source_addr, + response_sender, + }; + + match self.0[index.0].send(request).await { + Ok(()) => Ok(response_receiver), + Err(err) => { + Err(anyhow::Error::new(err).context("error sending ChannelAnnounceRequest")) + } + } + } +} diff --git a/aquatic_http_private/src/config.rs b/aquatic_http_private/src/config.rs index bad3d72..713ea8c 100644 --- a/aquatic_http_private/src/config.rs +++ b/aquatic_http_private/src/config.rs @@ -17,6 +17,7 @@ pub struct Config { /// Request workers receive a number of requests from socket workers, /// generate responses and send them back to the socket workers. pub request_workers: usize, + pub worker_channel_size: usize, pub log_level: LogLevel, pub network: NetworkConfig, pub protocol: ProtocolConfig, @@ -29,6 +30,7 @@ impl Default for Config { Self { socket_workers: 1, request_workers: 1, + worker_channel_size: 128, log_level: LogLevel::default(), network: NetworkConfig::default(), protocol: ProtocolConfig::default(), diff --git a/aquatic_http_private/src/lib.rs b/aquatic_http_private/src/lib.rs index f0ce764..23053d1 100644 --- a/aquatic_http_private/src/lib.rs +++ b/aquatic_http_private/src/lib.rs @@ -1,7 +1,12 @@ +mod common; pub mod config; mod workers; +use std::collections::VecDeque; + +use common::ChannelRequestSender; use dotenv::dotenv; +use tokio::sync::mpsc::channel; pub const APP_NAME: &str = "aquatic_http_private: private HTTP/TLS BitTorrent tracker"; pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -9,24 +14,36 @@ pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn run(config: config::Config) -> anyhow::Result<()> { dotenv().ok(); + let mut request_senders = Vec::new(); + let mut request_receivers = VecDeque::new(); + + for _ in 0..config.request_workers { + let (request_sender, request_receiver) = channel(config.worker_channel_size); + + request_senders.push(request_sender); + request_receivers.push_back(request_receiver); + } + let mut handles = Vec::new(); for _ in 0..config.socket_workers { let config = config.clone(); + let request_sender = ChannelRequestSender::new(request_senders.clone()); let handle = ::std::thread::Builder::new() .name("socket".into()) - .spawn(move || workers::socket::run_socket_worker(config))?; + .spawn(move || workers::socket::run_socket_worker(config, request_sender))?; handles.push(handle); } for _ in 0..config.request_workers { let config = config.clone(); + let request_receiver = request_receivers.pop_front().unwrap(); let handle = ::std::thread::Builder::new() .name("request".into()) - .spawn(move || workers::request::run_request_worker(config))?; + .spawn(move || workers::request::run_request_worker(config, request_receiver))?; handles.push(handle); } diff --git a/aquatic_http_private/src/workers/common.rs b/aquatic_http_private/src/workers/common.rs index ae73ace..558bf07 100644 --- a/aquatic_http_private/src/workers/common.rs +++ b/aquatic_http_private/src/workers/common.rs @@ -1,11 +1,12 @@ use aquatic_common::CanonicalSocketAddr; -use aquatic_http_protocol::response::AnnounceResponse; +use aquatic_http_protocol::response::Response; use tokio::sync::oneshot::Sender; use super::socket::db::ValidatedAnnounceRequest; +#[derive(Debug)] pub struct ChannelAnnounceRequest { pub request: ValidatedAnnounceRequest, pub source_addr: CanonicalSocketAddr, - pub response_sender: Sender, + pub response_sender: Sender, } diff --git a/aquatic_http_private/src/workers/request.rs b/aquatic_http_private/src/workers/request.rs index 068f293..8627706 100644 --- a/aquatic_http_private/src/workers/request.rs +++ b/aquatic_http_private/src/workers/request.rs @@ -1,5 +1,40 @@ +use tokio::sync::mpsc::Receiver; + +use aquatic_http_protocol::response::{FailureResponse, Response}; + use crate::config::Config; -pub fn run_request_worker(config: Config) -> anyhow::Result<()> { +use super::common::ChannelAnnounceRequest; + +pub fn run_request_worker( + config: Config, + request_receiver: Receiver, +) -> anyhow::Result<()> { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + runtime.block_on(run_inner(config, request_receiver))?; + Ok(()) } + +async fn run_inner( + config: Config, + mut request_receiver: Receiver, +) -> anyhow::Result<()> { + loop { + let request = request_receiver + .recv() + .await + .ok_or_else(|| anyhow::anyhow!("request channel closed"))?; + + println!("{:?}", request); + + let _ = request + .response_sender + .send(Response::Failure(FailureResponse::new( + "successful actually", + ))); + } +} diff --git a/aquatic_http_private/src/workers/socket/mod.rs b/aquatic_http_private/src/workers/socket/mod.rs index 7d5d0cd..10fcf6c 100644 --- a/aquatic_http_private/src/workers/socket/mod.rs +++ b/aquatic_http_private/src/workers/socket/mod.rs @@ -1,27 +1,37 @@ pub mod db; mod routes; -use std::net::{SocketAddr, TcpListener}; +use std::{ + net::{SocketAddr, TcpListener}, + sync::Arc, +}; use anyhow::Context; use axum::{routing::get, Extension, Router}; use sqlx::mysql::MySqlPoolOptions; -use crate::config::Config; +use crate::{common::ChannelRequestSender, config::Config}; -pub fn run_socket_worker(config: Config) -> anyhow::Result<()> { +pub fn run_socket_worker( + config: Config, + request_sender: ChannelRequestSender, +) -> anyhow::Result<()> { let tcp_listener = create_tcp_listener(config.network.address)?; let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - runtime.block_on(run_app(tcp_listener))?; + runtime.block_on(run_app(config, tcp_listener, request_sender))?; Ok(()) } -async fn run_app(tcp_listener: TcpListener) -> anyhow::Result<()> { +async fn run_app( + config: Config, + tcp_listener: TcpListener, + request_sender: ChannelRequestSender, +) -> anyhow::Result<()> { let db_url = ::std::env::var("DATABASE_URL").expect("env var DATABASE_URL"); let pool = MySqlPoolOptions::new() @@ -31,7 +41,9 @@ async fn run_app(tcp_listener: TcpListener) -> anyhow::Result<()> { let app = Router::new() .route("/:user_token/announce/", get(routes::announce)) - .layer(Extension(pool)); + .layer(Extension(Arc::new(config))) + .layer(Extension(pool)) + .layer(Extension(Arc::new(request_sender))); axum::Server::from_tcp(tcp_listener)? .serve(app.into_make_service_with_connect_info::()) diff --git a/aquatic_http_private/src/workers/socket/routes.rs b/aquatic_http_private/src/workers/socket/routes.rs index a18c23c..320c4a5 100644 --- a/aquatic_http_private/src/workers/socket/routes.rs +++ b/aquatic_http_private/src/workers/socket/routes.rs @@ -7,60 +7,90 @@ use axum::{ Extension, TypedHeader, }; use sqlx::mysql::MySqlPool; -use std::net::SocketAddr; -use tokio::sync::oneshot; +use std::{borrow::Cow, net::SocketAddr, sync::Arc}; use aquatic_http_protocol::{ request::AnnounceRequest, - response::{AnnounceResponse, FailureResponse, Response}, + response::{FailureResponse, Response}, }; -use crate::workers::common::ChannelAnnounceRequest; +use crate::{ + common::{ChannelRequestSender, RequestWorkerIndex}, + config::Config, +}; use super::db; pub async fn announce( + Extension(config): Extension>, Extension(pool): Extension, + Extension(request_sender): Extension>, ConnectInfo(peer_addr): ConnectInfo, opt_user_agent: Option>, Path(user_token): Path, RawQuery(query): RawQuery, -) -> Result<(StatusCode, impl IntoResponse), (StatusCode, impl IntoResponse)> { - let request = AnnounceRequest::from_query_string(&query.unwrap_or_else(|| "".into())) - .map_err(|err| build_response(Response::Failure(FailureResponse::new("Internal error"))))?; +) -> Result { + let query = query.ok_or_else(|| create_failure_response("Empty query string"))?; + let request = AnnounceRequest::from_query_string(&query) + .map_err(|_| create_failure_response("Malformed request"))?; + + let request_worker_index = RequestWorkerIndex::from_info_hash(&config, request.info_hash); let opt_user_agent = opt_user_agent.map(|header| header.as_str().to_owned()); let validated_request = db::validate_announce_request(&pool, peer_addr, opt_user_agent, user_token, request) .await - .map_err(|r| build_response(Response::Failure(r)))?; - - let (response_sender, response_receiver) = oneshot::channel(); + .map_err(|r| create_response(Response::Failure(r)))?; let canonical_socket_addr = CanonicalSocketAddr::new(peer_addr); - let channel_request = ChannelAnnounceRequest { - request: validated_request, - source_addr: canonical_socket_addr, - response_sender, - }; - - // TODO: send request to request worker + let response_receiver = request_sender + .send_to( + request_worker_index, + validated_request, + canonical_socket_addr, + ) + .await + .map_err(|err| internal_error(format!("Sending request over channel failed: {:#}", err)))?; let response = response_receiver.await.map_err(|err| { - ::log::error!("channel response sender closed: {}", err); - - build_response(Response::Failure(FailureResponse::new("Internal error"))) + internal_error(format!("Receiving response over channel failed: {:#}", err)) })?; - Ok(build_response(Response::Announce(response))) + Ok(create_response(response)) } -fn build_response(response: Response) -> (StatusCode, impl IntoResponse) { - let mut response_bytes = Vec::with_capacity(512); +fn create_response(response: Response) -> axum::response::Response { + let mut response_bytes = Vec::with_capacity(64); - response.write(&mut response_bytes); + response.write(&mut response_bytes).unwrap(); - (StatusCode::OK, response_bytes) + ( + StatusCode::OK, + [("Content-type", "text/plain; charset=utf-8")], + response_bytes, + ) + .into_response() +} + +fn create_failure_response>>(reason: R) -> axum::response::Response { + let mut response_bytes = Vec::with_capacity(32); + + FailureResponse::new(reason) + .write(&mut response_bytes) + .unwrap(); + + ( + StatusCode::OK, + [("Content-type", "text/plain; charset=utf-8")], + response_bytes, + ) + .into_response() +} + +fn internal_error(error: String) -> axum::response::Response { + ::log::error!("{}", error); + + create_failure_response("Internal error") }