diff --git a/aquatic_ws/src/lib/handler.rs b/aquatic_ws/src/lib/handler.rs index 80ee7ec..2e2f8e8 100644 --- a/aquatic_ws/src/lib/handler.rs +++ b/aquatic_ws/src/lib/handler.rs @@ -194,28 +194,23 @@ pub fn handle_scrape_requests( requests: Drain<(ConnectionMeta, ScrapeRequest)>, ){ messages_out.extend(requests.map(|(meta, request)| { - let num_info_hashes = request.info_hashes - .as_ref() - .map(|v| v.len()) - .unwrap_or(0); + let num_info_hashes = request.info_hashes.len(); let mut response = ScrapeResponse { files: HashMap::with_capacity(num_info_hashes), }; - // If request.info_hashes is None, don't return scrape for all + // If request.info_hashes is empty, don't return scrape for all // torrents, even though reference server does it. It is too expensive. - if let Some(info_hashes) = request.info_hashes { - for info_hash in info_hashes { - if let Some(torrent_data) = torrents.get(&info_hash){ - let stats = ScrapeStatistics { - complete: torrent_data.num_seeders, - downloaded: 0, // No implementation planned - incomplete: torrent_data.num_leechers, - }; + for info_hash in request.info_hashes { + if let Some(torrent_data) = torrents.get(&info_hash){ + let stats = ScrapeStatistics { + complete: torrent_data.num_seeders, + downloaded: 0, // No implementation planned + incomplete: torrent_data.num_leechers, + }; - response.files.insert(info_hash, stats); - } + response.files.insert(info_hash, stats); } } diff --git a/aquatic_ws/src/lib/protocol/deserialize.rs b/aquatic_ws/src/lib/protocol/deserialize.rs new file mode 100644 index 0000000..41b276e --- /dev/null +++ b/aquatic_ws/src/lib/protocol/deserialize.rs @@ -0,0 +1,199 @@ +use serde::{Deserializer, de::{Visitor, SeqAccess}}; + +use super::InfoHash; + + +struct TwentyByteVisitor; + +impl<'de> Visitor<'de> for TwentyByteVisitor { + type Value = [u8; 20]; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("string consisting of 20 bytes") + } + + fn visit_str(self, value: &str) -> Result + where E: ::serde::de::Error, + { + let mut arr = [0u8; 20]; + + let bytes = value.as_bytes(); + + if bytes.len() == 20 { + arr.copy_from_slice(&bytes); + + Ok(arr) + } else { + Err(E::custom(format!("not 20 bytes: {}", value))) + } + } +} + + +pub fn deserialize_20_bytes<'de, D>( + deserializer: D +) -> Result<[u8; 20], D::Error> + where D: Deserializer<'de> +{ + deserializer.deserialize_any(TwentyByteVisitor) +} + + +pub struct InfoHashVecVisitor; + + +impl<'de> Visitor<'de> for InfoHashVecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("string or array of strings consisting of 20 bytes") + } + + fn visit_str(self, value: &str) -> Result + where E: ::serde::de::Error, + { + match TwentyByteVisitor::visit_str::(TwentyByteVisitor, value){ + Ok(arr) => Ok(vec![InfoHash(arr)]), + Err(err) => Err(E::custom(format!("got string, but {}", err))) + } + } + + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> + { + let mut info_hashes: Self::Value = Vec::new(); + + while let Ok(Some(value)) = seq.next_element::<&str>(){ + let arr = TwentyByteVisitor::visit_str( + TwentyByteVisitor, value + )?; + + info_hashes.push(InfoHash(arr)); + } + + Ok(info_hashes) + } + + fn visit_none(self) -> Result + where E: ::serde::de::Error + { + Ok(vec![]) + } +} + + +/// Empty vector is returned if value is null or any invalid info hash +/// is present +pub fn deserialize_info_hashes<'de, D>( + deserializer: D +) -> Result, D::Error> + where D: Deserializer<'de>, +{ + Ok(deserializer.deserialize_any(InfoHashVecVisitor).unwrap_or_default()) +} + + +#[cfg(test)] +mod tests { + use serde::Deserialize; + + use super::*; + + fn info_hash_from_bytes(bytes: &[u8]) -> InfoHash { + let mut arr = [0u8; 20]; + + arr.copy_from_slice(bytes); + + InfoHash(arr) + } + + #[test] + fn test_deserialize_20_bytes(){ + let input = r#""aaaabbbbccccddddeeee""#; + + let expected = info_hash_from_bytes(b"aaaabbbbccccddddeeee"); + let observed: InfoHash = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + + let input = r#""1aaaabbbbccccddddeeee""#; + let res_info_hash: Result = serde_json::from_str(input); + + assert!(res_info_hash.is_err()); + + let input = r#""aaaabbbbccccddddeeeö""#; + let res_info_hash: Result = serde_json::from_str(input); + + assert!(res_info_hash.is_err()); + } + + #[derive(Debug, PartialEq, Eq, Deserialize)] + struct Test { + #[serde(deserialize_with = "deserialize_info_hashes", default)] + info_hashes: Vec, + } + + + #[test] + fn test_deserialize_info_hashes_vec(){ + let input = r#"{ + "info_hashes": ["aaaabbbbccccddddeeee", "aaaabbbbccccddddeeee"] + }"#; + + let expected = Test { + info_hashes: vec![ + info_hash_from_bytes(b"aaaabbbbccccddddeeee"), + info_hash_from_bytes(b"aaaabbbbccccddddeeee"), + ] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } + + #[test] + fn test_deserialize_info_hashes_str(){ + let input = r#"{ + "info_hashes": "aaaabbbbccccddddeeee" + }"#; + + let expected = Test { + info_hashes: vec![ + info_hash_from_bytes(b"aaaabbbbccccddddeeee"), + ] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } + + #[test] + fn test_deserialize_info_hashes_null(){ + let input = r#"{ + "info_hashes": null + }"#; + + let expected = Test { + info_hashes: vec![] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } + + #[test] + fn test_deserialize_info_hashes_missing(){ + let input = r#"{}"#; + + let expected = Test { + info_hashes: vec![] + }; + + let observed: Test = serde_json::from_str(input).unwrap(); + + assert_eq!(observed, expected); + } +} \ No newline at end of file diff --git a/aquatic_ws/src/lib/protocol.rs b/aquatic_ws/src/lib/protocol/mod.rs similarity index 83% rename from aquatic_ws/src/lib/protocol.rs rename to aquatic_ws/src/lib/protocol/mod.rs index fb9bbad..0da5c59 100644 --- a/aquatic_ws/src/lib/protocol.rs +++ b/aquatic_ws/src/lib/protocol/mod.rs @@ -1,64 +1,31 @@ use hashbrown::HashMap; -use serde::{Serialize, Deserialize, Deserializer, de::Visitor}; +use serde::{Serialize, Deserialize}; + +pub mod deserialize; + +use deserialize::*; -struct TwentyAsciiBytesVisitor; - -impl<'de> Visitor<'de> for TwentyAsciiBytesVisitor { - type Value = [u8; 20]; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("string consisting of 20 bytes") - } - - fn visit_str(self, value: &str) -> Result - where - E: ::serde::de::Error, - { - let mut arr = [0u8; 20]; - - let bytes = value.as_bytes(); - - if value.is_ascii() && bytes.len() == 20 { - arr.copy_from_slice(&bytes); - - Ok(arr) - } else { - Err(E::custom(format!("not 20 ascii bytes: {}", value))) - } - } -} - - -fn deserialize_20_ascii_bytes<'de, D>( - deserializer: D -) -> Result<[u8; 20], D::Error> - where D: Deserializer<'de> -{ - deserializer.deserialize_any(TwentyAsciiBytesVisitor) -} - - -#[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct PeerId( - #[serde(deserialize_with = "deserialize_20_ascii_bytes")] + #[serde(deserialize_with = "deserialize_20_bytes")] pub [u8; 20] ); -#[derive(Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct InfoHash( - #[serde(deserialize_with = "deserialize_20_ascii_bytes")] + #[serde(deserialize_with = "deserialize_20_bytes")] pub [u8; 20] ); -#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct OfferId( - #[serde(deserialize_with = "deserialize_20_ascii_bytes")] + #[serde(deserialize_with = "deserialize_20_bytes")] pub [u8; 20] ); @@ -173,7 +140,8 @@ pub struct ScrapeRequest { // If omitted, scrape for all torrents, apparently // There is some kind of parsing here too which accepts a single info hash // and puts it into a vector - pub info_hashes: Option>, + #[serde(deserialize_with = "deserialize_info_hashes", default)] + pub info_hashes: Vec, }