diff --git a/aquatic_common/src/cpu_pinning.rs b/aquatic_common/src/cpu_pinning.rs index b82d5bb..0e03e52 100644 --- a/aquatic_common/src/cpu_pinning.rs +++ b/aquatic_common/src/cpu_pinning.rs @@ -194,7 +194,9 @@ pub mod glommio { // 15 -> 14 and 15 // 14 -> 12 and 13 // 13 -> 10 and 11 - CpuPinningDirection::Descending => num_cpu_cores - 2 * (num_cpu_cores - core_index), + CpuPinningDirection::Descending => { + num_cpu_cores - 2 * (num_cpu_cores - core_index) + } }; get_cpu_set()? diff --git a/aquatic_common/src/privileges.rs b/aquatic_common/src/privileges.rs index 0e4a627..4475830 100644 --- a/aquatic_common/src/privileges.rs +++ b/aquatic_common/src/privileges.rs @@ -1,22 +1,22 @@ use std::{ - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - time::Duration, + path::PathBuf, + sync::{Arc, Barrier}, }; -use aquatic_toml_config::TomlConfig; use privdrop::PrivDrop; use serde::Deserialize; +use aquatic_toml_config::TomlConfig; + #[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] #[serde(default)] pub struct PrivilegeConfig { - /// Chroot and switch user after binding to sockets + /// Chroot and switch group and user after binding to sockets pub drop_privileges: bool, /// Chroot to this path - pub chroot_path: String, + pub chroot_path: PathBuf, + /// Group to switch to after chrooting + pub group: String, /// User to switch to after chrooting pub user: String, } @@ -25,41 +25,37 @@ impl Default for PrivilegeConfig { fn default() -> Self { Self { drop_privileges: false, - chroot_path: ".".to_string(), + chroot_path: ".".into(), user: "nobody".to_string(), + group: "nobody".to_string(), } } } -pub fn drop_privileges_after_socket_binding( - config: &PrivilegeConfig, - num_bound_sockets: Arc, - target_num: usize, -) -> anyhow::Result<()> { - if config.drop_privileges { - let mut counter = 0usize; +#[derive(Clone)] +pub struct PrivilegeDropper { + barrier: Arc, + config: Arc, +} - loop { - let num_bound = num_bound_sockets.load(Ordering::SeqCst); +impl PrivilegeDropper { + pub fn new(config: PrivilegeConfig, num_sockets: usize) -> Self { + Self { + barrier: Arc::new(Barrier::new(num_sockets)), + config: Arc::new(config), + } + } - if num_bound == target_num { + pub fn after_socket_creation(&self) { + if self.config.drop_privileges { + if self.barrier.wait().is_leader() { PrivDrop::default() - .chroot(config.chroot_path.clone()) - .user(config.user.clone()) - .apply()?; - - break; - } - - ::std::thread::sleep(Duration::from_millis(10)); - - counter += 1; - - if counter == 500 { - panic!("Sockets didn't bind in time for privilege drop."); + .chroot(self.config.chroot_path.clone()) + .user(self.config.user.clone()) + .user(self.config.user.clone()) + .apply() + .expect("drop privileges"); } } } - - Ok(()) } diff --git a/aquatic_http/src/config.rs b/aquatic_http/src/config.rs index c94cf1e..0cbcea9 100644 --- a/aquatic_http/src/config.rs +++ b/aquatic_http/src/config.rs @@ -1,6 +1,9 @@ use std::{net::SocketAddr, path::PathBuf}; -use aquatic_common::{access_list::AccessListConfig, privileges::PrivilegeConfig, cpu_pinning::asc::CpuPinningConfigAsc}; +use aquatic_common::{ + access_list::AccessListConfig, cpu_pinning::asc::CpuPinningConfigAsc, + privileges::PrivilegeConfig, +}; use aquatic_toml_config::TomlConfig; use serde::Deserialize; diff --git a/aquatic_http/src/lib.rs b/aquatic_http/src/lib.rs index 28ca996..7ce2065 100644 --- a/aquatic_http/src/lib.rs +++ b/aquatic_http/src/lib.rs @@ -4,13 +4,13 @@ use aquatic_common::{ glommio::{get_worker_placement, set_affinity_for_util_worker}, WorkerIndex, }, - privileges::drop_privileges_after_socket_binding, + privileges::PrivilegeDropper, rustls_config::create_rustls_config, }; use common::State; use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; use signal_hook::{consts::SIGUSR1, iterator::Signals}; -use std::sync::{atomic::AtomicUsize, Arc}; +use std::sync::Arc; use crate::config::Config; @@ -63,7 +63,7 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); - let num_bound_sockets = Arc::new(AtomicUsize::new(0)); + let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); let tls_config = Arc::new(create_rustls_config( &config.network.tls_certificate_path, @@ -78,7 +78,7 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { let tls_config = tls_config.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); - let num_bound_sockets = num_bound_sockets.clone(); + let priv_dropper = priv_dropper.clone(); let placement = get_worker_placement( &config.cpu_pinning, @@ -95,7 +95,7 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { tls_config, request_mesh_builder, response_mesh_builder, - num_bound_sockets, + priv_dropper, ) .await }); @@ -130,13 +130,6 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> { executors.push(executor); } - drop_privileges_after_socket_binding( - &config.privileges, - num_bound_sockets, - config.socket_workers, - ) - .unwrap(); - if config.cpu_pinning.active { set_affinity_for_util_worker( &config.cpu_pinning, diff --git a/aquatic_http/src/workers/socket.rs b/aquatic_http/src/workers/socket.rs index 3992551..50e6e98 100644 --- a/aquatic_http/src/workers/socket.rs +++ b/aquatic_http/src/workers/socket.rs @@ -2,11 +2,11 @@ use std::cell::RefCell; use std::collections::BTreeMap; use std::os::unix::prelude::{FromRawFd, IntoRawFd}; use std::rc::Rc; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; +use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::rustls_config::RustlsConfig; use aquatic_common::CanonicalSocketAddr; use aquatic_http_protocol::common::InfoHash; @@ -58,13 +58,12 @@ pub async fn run_socket_worker( tls_config: Arc, request_mesh_builder: MeshBuilder, response_mesh_builder: MeshBuilder, - num_bound_sockets: Arc, + priv_dropper: PrivilegeDropper, ) { let config = Rc::new(config); let access_list = state.access_list; - let listener = create_tcp_listener(&config); - num_bound_sockets.fetch_add(1, Ordering::SeqCst); + let listener = create_tcp_listener(&config, priv_dropper); let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap(); let request_senders = Rc::new(request_senders); @@ -485,7 +484,7 @@ fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usi (info_hash.0[0] as usize) % config.request_workers } -fn create_tcp_listener(config: &Config) -> TcpListener { +fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> TcpListener { let domain = if config.network.address.is_ipv4() { socket2::Domain::IPV4 } else { @@ -509,5 +508,7 @@ fn create_tcp_listener(config: &Config) -> TcpListener { .listen(config.network.tcp_backlog) .unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err)); + priv_dropper.after_socket_creation(); + unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) } } diff --git a/aquatic_udp/src/lib.rs b/aquatic_udp/src/lib.rs index d56d743..5a805d8 100644 --- a/aquatic_udp/src/lib.rs +++ b/aquatic_udp/src/lib.rs @@ -5,13 +5,12 @@ pub mod workers; use config::Config; use std::collections::BTreeMap; -use std::sync::{atomic::AtomicUsize, Arc}; use std::thread::Builder; use anyhow::Context; #[cfg(feature = "cpu-pinning")] use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; -use aquatic_common::privileges::drop_privileges_after_socket_binding; +use aquatic_common::privileges::PrivilegeDropper; use crossbeam_channel::{bounded, unbounded}; use aquatic_common::access_list::update_access_list; @@ -32,7 +31,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let mut signals = Signals::new(::std::iter::once(SIGUSR1))?; - let num_bound_sockets = Arc::new(AtomicUsize::new(0)); + let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); let mut request_senders = Vec::new(); let mut request_receivers = BTreeMap::new(); @@ -96,7 +95,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { let request_sender = ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone()); let response_receiver = response_receivers.remove(&i).unwrap(); - let num_bound_sockets = num_bound_sockets.clone(); + let priv_dropper = priv_dropper.clone(); Builder::new() .name(format!("socket-{:02}", i + 1)) @@ -115,7 +114,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { i, request_sender, response_receiver, - num_bound_sockets, + priv_dropper, ); }) .with_context(|| "spawn socket worker")?; @@ -141,13 +140,6 @@ pub fn run(config: Config) -> ::anyhow::Result<()> { .with_context(|| "spawn statistics worker")?; } - drop_privileges_after_socket_binding( - &config.privileges, - num_bound_sockets, - config.socket_workers, - ) - .unwrap(); - #[cfg(feature = "cpu-pinning")] pin_current_if_configured_to( &config.cpu_pinning, diff --git a/aquatic_udp/src/workers/socket.rs b/aquatic_udp/src/workers/socket.rs index 6bf3487..0354b58 100644 --- a/aquatic_udp/src/workers/socket.rs +++ b/aquatic_udp/src/workers/socket.rs @@ -1,12 +1,10 @@ use std::collections::BTreeMap; use std::io::{Cursor, ErrorKind}; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; +use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; use std::vec::Drain; +use aquatic_common::privileges::PrivilegeDropper; use crossbeam_channel::Receiver; use mio::net::UdpSocket; use mio::{Events, Interest, Poll, Token}; @@ -157,12 +155,12 @@ pub fn run_socket_worker( token_num: usize, request_sender: ConnectedRequestSender, response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, - num_bound_sockets: Arc, + priv_dropper: PrivilegeDropper, ) { let mut rng = StdRng::from_entropy(); let mut buffer = [0u8; MAX_PACKET_SIZE]; - let mut socket = UdpSocket::from_std(create_socket(&config)); + let mut socket = UdpSocket::from_std(create_socket(&config, priv_dropper)); let mut poll = Poll::new().expect("create poll"); let interests = Interest::READABLE; @@ -171,8 +169,6 @@ pub fn run_socket_worker( .register(&mut socket, Token(token_num), interests) .unwrap(); - num_bound_sockets.fetch_add(1, Ordering::SeqCst); - let mut events = Events::with_capacity(config.network.poll_event_capacity); let mut connections = ConnectionMap::default(); let mut pending_scrape_responses = PendingScrapeResponseSlab::default(); @@ -520,7 +516,7 @@ fn send_response( } } -pub fn create_socket(config: &Config) -> ::std::net::UdpSocket { +pub fn create_socket(config: &Config, priv_dropper: PrivilegeDropper) -> ::std::net::UdpSocket { let socket = if config.network.address.is_ipv4() { Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) } else { @@ -542,6 +538,8 @@ pub fn create_socket(config: &Config) -> ::std::net::UdpSocket { .bind(&config.network.address.into()) .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); + priv_dropper.after_socket_creation(); + let recv_buffer_size = config.network.socket_recv_buffer_size; if recv_buffer_size != 0 { diff --git a/aquatic_ws/src/lib.rs b/aquatic_ws/src/lib.rs index 0ba6bbc..6671b00 100644 --- a/aquatic_ws/src/lib.rs +++ b/aquatic_ws/src/lib.rs @@ -2,7 +2,7 @@ pub mod common; pub mod config; pub mod workers; -use std::sync::{atomic::AtomicUsize, Arc}; +use std::sync::Arc; use aquatic_common::cpu_pinning::glommio::{get_worker_placement, set_affinity_for_util_worker}; use aquatic_common::cpu_pinning::WorkerIndex; @@ -11,7 +11,7 @@ use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; use signal_hook::{consts::SIGUSR1, iterator::Signals}; use aquatic_common::access_list::update_access_list; -use aquatic_common::privileges::drop_privileges_after_socket_binding; +use aquatic_common::privileges::PrivilegeDropper; use common::*; use config::Config; @@ -61,7 +61,7 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> { let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE); let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16); - let num_bound_sockets = Arc::new(AtomicUsize::new(0)); + let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); let tls_config = Arc::new(create_rustls_config( &config.network.tls_certificate_path, @@ -76,7 +76,7 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> { let tls_config = tls_config.clone(); let request_mesh_builder = request_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone(); - let num_bound_sockets = num_bound_sockets.clone(); + let priv_dropper = priv_dropper.clone(); let placement = get_worker_placement( &config.cpu_pinning, @@ -93,7 +93,7 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> { tls_config, request_mesh_builder, response_mesh_builder, - num_bound_sockets, + priv_dropper, ) .await }); @@ -128,13 +128,6 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> { executors.push(executor); } - drop_privileges_after_socket_binding( - &config.privileges, - num_bound_sockets, - config.socket_workers, - ) - .unwrap(); - if config.cpu_pinning.active { set_affinity_for_util_worker( &config.cpu_pinning, diff --git a/aquatic_ws/src/workers/socket.rs b/aquatic_ws/src/workers/socket.rs index 7c121d4..0acdd56 100644 --- a/aquatic_ws/src/workers/socket.rs +++ b/aquatic_ws/src/workers/socket.rs @@ -3,11 +3,11 @@ use std::cell::RefCell; use std::collections::BTreeMap; use std::os::unix::prelude::{FromRawFd, IntoRawFd}; use std::rc::Rc; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; +use aquatic_common::privileges::PrivilegeDropper; use aquatic_common::rustls_config::RustlsConfig; use aquatic_common::CanonicalSocketAddr; use aquatic_ws_protocol::*; @@ -53,14 +53,12 @@ pub async fn run_socket_worker( tls_config: Arc, in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, - num_bound_sockets: Arc, + priv_dropper: PrivilegeDropper, ) { let config = Rc::new(config); let access_list = state.access_list; - let listener = create_tcp_listener(&config); - - num_bound_sockets.fetch_add(1, Ordering::SeqCst); + let listener = create_tcp_listener(&config, priv_dropper); let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap(); let in_message_senders = Rc::new(in_message_senders); @@ -544,7 +542,7 @@ fn calculate_in_message_consumer_index(config: &Config, info_hash: InfoHash) -> (info_hash.0[0] as usize) % config.request_workers } -fn create_tcp_listener(config: &Config) -> TcpListener { +fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> TcpListener { let domain = if config.network.address.is_ipv4() { socket2::Domain::IPV4 } else { @@ -568,5 +566,7 @@ fn create_tcp_listener(config: &Config) -> TcpListener { .listen(config.network.tcp_backlog) .unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err)); + priv_dropper.after_socket_creation(); + unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) } }