http_protocol: implement axum IntoResponse, use in http_private

This commit is contained in:
Joakim Frostegård 2022-04-03 20:20:51 +02:00
parent 5e79df8e7e
commit 98e7e5cc13
6 changed files with 70 additions and 41 deletions

1
Cargo.lock generated
View file

@ -171,6 +171,7 @@ name = "aquatic_http_protocol"
version = "0.2.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"axum",
"bendy", "bendy",
"criterion", "criterion",
"hex", "hex",

View file

@ -11,7 +11,6 @@
* site will likely want num_seeders and num_leechers for all torrents.. * site will likely want num_seeders and num_leechers for all torrents..
* aquatic_http_protocol * aquatic_http_protocol
* Implement axum IntoResponse behind feature gate
* don't require compact=1? * don't require compact=1?
## Medium priority ## Medium priority

View file

@ -12,7 +12,7 @@ name = "aquatic_http_private"
[dependencies] [dependencies]
aquatic_cli_helpers = "0.2.0" aquatic_cli_helpers = "0.2.0"
aquatic_common = { version = "0.2.0", features = ["rustls-config"] } aquatic_common = { version = "0.2.0", features = ["rustls-config"] }
aquatic_http_protocol = "0.2.0" aquatic_http_protocol = { version = "0.2.0", features = ["with-axum"] }
aquatic_toml_config = "0.2.0" aquatic_toml_config = "0.2.0"
anyhow = "1" anyhow = "1"

View file

