diff --git a/aquatic_common/src/privileges.rs b/aquatic_common/src/privileges.rs index d9db45f..d252996 100644 --- a/aquatic_common/src/privileges.rs +++ b/aquatic_common/src/privileges.rs @@ -3,6 +3,7 @@ use std::{ sync::{Arc, Barrier}, }; +use anyhow::Context; use privdrop::PrivDrop; use serde::Deserialize; @@ -46,7 +47,7 @@ impl PrivilegeDropper { } } - pub fn after_socket_creation(&self) { + pub fn after_socket_creation(&self) -> anyhow::Result<()> { if self.config.drop_privileges { if self.barrier.wait().is_leader() { PrivDrop::default() @@ -54,8 +55,10 @@ impl PrivilegeDropper { .group(self.config.group.clone()) .user(self.config.user.clone()) .apply() - .expect("drop privileges"); + .with_context(|| "drop privileges")?; } } + + Ok(()) } } diff --git a/aquatic_http/src/workers/socket.rs b/aquatic_http/src/workers/socket.rs index 50e6e98..aee1f97 100644 --- a/aquatic_http/src/workers/socket.rs +++ b/aquatic_http/src/workers/socket.rs @@ -5,6 +5,7 @@ use std::rc::Rc; use std::sync::Arc; use std::time::{Duration, Instant}; +use anyhow::Context; use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::rustls_config::RustlsConfig; @@ -63,7 +64,7 @@ pub async fn run_socket_worker( let config = Rc::new(config); let access_list = state.access_list; - let listener = create_tcp_listener(&config, priv_dropper); + let listener = create_tcp_listener(&config, priv_dropper).expect("create tcp listener"); let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap(); let request_senders = Rc::new(request_senders); @@ -484,31 +485,30 @@ fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usi (info_hash.0[0] as usize) % config.request_workers } -fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> TcpListener { +fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> anyhow::Result { let domain = if config.network.address.is_ipv4() { socket2::Domain::IPV4 } else { socket2::Domain::IPV6 }; - let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) - .expect("create socket"); + let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?; if config.network.only_ipv6 { - socket.set_only_v6(true).expect("socket: set only ipv6"); + socket.set_only_v6(true).with_context(|| "socket: set only ipv6")?; } - socket.set_reuse_port(true).expect("socket: set reuse port"); + socket.set_reuse_port(true).with_context(|| "socket: set reuse port")?; socket .bind(&config.network.address.into()) - .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); + .with_context(|| format!("socket: bind to {}", config.network.address))?; socket .listen(config.network.tcp_backlog) - .unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err)); + .with_context(|| format!("socket: listen on {}", config.network.address))?; - priv_dropper.after_socket_creation(); + priv_dropper.after_socket_creation()?; - unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) } + Ok(unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) }) } diff --git a/aquatic_udp/src/workers/socket.rs b/aquatic_udp/src/workers/socket.rs index 0354b58..ac9c556 100644 --- a/aquatic_udp/src/workers/socket.rs +++ b/aquatic_udp/src/workers/socket.rs @@ -4,6 +4,7 @@ use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; use std::vec::Drain; +use anyhow::Context; use aquatic_common::privileges::PrivilegeDropper; use crossbeam_channel::Receiver; use mio::net::UdpSocket; @@ -160,7 +161,7 @@ pub fn run_socket_worker( let mut rng = StdRng::from_entropy(); let mut buffer = [0u8; MAX_PACKET_SIZE]; - let mut socket = UdpSocket::from_std(create_socket(&config, priv_dropper)); + let mut socket = UdpSocket::from_std(create_socket(&config, priv_dropper).expect("create socket")); let mut poll = Poll::new().expect("create poll"); let interests = Interest::READABLE; @@ -516,29 +517,22 @@ fn send_response( } } -pub fn create_socket(config: &Config, priv_dropper: PrivilegeDropper) -> ::std::net::UdpSocket { +pub fn create_socket(config: &Config, priv_dropper: PrivilegeDropper) -> anyhow::Result<::std::net::UdpSocket> { let socket = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))? } else { - Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) - } - .expect("create socket"); + Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))? + }; if config.network.only_ipv6 { - socket.set_only_v6(true).expect("socket: set only ipv6"); + socket.set_only_v6(true).with_context(|| "socket: set only ipv6")?; } - socket.set_reuse_port(true).expect("socket: set reuse port"); + socket.set_reuse_port(true).with_context(|| "socket: set reuse port")?; socket .set_nonblocking(true) - .expect("socket: set nonblocking"); - - socket - .bind(&config.network.address.into()) - .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); - - priv_dropper.after_socket_creation(); + .with_context(|| "socket: set nonblocking")?; let recv_buffer_size = config.network.socket_recv_buffer_size; @@ -552,7 +546,13 @@ pub fn create_socket(config: &Config, priv_dropper: PrivilegeDropper) -> ::std:: } } - socket.into() + socket + .bind(&config.network.address.into()) + .with_context(|| format!("socket: bind to {}", config.network.address))?; + + priv_dropper.after_socket_creation()?; + + Ok(socket.into()) } #[cfg(test)] diff --git a/aquatic_ws/src/workers/socket.rs b/aquatic_ws/src/workers/socket.rs index 0acdd56..02974cd 100644 --- a/aquatic_ws/src/workers/socket.rs +++ b/aquatic_ws/src/workers/socket.rs @@ -6,6 +6,7 @@ use std::rc::Rc; use std::sync::Arc; use std::time::{Duration, Instant}; +use anyhow::Context; use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::rustls_config::RustlsConfig; @@ -58,7 +59,7 @@ pub async fn run_socket_worker( let config = Rc::new(config); let access_list = state.access_list; - let listener = create_tcp_listener(&config, priv_dropper); + let listener = create_tcp_listener(&config, priv_dropper).expect("create tcp listener"); let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap(); let in_message_senders = Rc::new(in_message_senders); @@ -542,7 +543,7 @@ fn calculate_in_message_consumer_index(config: &Config, info_hash: InfoHash) -> (info_hash.0[0] as usize) % config.request_workers } -fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> TcpListener { +fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> anyhow::Result { let domain = if config.network.address.is_ipv4() { socket2::Domain::IPV4 } else { @@ -550,23 +551,23 @@ fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> TcpLi }; let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) - .expect("create socket"); + .with_context(|| "create socket")?; if config.network.only_ipv6 { - socket.set_only_v6(true).expect("socket: set only ipv6"); + socket.set_only_v6(true).with_context(|| "socket: set only ipv6")?; } - socket.set_reuse_port(true).expect("socket: set reuse port"); + socket.set_reuse_port(true).with_context(|| "socket: set reuse port")?; socket .bind(&config.network.address.into()) - .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); + .with_context(|| format!("socket: bind to {}", config.network.address))?; socket .listen(config.network.tcp_backlog) - .unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err)); + .with_context(|| format!("socket: listen {}", config.network.address))?; - priv_dropper.after_socket_creation(); + priv_dropper.after_socket_creation()?; - unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) } + Ok(unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) }) }