From 380ca222de1c174f9b0bbdc1d43147b324dea699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Thu, 3 Feb 2022 18:59:51 +0100 Subject: [PATCH] http: socket workers: simplify request buffering --- aquatic_http/src/workers/socket.rs | 75 +++++++++++++----------------- 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/aquatic_http/src/workers/socket.rs b/aquatic_http/src/workers/socket.rs index 7dc7e6e..1f1dea4 100644 --- a/aquatic_http/src/workers/socket.rs +++ b/aquatic_http/src/workers/socket.rs @@ -28,9 +28,8 @@ use slab::Slab; use crate::common::*; use crate::config::Config; -const INTERMEDIATE_BUFFER_SIZE: usize = 1024; -const MAX_REQUEST_SIZE: usize = 2048; -const MAX_RESPONSE_SIZE: usize = 4096; +const REQUEST_BUFFER_SIZE: usize = 2048; +const RESPONSE_BUFFER_SIZE: usize = 4096; const RESPONSE_HEADER_A: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Length: "; const RESPONSE_HEADER_B: &[u8] = b" "; @@ -173,9 +172,9 @@ struct Connection { stream: TlsStream, peer_addr: SocketAddr, connection_id: ConnectionId, - request_buffer: [u8; MAX_REQUEST_SIZE], + request_buffer: [u8; REQUEST_BUFFER_SIZE], request_buffer_position: usize, - response_buffer: [u8; MAX_RESPONSE_SIZE], + response_buffer: [u8; RESPONSE_BUFFER_SIZE], } impl Connection { @@ -196,7 +195,7 @@ impl Connection { let tls_acceptor: TlsAcceptor = tls_config.into(); let stream = tls_acceptor.accept(stream).await?; - let mut response_buffer = [0; MAX_RESPONSE_SIZE]; + let mut response_buffer = [0; RESPONSE_BUFFER_SIZE]; response_buffer[..RESPONSE_HEADER.len()].copy_from_slice(&RESPONSE_HEADER); @@ -209,7 +208,7 @@ impl Connection { stream, peer_addr, connection_id, - request_buffer: [0; MAX_REQUEST_SIZE], + request_buffer: [0; REQUEST_BUFFER_SIZE], request_buffer_position: 0, response_buffer, }; @@ -244,54 +243,44 @@ impl Connection { } async fn read_request(&mut self) -> anyhow::Result> { - let mut buf = [0u8; INTERMEDIATE_BUFFER_SIZE]; + self.request_buffer_position = 0; loop { - ::log::debug!("read"); + if self.request_buffer_position == self.request_buffer.len() { + return Err(anyhow::anyhow!("request buffer is full")); + } - let bytes_read = self.stream.read(&mut buf).await?; + let bytes_read = self + .stream + .read(&mut self.request_buffer[self.request_buffer_position..]) + .await?; if bytes_read == 0 { return Err(anyhow::anyhow!("peer closed connection")); } - let request_buffer_end = self.request_buffer_position + bytes_read; + self.request_buffer_position += bytes_read; - if request_buffer_end > self.request_buffer.len() { - return Err(anyhow::anyhow!("request too large")); - } else { - let request_buffer_slice = - &mut self.request_buffer[self.request_buffer_position..request_buffer_end]; + match Request::from_bytes(&self.request_buffer[..self.request_buffer_position]) { + Ok(request) => { + ::log::debug!("received request: {:?}", request); - request_buffer_slice.copy_from_slice(&buf[..bytes_read]); + return Ok(Either::Right(request)); + } + Err(RequestParseError::Invalid(err)) => { + ::log::debug!("invalid request: {:?}", err); - self.request_buffer_position = request_buffer_end; + let response = FailureResponse { + failure_reason: "Invalid request".into(), + }; - match Request::from_bytes(&self.request_buffer[..self.request_buffer_position]) { - Ok(request) => { - ::log::debug!("received request: {:?}", request); - - self.request_buffer_position = 0; - - return Ok(Either::Right(request)); - } - Err(RequestParseError::Invalid(err)) => { - ::log::debug!("invalid request: {:?}", err); - - let response = FailureResponse { - failure_reason: "Invalid request".into(), - }; - - return Ok(Either::Left(response)); - } - Err(RequestParseError::NeedMoreData) => { - ::log::debug!( - "need more request data. current data: {:?}", - std::str::from_utf8( - &self.request_buffer[..self.request_buffer_position] - ) - ); - } + return Ok(Either::Left(response)); + } + Err(RequestParseError::NeedMoreData) => { + ::log::debug!( + "need more request data. current data: {:?}", + std::str::from_utf8(&self.request_buffer[..self.request_buffer_position]) + ); } } }