diff --git a/Cargo.lock b/Cargo.lock index 78cf4aa..376ec68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,7 @@ dependencies = [ "signal-hook", "slab", "smartstring", + "socket2 0.4.4", ] [[package]] diff --git a/aquatic_http/Cargo.toml b/aquatic_http/Cargo.toml index 8566590..fa08ec1 100644 --- a/aquatic_http/Cargo.toml +++ b/aquatic_http/Cargo.toml @@ -40,6 +40,7 @@ serde = { version = "1", features = ["derive"] } signal-hook = { version = "0.3" } slab = "0.4" smartstring = "1" +socket2 = { version = "0.4", features = ["all"] } [dev-dependencies] quickcheck = "1" diff --git a/aquatic_http/src/config.rs b/aquatic_http/src/config.rs index 095acce..3bb54b8 100644 --- a/aquatic_http/src/config.rs +++ b/aquatic_http/src/config.rs @@ -39,7 +39,9 @@ pub struct NetworkConfig { /// Bind to this address pub address: SocketAddr, /// Only allow access over IPv6 - pub ipv6_only: bool, + pub only_ipv6: bool, + /// Maximum number of pending TCP connections + pub tcp_backlog: i32, /// Path to TLS certificate (DER-encoded X.509) pub tls_certificate_path: PathBuf, /// Path to TLS private key (DER-encoded ASN.1 in PKCS#8 or PKCS#1 format) @@ -91,7 +93,8 @@ impl Default for NetworkConfig { address: SocketAddr::from(([0, 0, 0, 0], 3000)), tls_certificate_path: "".into(), tls_private_key_path: "".into(), - ipv6_only: false, + only_ipv6: false, + tcp_backlog: 1024, keep_alive: true, } } diff --git a/aquatic_http/src/workers/socket.rs b/aquatic_http/src/workers/socket.rs index 493efd9..e7e9d35 100644 --- a/aquatic_http/src/workers/socket.rs +++ b/aquatic_http/src/workers/socket.rs @@ -1,5 +1,6 @@ use std::cell::RefCell; use std::collections::BTreeMap; +use std::os::unix::prelude::{FromRawFd, IntoRawFd}; use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -58,7 +59,7 @@ pub async fn run_socket_worker( let config = Rc::new(config); let access_list = state.access_list; - let listener = TcpListener::bind(config.network.address).expect("bind socket"); + let listener = create_tcp_listener(&config); num_bound_sockets.fetch_add(1, Ordering::SeqCst); let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap(); @@ -464,3 +465,30 @@ impl Connection { fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize { (info_hash.0[0] as usize) % config.request_workers } + +fn create_tcp_listener(config: &Config) -> TcpListener { + let domain = if config.network.address.is_ipv4() { + socket2::Domain::IPV4 + } else { + socket2::Domain::IPV6 + }; + + let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) + .expect("create socket"); + + if config.network.only_ipv6 { + socket.set_only_v6(true).expect("socket: set only ipv6"); + } + + socket.set_reuse_port(true).expect("socket: set reuse port"); + + socket + .bind(&config.network.address.into()) + .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); + + socket + .listen(config.network.tcp_backlog) + .unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err)); + + unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) } +}