diff --git a/aquatic_http/src/lib/common/mod.rs b/aquatic_http/src/lib/common/mod.rs new file mode 100644 index 0000000..8002a13 --- /dev/null +++ b/aquatic_http/src/lib/common/mod.rs @@ -0,0 +1,31 @@ +pub fn num_digits_in_usize(mut number: usize) -> usize { + let mut num_digits = 1usize; + + while number >= 10 { + num_digits += 1; + + number /= 10; + } + + num_digits +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_num_digits_in_usize() { + let f = num_digits_in_usize; + + assert_eq!(f(0), 1); + assert_eq!(f(1), 1); + assert_eq!(f(9), 1); + assert_eq!(f(10), 2); + assert_eq!(f(11), 2); + assert_eq!(f(99), 2); + assert_eq!(f(100), 3); + assert_eq!(f(101), 3); + assert_eq!(f(1000), 4); + } +} diff --git a/aquatic_http/src/lib/glommio/network.rs b/aquatic_http/src/lib/glommio/network.rs index 2737234..589a907 100644 --- a/aquatic_http/src/lib/glommio/network.rs +++ b/aquatic_http/src/lib/glommio/network.rs @@ -1,11 +1,11 @@ use std::cell::RefCell; -use std::io::{BufReader, Cursor, ErrorKind, Read}; +use std::io::{BufReader, Cursor, ErrorKind, Read, Write}; use std::rc::Rc; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use aquatic_http_protocol::request::{Request, RequestParseError}; -use aquatic_http_protocol::response::Response; +use aquatic_http_protocol::response::{FailureResponse, Response}; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; use glommio::channels::channel_mesh::{MeshBuilder, Partial, Receivers, Role, Senders}; use glommio::channels::shared_channel::{ConnectedReceiver, ConnectedSender, SharedSender}; @@ -16,8 +16,11 @@ use glommio::task::JoinHandle; use rustls::{IoState, ServerConnection}; use slab::Slab; +use crate::common::num_digits_in_usize; use crate::config::Config; +const BUFFER_SIZE: usize = 1024; + #[derive(Clone, Copy, Debug)] pub struct ConnectionId(pub usize); @@ -28,12 +31,13 @@ struct ConnectionReference { struct Connection { // request_senders: Rc>, - // response_receiver: LocalReceiver, + response_receiver: LocalReceiver, tls: ServerConnection, stream: TcpStream, index: ConnectionId, - expects_request: bool, request_buffer: Vec, + wait_for_response: bool, + close_after_writing: bool, } pub async fn run_socket_worker( @@ -71,17 +75,18 @@ pub async fn run_socket_worker( let conn = Connection { // request_senders: request_senders.clone(), - // response_receiver, + response_receiver, tls: ServerConnection::new(tls_config.clone()).unwrap(), stream, index: ConnectionId(entry.key()), - expects_request: true, request_buffer: Vec::new(), + wait_for_response: false, + close_after_writing: false, }; async fn handle_stream(mut conn: Connection) { if let Err(err) = conn.handle_stream().await { - ::log::error!("conn.handle_stream() error: {:?}", err); + ::log::info!("conn.handle_stream() error: {:?}", err); } } @@ -115,43 +120,45 @@ async fn receive_responses( impl Connection { async fn handle_stream(&mut self) -> anyhow::Result<()> { - ::log::info!("incoming stream"); loop { - self.write_tls().await?; self.read_tls().await?; - /* - if !self.tls.is_handshaking() { - if self.expects_request { - let request = self.extract_request()?; + if self.wait_for_response { + if let Some(response) = self.response_receiver.recv().await { + self.queue_response(&response)?; - ::log::info!("request received: {:?}", request); + self.wait_for_response = false; - // self.request_senders.try_send_to(0, (self.index, request)); - self.expects_request = false; - - }/* - else if let Some(response) = self.response_receiver.recv().await { - response.write(&mut self.tls.writer())?; - - self.expects_request = true; - } */ + // TODO: trigger close here if keepalive is false + } + } + + self.write_tls().await?; + + if self.close_after_writing { + let _ = self.stream.shutdown(std::net::Shutdown::Both).await; + + break; } - */ } + + Ok(()) } async fn read_tls(&mut self) -> anyhow::Result<()> { loop { - ::log::info!("read_tls (wants read)"); + ::log::debug!("read_tls"); - let mut buf = [0u8; 1024]; + let mut buf = [0u8; BUFFER_SIZE]; let bytes_read = self.stream.read(&mut buf).await?; if bytes_read == 0 { - // Peer has closed connection. Remove it. - return Err(anyhow::anyhow!("peer has closed connection")); + ::log::debug!("peer has closed connection"); + + self.close_after_writing = true; + + break; } let _ = self.tls.read_tls(&mut &buf[..bytes_read]).unwrap(); @@ -160,23 +167,26 @@ impl Connection { let mut added_plaintext = false; - while io_state.plaintext_bytes_to_read() != 0 { - match self.tls.reader().read(&mut buf) { - Ok(0) => { - break; - } - Ok(amt) => { - self.request_buffer.extend_from_slice(&buf[..amt]); + if io_state.plaintext_bytes_to_read() != 0 { + loop { + match self.tls.reader().read(&mut buf) { + Ok(0) => { + break; + } + Ok(amt) => { + self.request_buffer.extend_from_slice(&buf[..amt]); - added_plaintext = true; - }, - Err(err) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(err) => { - ::log::info!("tls.reader().read error: {:?}", err); + added_plaintext = true; + }, + Err(err) if err.kind() == ErrorKind::WouldBlock => { + break; + } + Err(err) => { + // Should never happen + ::log::error!("tls.reader().read error: {:?}", err); - break; + break; + } } } } @@ -184,15 +194,24 @@ impl Connection { if added_plaintext { match Request::from_bytes(&self.request_buffer[..]) { Ok(request) => { - self.expects_request = false; + self.wait_for_response = true; - ::log::info!("received request: {:?}", request); + ::log::trace!("received request: {:?}", request); } Err(RequestParseError::NeedMoreData) => { - ::log::info!("need more request data. current data: {:?}", std::str::from_utf8(&self.request_buffer)); + ::log::debug!("need more request data. current data: {:?}", std::str::from_utf8(&self.request_buffer)); } Err(RequestParseError::Invalid(err)) => { - return Err(anyhow::anyhow!("request parse error: {:?}", err)); + ::log::debug!("invalid request: {:?}", err); + + let response = Response::Failure(FailureResponse { + failure_reason: "Invalid request".into(), + }); + + self.queue_response(&response)?; + self.close_after_writing = true; + + break; } } } @@ -210,7 +229,7 @@ impl Connection { return Ok(()); } - ::log::info!("write_tls (wants write)"); + ::log::debug!("write_tls (wants write)"); let mut buf = Vec::new(); let mut buf = Cursor::new(&mut buf); @@ -219,7 +238,29 @@ impl Connection { self.tls.write_tls(&mut buf).unwrap(); } - self.stream.write_all(&buf.into_inner()).await.unwrap(); + self.stream.write_all(&buf.into_inner()).await?; + self.stream.flush().await?; + + Ok(()) + } + + fn queue_response(&mut self, response: &Response) -> anyhow::Result<()> { + let mut body = Vec::new(); + + response.write(&mut body).unwrap(); + + let content_len = body.len() + 2; // 2 is for newlines at end + let content_len_num_digits = num_digits_in_usize(content_len); + + let mut response_bytes = Vec::with_capacity(39 + content_len_num_digits + body.len()); + + response_bytes.extend_from_slice(b"HTTP/1.1 200 OK\r\nContent-Length: "); + ::itoa::write(&mut response_bytes, content_len)?; + response_bytes.extend_from_slice(b"\r\n\r\n"); + response_bytes.append(&mut body); + response_bytes.extend_from_slice(b"\r\n"); + + self.tls.writer().write(&response_bytes[..])?; Ok(()) } diff --git a/aquatic_http/src/lib/lib.rs b/aquatic_http/src/lib/lib.rs index ee6f8e8..7f1456f 100644 --- a/aquatic_http/src/lib/lib.rs +++ b/aquatic_http/src/lib/lib.rs @@ -1,6 +1,7 @@ use cfg_if::cfg_if; pub mod config; +pub mod common; #[cfg(feature = "with-mio")] pub mod mio; diff --git a/aquatic_http/src/lib/mio/network/connection.rs b/aquatic_http/src/lib/mio/network/connection.rs index 57fee71..7ac3aa6 100644 --- a/aquatic_http/src/lib/mio/network/connection.rs +++ b/aquatic_http/src/lib/mio/network/connection.rs @@ -10,6 +10,7 @@ use native_tls::{MidHandshakeTlsStream, TlsAcceptor}; use aquatic_http_protocol::request::{Request, RequestParseError}; +use crate::common::num_digits_in_usize; use crate::mio::common::*; use super::stream::Stream; @@ -85,7 +86,7 @@ impl EstablishedConnection { pub fn send_response(&mut self, body: &[u8]) -> ::std::io::Result<()> { let content_len = body.len() + 2; // 2 is for newlines at end - let content_len_num_digits = Self::num_digits_in_usize(content_len); + let content_len_num_digits = num_digits_in_usize(content_len); let mut response = Vec::with_capacity(39 + content_len_num_digits + body.len()); @@ -110,18 +111,6 @@ impl EstablishedConnection { Ok(()) } - fn num_digits_in_usize(mut number: usize) -> usize { - let mut num_digits = 1usize; - - while number >= 10 { - num_digits += 1; - - number /= 10; - } - - num_digits - } - #[inline] pub fn clear_buffer(&mut self) { self.bytes_read = 0; @@ -269,23 +258,3 @@ impl Connection { } pub type ConnectionMap = HashMap; - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_num_digits_in_usize() { - let f = EstablishedConnection::num_digits_in_usize; - - assert_eq!(f(0), 1); - assert_eq!(f(1), 1); - assert_eq!(f(9), 1); - assert_eq!(f(10), 2); - assert_eq!(f(11), 2); - assert_eq!(f(99), 2); - assert_eq!(f(100), 3); - assert_eq!(f(101), 3); - assert_eq!(f(1000), 4); - } -}