@ -2,12 +2,10 @@ use aquatic_common::CanonicalSocketAddr;
use axum::{ use axum::{
extract::{ConnectInfo, Path, RawQuery}, extract::{ConnectInfo, Path, RawQuery},
headers::UserAgent, headers::UserAgent,
http::StatusCode,
response::IntoResponse,
Extension, TypedHeader, Extension, TypedHeader,
}; };
use sqlx::mysql::MySqlPool; use sqlx::mysql::MySqlPool;
use std::{borrow::Cow, net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use aquatic_http_protocol::{ use aquatic_http_protocol::{
request::AnnounceRequest, request::AnnounceRequest,
@ -29,11 +27,11 @@ pub async fn announce(
opt_user_agent: Option<TypedHeader<UserAgent>>, opt_user_agent: Option<TypedHeader<UserAgent>>,
Path(user_token): Path<String>, Path(user_token): Path<String>,
RawQuery(query): RawQuery, RawQuery(query): RawQuery,
) -> Result<axum::response::Response, axum::response::Response> { ) -> Result<Response, FailureResponse> {
let query = query.ok_or_else(|| create_failure_response("Empty query string"))?; let query = query.ok_or_else(|| FailureResponse::new("Empty query string"))?;
let request = AnnounceRequest::from_query_string(&query) let request = AnnounceRequest::from_query_string(&query)
.map_err(|_| create_failure_response("Malformed request"))?; .map_err(|_| FailureResponse::new("Malformed request"))?;
let request_worker_index = RequestWorkerIndex::from_info_hash(&config, request.info_hash); let request_worker_index = RequestWorkerIndex::from_info_hash(&config, request.info_hash);
let opt_user_agent = opt_user_agent.map(|header| header.as_str().to_owned()); let opt_user_agent = opt_user_agent.map(|header| header.as_str().to_owned());
@ -42,8 +40,7 @@ pub async fn announce(
let (validated_request, opt_warning_message) = let (validated_request, opt_warning_message) =
db::validate_announce_request(&pool, source_addr, opt_user_agent, user_token, request) db::validate_announce_request(&pool, source_addr, opt_user_agent, user_token, request)
.await .await?;
.map_err(|r| create_response(Response::Failure(r)))?;
let response_receiver = request_sender let response_receiver = request_sender
.send_to(request_worker_index, validated_request, source_addr) .send_to(request_worker_index, validated_request, source_addr)
@ -58,39 +55,11 @@ pub async fn announce(
r.warning_message = opt_warning_message; r.warning_message = opt_warning_message;
} }
Ok(create_response(response)) Ok(response)
} }
fn create_response(response: Response) -> axum::response::Response { fn internal_error(error: String) -> FailureResponse {
let mut response_bytes = Vec::with_capacity(128);
response.write(&mut response_bytes).unwrap();
(
StatusCode::OK,
[("Content-type", "text/plain; charset=utf-8")],
response_bytes,
)
.into_response()
}
fn create_failure_response<R: Into<Cow<'static, str>>>(reason: R) -> axum::response::Response {
let mut response_bytes = Vec::with_capacity(64);
FailureResponse::new(reason)
.write(&mut response_bytes)
.unwrap();
(
StatusCode::OK,
[("Content-type", "text/plain; charset=utf-8")],
response_bytes,
)
.into_response()
}
fn internal_error(error: String) -> axum::response::Response {
::log::error!("{}", error); ::log::error!("{}", error);
create_failure_response("Internal error") FailureResponse::new("Internal error")
} }

View file

@ -22,8 +22,12 @@ name = "bench_announce_response_to_bytes"
path = "benches/bench_announce_response_to_bytes.rs" path = "benches/bench_announce_response_to_bytes.rs"
harness = false harness = false
[features]
with-axum = ["axum"]
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
axum = { version = "0.5", optional = true, default-features = false }
hex = { version = "0.4", default-features = false } hex = { version = "0.4", default-features = false }
httparse = "1" httparse = "1"
itoa = "1" itoa = "1"

View file

@ -112,6 +112,21 @@ impl AnnounceResponse {
} }
} }
#[cfg(feature = "with-axum")]
impl axum::response::IntoResponse for AnnounceResponse {
fn into_response(self) -> axum::response::Response {
let mut response_bytes = Vec::with_capacity(128);
self.write(&mut response_bytes).unwrap();
(
[("Content-type", "text/plain; charset=utf-8")],
response_bytes,
)
.into_response()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScrapeResponse { pub struct ScrapeResponse {
/// BTreeMap instead of HashMap since keys need to be serialized in order /// BTreeMap instead of HashMap since keys need to be serialized in order
@ -142,6 +157,21 @@ impl ScrapeResponse {
} }
} }
#[cfg(feature = "with-axum")]
impl axum::response::IntoResponse for ScrapeResponse {
fn into_response(self) -> axum::response::Response {
let mut response_bytes = Vec::with_capacity(128);
self.write(&mut response_bytes).unwrap();
(
[("Content-type", "text/plain; charset=utf-8")],
response_bytes,
)
.into_response()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailureResponse { pub struct FailureResponse {
#[serde(rename = "failure reason")] #[serde(rename = "failure reason")]
@ -170,6 +200,21 @@ impl FailureResponse {
} }
} }
#[cfg(feature = "with-axum")]
impl axum::response::IntoResponse for FailureResponse {
fn into_response(self) -> axum::response::Response {
let mut response_bytes = Vec::with_capacity(64);
self.write(&mut response_bytes).unwrap();
(
[("Content-type", "text/plain; charset=utf-8")],
response_bytes,
)
.into_response()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum Response { pub enum Response {
@ -191,6 +236,17 @@ impl Response {
} }
} }
#[cfg(feature = "with-axum")]
impl axum::response::IntoResponse for Response {
fn into_response(self) -> axum::response::Response {
match self {
Self::Announce(r) => r.into_response(),
Self::Scrape(r) => r.into_response(),
Self::Failure(r) => r.into_response(),
}
}
}
#[cfg(test)] #[cfg(test)]
impl quickcheck::Arbitrary for ResponsePeer<Ipv4Addr> { impl quickcheck::Arbitrary for ResponsePeer<Ipv4Addr> {
fn arbitrary(g: &mut quickcheck::Gen) -> Self { fn arbitrary(g: &mut quickcheck::Gen) -> Self {