diff --git a/aquatic_http_protocol/src/request.rs b/aquatic_http_protocol/src/request.rs index 93041dc..d5b6845 100644 --- a/aquatic_http_protocol/src/request.rs +++ b/aquatic_http_protocol/src/request.rs @@ -129,23 +129,31 @@ impl Request { let mut headers = [httparse::EMPTY_HEADER; 16]; let mut http_request = httparse::Request::new(&mut headers); - match http_request.parse(bytes){ + let path = match http_request.parse(bytes){ Ok(httparse::Status::Complete(_)) => { if let Some(path) = http_request.path { - let res_request = Self::from_http_get_path(path); - - res_request.map_err(RequestParseError::Invalid) + path } else { - Err(RequestParseError::Invalid(anyhow::anyhow!("no http path"))) + return Err(RequestParseError::Invalid( + anyhow::anyhow!("no http path") + )) } }, Ok(httparse::Status::Partial) => { - Err(RequestParseError::NeedMoreData) - }, - Err(err) => { - Err(RequestParseError::Invalid(anyhow::anyhow!("httparse: {:?}", err))) + if let Some(path) = http_request.path { + path + } else { + return Err(RequestParseError::NeedMoreData) + } } - } + Err(err) => { + return Err(RequestParseError::Invalid( + anyhow::Error::from(err) + )) + }, + }; + + Self::from_http_get_path(path).map_err(RequestParseError::Invalid) } /// Parse Request from http path (GET `/announce?info_hash=...`) @@ -430,32 +438,21 @@ mod tests { bytes.extend_from_slice(&ANNOUNCE_REQUEST_PATH.as_bytes()); bytes.extend_from_slice(b" HTTP/1.1\r\n\r\n"); - let parsed_request = Request::from_bytes( - &bytes[..] - ).unwrap(); - + let parsed_request = Request::from_bytes(&bytes[..]).unwrap(); let reference_request = get_reference_announce_request(); assert_eq!(parsed_request, reference_request); } #[test] - fn test_announce_request_from_path(){ - let parsed_request = Request::from_http_get_path( - ANNOUNCE_REQUEST_PATH - ).unwrap(); + fn test_scrape_request_from_bytes(){ + let mut bytes = Vec::new(); - let reference_request = get_reference_announce_request(); - - assert_eq!(parsed_request, reference_request); - } - - #[test] - fn test_scrape_request_from_path(){ - let parsed_request = Request::from_http_get_path( - SCRAPE_REQUEST_PATH - ).unwrap(); + bytes.extend_from_slice(b"GET "); + bytes.extend_from_slice(&SCRAPE_REQUEST_PATH.as_bytes()); + bytes.extend_from_slice(b" HTTP/1.1\r\n\r\n"); + let parsed_request = Request::from_bytes(&bytes[..]).unwrap(); let reference_request = Request::Scrape(ScrapeRequest { info_hashes: vec![InfoHash(REFERENCE_INFO_HASH)], });