diff --git a/crates/udp_protocol/src/request.rs b/crates/udp_protocol/src/request.rs index 8cdd10d..ef0f5ac 100644 --- a/crates/udp_protocol/src/request.rs +++ b/crates/udp_protocol/src/request.rs @@ -19,27 +19,12 @@ pub enum Request { } impl Request { - pub fn write_bytes(self, bytes: &mut impl Write) -> Result<(), io::Error> { + pub fn write_bytes(&self, bytes: &mut impl Write) -> Result<(), io::Error> { match self { - Request::Connect(r) => { - bytes.write_i64::(PROTOCOL_IDENTIFIER)?; - bytes.write_i32::(0)?; - bytes.write_all(r.transaction_id.as_bytes())?; - } - - Request::Announce(r) => { - bytes.write_all(r.as_bytes())?; - } - - Request::Scrape(r) => { - bytes.write_all(r.connection_id.as_bytes())?; - bytes.write_i32::(2)?; - bytes.write_all(r.transaction_id.as_bytes())?; - bytes.write_all((*r.info_hashes.as_slice()).as_bytes())?; - } + Request::Connect(r) => r.write_bytes(bytes), + Request::Announce(r) => r.write_bytes(bytes), + Request::Scrape(r) => r.write_bytes(bytes), } - - Ok(()) } pub fn parse_bytes(bytes: &[u8], max_scrape_torrents: u8) -> Result { @@ -150,6 +135,16 @@ pub struct ConnectRequest { pub transaction_id: TransactionId, } +impl ConnectRequest { + pub fn write_bytes(&self, bytes: &mut impl Write) -> Result<(), io::Error> { + bytes.write_i64::(PROTOCOL_IDENTIFIER)?; + bytes.write_i32::(0)?; + bytes.write_all(self.transaction_id.as_bytes())?; + + Ok(()) + } +} + #[derive(PartialEq, Eq, Clone, Debug, AsBytes, FromBytes, FromZeroes)] #[repr(C, packed)] pub struct AnnounceRequest { @@ -170,6 +165,12 @@ pub struct AnnounceRequest { pub port: Port, } +impl AnnounceRequest { + pub fn write_bytes(&self, bytes: &mut impl Write) -> Result<(), io::Error> { + bytes.write_all(self.as_bytes()) + } +} + /// Note: Request::from_bytes only creates this struct with value 1 #[derive(PartialEq, Eq, Clone, Copy, Debug, AsBytes, FromBytes, FromZeroes)] #[repr(transparent)] @@ -223,6 +224,17 @@ pub struct ScrapeRequest { pub info_hashes: Vec, } +impl ScrapeRequest { + pub fn write_bytes(&self, bytes: &mut impl Write) -> Result<(), io::Error> { + bytes.write_all(self.connection_id.as_bytes())?; + bytes.write_i32::(2)?; + bytes.write_all(self.transaction_id.as_bytes())?; + bytes.write_all((*self.info_hashes.as_slice()).as_bytes())?; + + Ok(()) + } +} + #[derive(Debug)] pub enum RequestParseError { Sendable {