diff --git a/Cargo.lock b/Cargo.lock index a27992b..37e798c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -179,12 +179,15 @@ dependencies = [ "aquatic_cli_helpers", "aquatic_common", "aquatic_udp_protocol", + "bytemuck", "cfg-if", "crossbeam-channel", "futures-lite", "glommio", "hex", "histogram", + "io-uring", + "libc", "log", "mimalloc", "mio", @@ -194,6 +197,7 @@ dependencies = [ "rand", "serde", "signal-hook", + "slab", "socket2 0.4.2", ] @@ -446,6 +450,12 @@ version = "3.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9df67f7bf9ef8498769f994239c45613ef0c5899415fb58e9add412d2c1a538" +[[package]] +name = "bytemuck" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b" + [[package]] name = "byteorder" version = "1.4.3" @@ -1130,6 +1140,16 @@ dependencies = [ "memoffset 0.5.6", ] +[[package]] +name = "io-uring" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d75829ed9377bab6c90039fe47b9d84caceb4b5063266142e21bcce6550cda8" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "itertools" version = "0.10.1" diff --git a/aquatic_udp/Cargo.toml b/aquatic_udp/Cargo.toml index 7fc5164..7c55a5a 100644 --- a/aquatic_udp/Cargo.toml +++ b/aquatic_udp/Cargo.toml @@ -18,7 +18,7 @@ name = "aquatic_udp" default = ["with-mio"] cpu-pinning = ["aquatic_common/cpu-pinning"] with-glommio = ["cpu-pinning", "glommio", "futures-lite"] -with-mio = ["crossbeam-channel", "histogram", "mio", "socket2"] +with-mio = ["crossbeam-channel", "histogram", "mio", "socket2", "io-uring", "libc", "bytemuck"] [dependencies] anyhow = "1" @@ -32,6 +32,7 @@ mimalloc = { version = "0.1", default-features = false } parking_lot = "0.11" rand = { version = "0.8", features = ["small_rng"] } serde = { version = "1", features = ["derive"] } +slab = "0.4" signal-hook = { version = "0.3" } # mio @@ -39,6 +40,9 @@ crossbeam-channel = { version = "0.5", optional = true } histogram = { version = "0.6", optional = true } mio = { version = "0.7", features = ["udp", "os-poll", "os-util"], optional = true } socket2 = { version = "0.4.1", features = ["all"], optional = true } +io-uring = { version = "0.5", optional = true } +libc = { version = "0.2", optional = true } +bytemuck = { version = "1", optional = true } # glommio glommio = { git = "https://github.com/DataDog/glommio.git", rev = "4e6b14772da2f4325271fbcf12d24cf91ed466e5", optional = true } diff --git a/aquatic_udp/src/lib/mio/mod.rs b/aquatic_udp/src/lib/mio/mod.rs index 0287f28..c7da3e1 100644 --- a/aquatic_udp/src/lib/mio/mod.rs +++ b/aquatic_udp/src/lib/mio/mod.rs @@ -17,6 +17,7 @@ use crate::config::Config; pub mod common; pub mod handlers; pub mod network; +pub mod network_uring; pub mod tasks; use common::State; @@ -98,7 +99,7 @@ pub fn run_inner(config: Config, state: State) -> ::anyhow::Result<()> { WorkerIndex::SocketWorker(i), ); - network::run_socket_worker( + network_uring::run_socket_worker( state, config, i, diff --git a/aquatic_udp/src/lib/mio/network_uring.rs b/aquatic_udp/src/lib/mio/network_uring.rs new file mode 100644 index 0000000..aa28fd5 --- /dev/null +++ b/aquatic_udp/src/lib/mio/network_uring.rs @@ -0,0 +1,521 @@ +use std::io::Cursor; +use std::mem::size_of_val; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::os::unix::prelude::{AsRawFd}; +use std::ptr::{null_mut}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::time::{Duration, Instant}; + +use aquatic_common::access_list::{AccessListCache, create_access_list_cache}; +use crossbeam_channel::{Receiver, Sender}; +use io_uring::SubmissionQueue; +use io_uring::types::{Fixed, Timespec}; +use libc::{c_void, in_addr, iovec, msghdr, sockaddr_in}; +use rand::prelude::{Rng, SeedableRng, StdRng}; +use slab::Slab; +use socket2::{Domain, Protocol, Socket, Type}; + +use aquatic_udp_protocol::{IpVersion, Request, Response}; + +use crate::common::handlers::*; +use crate::common::network::ConnectionMap; +use crate::common::*; +use crate::config::Config; + +use super::common::*; + +const RING_SIZE: usize = 128; +const MAX_RECV_EVENTS: usize = 1; +const MAX_SEND_EVENTS: usize = RING_SIZE - MAX_RECV_EVENTS - 1; +const NUM_BUFFERS: usize = MAX_RECV_EVENTS + MAX_SEND_EVENTS; + +#[derive(Clone, Copy, Debug, PartialEq)] +enum UserData { + RecvMsg { + slab_key: usize, + }, + SendMsg { + slab_key: usize, + }, + Timeout, +} + +impl UserData { + fn get_buffer_index(&self) -> usize { + match self { + Self::RecvMsg { slab_key } => { + *slab_key + } + Self::SendMsg { slab_key } => { + slab_key + MAX_RECV_EVENTS + } + Self::Timeout => { + unreachable!() + } + } + } +} + +impl From for UserData { + fn from(mut n: u64) -> UserData { + let bytes = bytemuck::bytes_of_mut(&mut n); + + let t = bytes[7]; + + bytes[7] = 0; + + match t { + 0 => Self::RecvMsg { + slab_key: n as usize, + }, + 1 => Self::SendMsg { + slab_key: n as usize, + }, + 2 => Self::Timeout, + _ => unreachable!(), + } + } +} + +impl Into for UserData { + fn into(self) -> u64 { + match self { + Self::RecvMsg { slab_key } => { + let mut out = slab_key as u64; + + bytemuck::bytes_of_mut(&mut out)[7] = 0; + + out + } + Self::SendMsg { slab_key } => { + let mut out = slab_key as u64; + + bytemuck::bytes_of_mut(&mut out)[7] = 1; + + out + } + Self::Timeout => { + let mut out = 0u64; + + bytemuck::bytes_of_mut(&mut out)[7] = 2; + + out + } + } + } +} + +pub fn run_socket_worker( + state: State, + config: Config, + token_num: usize, + request_sender: Sender<(ConnectedRequest, SocketAddr)>, + response_receiver: Receiver<(ConnectedResponse, SocketAddr)>, + num_bound_sockets: Arc, +) { + let mut rng = StdRng::from_entropy(); + + let socket = create_socket(&config); + + num_bound_sockets.fetch_add(1, Ordering::SeqCst); + + let mut connections = ConnectionMap::default(); + let mut access_list_cache = create_access_list_cache(&state.access_list); + let mut local_responses: Vec<(Response, SocketAddr)> = Vec::new(); + + let cleaning_duration = Duration::from_secs(config.cleaning.connection_cleaning_interval); + + let mut iter_counter = 0usize; + let mut last_cleaning = Instant::now(); + + let mut buffers: Vec<[u8; MAX_PACKET_SIZE]> = (0..NUM_BUFFERS).map(|_| [0; MAX_PACKET_SIZE]).collect(); + + let mut sockaddrs_ipv4 = [ + sockaddr_in { + sin_addr: in_addr { + s_addr: 0, + }, + sin_port: 0, + sin_family: 0, + sin_zero: Default::default(), + } + ; NUM_BUFFERS + ]; + + let mut iovs: Vec = (0..NUM_BUFFERS).map(|i| { + let iov_base = buffers[i].as_mut_ptr() as *mut c_void; + let iov_len = MAX_PACKET_SIZE; + + iovec { + iov_base, + iov_len, + } + }).collect(); + + let mut msghdrs: Vec = (0..NUM_BUFFERS).map(|i| { + let msg_iov: *mut iovec = &mut iovs[i]; + let msg_name: *mut sockaddr_in = &mut sockaddrs_ipv4[i]; + + msghdr { + msg_name: msg_name as *mut c_void, + msg_namelen: size_of_val(&sockaddrs_ipv4[i]) as u32, + msg_iov, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + } + }).collect(); + + let timeout = Timespec::new().nsec(500_000_000); + let mut timeout_set = false; + + let mut recv_entries = Slab::with_capacity(MAX_RECV_EVENTS); + let mut send_entries = Slab::with_capacity(MAX_SEND_EVENTS); + + let mut ring = io_uring::IoUring::new(RING_SIZE as u32).unwrap(); + + let (submitter, mut sq, mut cq) = ring.split(); + + submitter.register_files(&[socket.as_raw_fd()]).unwrap(); + + let fd = Fixed(0); + + loop { + while let Some(entry) = cq.next() { + let user_data: UserData = entry.user_data().into(); + + match user_data { + UserData::RecvMsg { slab_key } => { + recv_entries.remove(slab_key); + + let result = entry.result(); + + if result < 0 { + ::log::info!("recvmsg error {}: {:#}", result, ::std::io::Error::from_raw_os_error(-result)); + } else if result == 0 { + ::log::info!("recvmsg error: 0 bytes read"); + } else { + let buffer_index = user_data.get_buffer_index(); + let buffer_len = result as usize; + + let src = SocketAddrV4::new( + Ipv4Addr::from(u32::from_be(sockaddrs_ipv4[buffer_index].sin_addr.s_addr)), + u16::from_be(sockaddrs_ipv4[buffer_index].sin_port), + ); + + let res_request = + Request::from_bytes(&buffers[buffer_index][..buffer_len], config.protocol.max_scrape_torrents); + + handle_request( + &config, + &state, + &mut connections, + &mut access_list_cache, + &mut rng, + &request_sender, + &mut local_responses, + res_request, + SocketAddr::V4(src), + ); + } + } + UserData::SendMsg { slab_key } => { + send_entries.remove(slab_key); + + if entry.result() < 0 { + ::log::info!("recvmsg error: {:#}", ::std::io::Error::from_raw_os_error(-entry.result())); + } + } + UserData::Timeout => { + timeout_set = false; + } + } + } + + for _ in 0..(MAX_RECV_EVENTS - recv_entries.len()) { + let slab_key = recv_entries.insert(()); + let user_data = UserData::RecvMsg { slab_key }; + + let buffer_index = user_data.get_buffer_index(); + + let buf_ptr: *mut msghdr = &mut msghdrs[buffer_index]; + + let entry = io_uring::opcode::RecvMsg::new(fd, buf_ptr).build().user_data(user_data.into()); + + unsafe { + sq.push(&entry).unwrap(); + } + } + + if !timeout_set { + let user_data = UserData::Timeout; + + let timespec_ptr: *const Timespec = &timeout; + + let entry = io_uring::opcode::Timeout::new(timespec_ptr).build().user_data(user_data.into()); + + unsafe { + sq.push(&entry).unwrap(); + } + + timeout_set = true; + } + + let num_local_to_queue = (MAX_SEND_EVENTS - send_entries.len()).min(local_responses.len()); + + for (response, addr) in local_responses.drain(local_responses.len() - num_local_to_queue..) { + queue_response(&mut sq, fd, &mut send_entries, &mut buffers, &mut iovs, &mut sockaddrs_ipv4, &mut msghdrs, response, addr); + } + + for (response, addr) in response_receiver.try_iter().take(MAX_SEND_EVENTS - send_entries.len()) { + queue_response(&mut sq, fd, &mut send_entries, &mut buffers, &mut iovs, &mut sockaddrs_ipv4, &mut msghdrs, response.into(), addr); + } + + if iter_counter % 32 == 0 { + let now = Instant::now(); + + if now > last_cleaning + cleaning_duration { + connections.clean(); + + last_cleaning = now; + } + } + + let all_responses_sent = local_responses.is_empty() & response_receiver.is_empty(); + + let wait_for_num = if all_responses_sent { + send_entries.len() + recv_entries.len() + } else { + send_entries.len() + }; + + sq.sync(); + + submitter.submit_and_wait(wait_for_num).unwrap(); + + sq.sync(); + cq.sync(); + + iter_counter = iter_counter.wrapping_add(1); + } +} + +fn queue_response( + sq: &mut SubmissionQueue, + fd: Fixed, + send_events: &mut Slab<()>, + buffers: &mut [[u8; MAX_PACKET_SIZE]], + iovs: &mut [iovec], + sockaddrs: &mut [sockaddr_in], + msghdrs: &mut [msghdr], + response: Response, + src: SocketAddr, +) { + let slab_key = send_events.insert(()); + let user_data = UserData::SendMsg { slab_key }; + + let buffer_index = user_data.get_buffer_index(); + + let mut cursor = Cursor::new(&mut buffers[buffer_index][..]); + + match response.write(&mut cursor, ip_version_from_ip(src.ip())) { + Ok(()) => { + iovs[buffer_index].iov_len = cursor.position() as usize; + + let src = if let SocketAddr::V4(src) = src { + src + } else { + return; // FIXME + }; + + sockaddrs[buffer_index].sin_addr.s_addr = u32::to_be((*src.ip()).into()); + sockaddrs[buffer_index].sin_port = u16::to_be(src.port()); + } + Err(err) => { + ::log::error!("Response::write error: {:?}", err); + } + } + + let buf_ptr: *mut msghdr = &mut msghdrs[buffer_index]; + + let entry = io_uring::opcode::SendMsg::new(fd, buf_ptr).build().user_data(user_data.into()); + + unsafe { + sq.push(&entry).unwrap(); + } +} + +fn create_socket(config: &Config) -> ::std::net::UdpSocket { + let socket = if config.network.address.is_ipv4() { + Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + } else { + Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) + } + .expect("create socket"); + + socket.set_reuse_port(true).expect("socket: set reuse port"); + + socket + .set_nonblocking(true) + .expect("socket: set nonblocking"); + + socket + .bind(&config.network.address.into()) + .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); + + let recv_buffer_size = config.network.socket_recv_buffer_size; + + if recv_buffer_size != 0 { + if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { + ::log::error!( + "socket: failed setting recv buffer to {}: {:?}", + recv_buffer_size, + err + ); + } + } + + socket.into() +} + +#[inline] +fn handle_request( + config: &Config, + state: &State, + connections: &mut ConnectionMap, + access_list_cache: &mut AccessListCache, + rng: &mut StdRng, + request_sender: &Sender<(ConnectedRequest, SocketAddr)>, + local_responses: &mut Vec<(Response, SocketAddr)>, + res_request: Result, + src: SocketAddr, +) { + + let valid_until = ValidUntil::new(config.cleaning.max_connection_age); + let access_list_mode = config.access_list.mode; + + match res_request { + Ok(Request::Connect(request)) => { + let connection_id = ConnectionId(rng.gen()); + + connections.insert(connection_id, src, valid_until); + + let response = Response::Connect(ConnectResponse { + connection_id, + transaction_id: request.transaction_id, + }); + + local_responses.push((response, src)) + } + Ok(Request::Announce(request)) => { + if connections.contains(request.connection_id, src) { + if access_list_cache + .load() + .allows(access_list_mode, &request.info_hash.0) + { + if let Err(err) = request_sender + .try_send((ConnectedRequest::Announce(request), src)) + { + ::log::warn!("request_sender.try_send failed: {:?}", err) + } + } else { + let response = Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + }); + + local_responses.push((response, src)) + } + } + } + Ok(Request::Scrape(request)) => { + if connections.contains(request.connection_id, src) { + let request = ConnectedRequest::Scrape { + request, + original_indices: Vec::new(), + }; + + if let Err(err) = request_sender.try_send((request, src)) { + ::log::warn!("request_sender.try_send failed: {:?}", err) + } + } + } + Err(err) => { + ::log::debug!("Request::from_bytes error: {:?}", err); + + if let RequestParseError::Sendable { + connection_id, + transaction_id, + err, + } = err + { + if connections.contains(connection_id, src) { + let response = ErrorResponse { + transaction_id, + message: err.right_or("Parse error").into(), + }; + + local_responses.push((response.into(), src)); + } + } + } + } + +} + +fn ip_version_from_ip(ip: IpAddr) -> IpVersion { + match ip { + IpAddr::V4(_) => IpVersion::IPv4, + IpAddr::V6(ip) => { + if let [0, 0, 0, 0, 0, 0xffff, ..] = ip.segments() { + IpVersion::IPv4 + } else { + IpVersion::IPv6 + } + } + } +} + +#[cfg(test)] +mod tests { + use quickcheck::Arbitrary; + use quickcheck_macros::quickcheck; + + use super::*; + + impl quickcheck::Arbitrary for UserData { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + match (bool::arbitrary(g), bool::arbitrary(g)) { + (false, b) => { + let slab_key: u32 = Arbitrary::arbitrary(g); + let slab_key = slab_key as usize; + + if b { + UserData::RecvMsg { + slab_key + } + } else { + UserData::SendMsg { + slab_key + } + } + } + _ => { + UserData::Timeout + } + } + } + } + + #[quickcheck] + fn test_user_data_identity(a: UserData) -> bool { + let n: u64 = a.into(); + let b = UserData::from(n); + + a == b + } +} \ No newline at end of file