Improve privilege dropping; run cargo fmt

This commit is contained in:
Joakim Frostegård 2022-04-05 01:26:40 +02:00
parent 2ad1418175
commit c888017072
9 changed files with 70 additions and 92 deletions

View file

@ -194,7 +194,9 @@ pub mod glommio {
// 15 -> 14 and 15 // 15 -> 14 and 15
// 14 -> 12 and 13 // 14 -> 12 and 13
// 13 -> 10 and 11 // 13 -> 10 and 11
CpuPinningDirection::Descending => num_cpu_cores - 2 * (num_cpu_cores - core_index), CpuPinningDirection::Descending => {
num_cpu_cores - 2 * (num_cpu_cores - core_index)
}
}; };
get_cpu_set()? get_cpu_set()?

View file

@ -1,22 +1,22 @@
use std::{ use std::{
sync::{ path::PathBuf,
atomic::{AtomicUsize, Ordering}, sync::{Arc, Barrier},
Arc,
},
time::Duration,
}; };
use aquatic_toml_config::TomlConfig;
use privdrop::PrivDrop; use privdrop::PrivDrop;
use serde::Deserialize; use serde::Deserialize;
use aquatic_toml_config::TomlConfig;
#[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)] #[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize)]
#[serde(default)] #[serde(default)]
pub struct PrivilegeConfig { pub struct PrivilegeConfig {
/// Chroot and switch user after binding to sockets /// Chroot and switch group and user after binding to sockets
pub drop_privileges: bool, pub drop_privileges: bool,
/// Chroot to this path /// Chroot to this path
pub chroot_path: String, pub chroot_path: PathBuf,
/// Group to switch to after chrooting
pub group: String,
/// User to switch to after chrooting /// User to switch to after chrooting
pub user: String, pub user: String,
} }
@ -25,41 +25,37 @@ impl Default for PrivilegeConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
drop_privileges: false, drop_privileges: false,
chroot_path: ".".to_string(), chroot_path: ".".into(),
user: "nobody".to_string(), user: "nobody".to_string(),
group: "nobody".to_string(),
} }
} }
} }
pub fn drop_privileges_after_socket_binding( #[derive(Clone)]
config: &PrivilegeConfig, pub struct PrivilegeDropper {
num_bound_sockets: Arc<AtomicUsize>, barrier: Arc<Barrier>,
target_num: usize, config: Arc<PrivilegeConfig>,
) -> anyhow::Result<()> { }
if config.drop_privileges {
let mut counter = 0usize;
loop { impl PrivilegeDropper {
let num_bound = num_bound_sockets.load(Ordering::SeqCst); pub fn new(config: PrivilegeConfig, num_sockets: usize) -> Self {
Self {
barrier: Arc::new(Barrier::new(num_sockets)),
config: Arc::new(config),
}
}
if num_bound == target_num { pub fn after_socket_creation(&self) {
if self.config.drop_privileges {
if self.barrier.wait().is_leader() {
PrivDrop::default() PrivDrop::default()
.chroot(config.chroot_path.clone()) .chroot(self.config.chroot_path.clone())
.user(config.user.clone()) .user(self.config.user.clone())
.apply()?; .user(self.config.user.clone())
.apply()
break; .expect("drop privileges");
}
::std::thread::sleep(Duration::from_millis(10));
counter += 1;
if counter == 500 {
panic!("Sockets didn't bind in time for privilege drop.");
} }
} }
} }
Ok(())
} }

View file

@ -1,6 +1,9 @@
use std::{net::SocketAddr, path::PathBuf}; use std::{net::SocketAddr, path::PathBuf};
use aquatic_common::{access_list::AccessListConfig, privileges::PrivilegeConfig, cpu_pinning::asc::CpuPinningConfigAsc}; use aquatic_common::{
access_list::AccessListConfig, cpu_pinning::asc::CpuPinningConfigAsc,
privileges::PrivilegeConfig,
};
use aquatic_toml_config::TomlConfig; use aquatic_toml_config::TomlConfig;
use serde::Deserialize; use serde::Deserialize;

View file

@ -4,13 +4,13 @@ use aquatic_common::{
glommio::{get_worker_placement, set_affinity_for_util_worker}, glommio::{get_worker_placement, set_affinity_for_util_worker},
WorkerIndex, WorkerIndex,
}, },
privileges::drop_privileges_after_socket_binding, privileges::PrivilegeDropper,
rustls_config::create_rustls_config, rustls_config::create_rustls_config,
}; };
use common::State; use common::State;
use glommio::{channels::channel_mesh::MeshBuilder, prelude::*}; use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
use signal_hook::{consts::SIGUSR1, iterator::Signals}; use signal_hook::{consts::SIGUSR1, iterator::Signals};
use std::sync::{atomic::AtomicUsize, Arc}; use std::sync::Arc;
use crate::config::Config; use crate::config::Config;
@ -63,7 +63,7 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> {
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE); let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_CHANNEL_SIZE);
let num_bound_sockets = Arc::new(AtomicUsize::new(0)); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
let tls_config = Arc::new(create_rustls_config( let tls_config = Arc::new(create_rustls_config(
&config.network.tls_certificate_path, &config.network.tls_certificate_path,
@ -78,7 +78,7 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> {
let tls_config = tls_config.clone(); let tls_config = tls_config.clone();
let request_mesh_builder = request_mesh_builder.clone(); let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone();
let num_bound_sockets = num_bound_sockets.clone(); let priv_dropper = priv_dropper.clone();
let placement = get_worker_placement( let placement = get_worker_placement(
&config.cpu_pinning, &config.cpu_pinning,
@ -95,7 +95,7 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> {
tls_config, tls_config,
request_mesh_builder, request_mesh_builder,
response_mesh_builder, response_mesh_builder,
num_bound_sockets, priv_dropper,
) )
.await .await
}); });
@ -130,13 +130,6 @@ pub fn run_inner(config: Config, state: State) -> anyhow::Result<()> {
executors.push(executor); executors.push(executor);
} }
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
if config.cpu_pinning.active { if config.cpu_pinning.active {
set_affinity_for_util_worker( set_affinity_for_util_worker(
&config.cpu_pinning, &config.cpu_pinning,

View file

@ -2,11 +2,11 @@ use std::cell::RefCell;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::os::unix::prelude::{FromRawFd, IntoRawFd}; use std::os::unix::prelude::{FromRawFd, IntoRawFd};
use std::rc::Rc; use std::rc::Rc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache};
use aquatic_common::privileges::PrivilegeDropper;
use aquatic_common::rustls_config::RustlsConfig; use aquatic_common::rustls_config::RustlsConfig;
use aquatic_common::CanonicalSocketAddr; use aquatic_common::CanonicalSocketAddr;
use aquatic_http_protocol::common::InfoHash; use aquatic_http_protocol::common::InfoHash;
@ -58,13 +58,12 @@ pub async fn run_socket_worker(
tls_config: Arc<RustlsConfig>, tls_config: Arc<RustlsConfig>,
request_mesh_builder: MeshBuilder<ChannelRequest, Partial>, request_mesh_builder: MeshBuilder<ChannelRequest, Partial>,
response_mesh_builder: MeshBuilder<ChannelResponse, Partial>, response_mesh_builder: MeshBuilder<ChannelResponse, Partial>,
num_bound_sockets: Arc<AtomicUsize>, priv_dropper: PrivilegeDropper,
) { ) {
let config = Rc::new(config); let config = Rc::new(config);
let access_list = state.access_list; let access_list = state.access_list;
let listener = create_tcp_listener(&config); let listener = create_tcp_listener(&config, priv_dropper);
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap(); let (request_senders, _) = request_mesh_builder.join(Role::Producer).await.unwrap();
let request_senders = Rc::new(request_senders); let request_senders = Rc::new(request_senders);
@ -485,7 +484,7 @@ fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usi
(info_hash.0[0] as usize) % config.request_workers (info_hash.0[0] as usize) % config.request_workers
} }
fn create_tcp_listener(config: &Config) -> TcpListener { fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> TcpListener {
let domain = if config.network.address.is_ipv4() { let domain = if config.network.address.is_ipv4() {
socket2::Domain::IPV4 socket2::Domain::IPV4
} else { } else {
@ -509,5 +508,7 @@ fn create_tcp_listener(config: &Config) -> TcpListener {
.listen(config.network.tcp_backlog) .listen(config.network.tcp_backlog)
.unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err)); .unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err));
priv_dropper.after_socket_creation();
unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) } unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) }
} }

View file

@ -5,13 +5,12 @@ pub mod workers;
use config::Config; use config::Config;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::{atomic::AtomicUsize, Arc};
use std::thread::Builder; use std::thread::Builder;
use anyhow::Context; use anyhow::Context;
#[cfg(feature = "cpu-pinning")] #[cfg(feature = "cpu-pinning")]
use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex}; use aquatic_common::cpu_pinning::{pin_current_if_configured_to, WorkerIndex};
use aquatic_common::privileges::drop_privileges_after_socket_binding; use aquatic_common::privileges::PrivilegeDropper;
use crossbeam_channel::{bounded, unbounded}; use crossbeam_channel::{bounded, unbounded};
use aquatic_common::access_list::update_access_list; use aquatic_common::access_list::update_access_list;
@ -32,7 +31,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
let mut signals = Signals::new(::std::iter::once(SIGUSR1))?; let mut signals = Signals::new(::std::iter::once(SIGUSR1))?;
let num_bound_sockets = Arc::new(AtomicUsize::new(0)); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
let mut request_senders = Vec::new(); let mut request_senders = Vec::new();
let mut request_receivers = BTreeMap::new(); let mut request_receivers = BTreeMap::new();
@ -96,7 +95,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
let request_sender = let request_sender =
ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone()); ConnectedRequestSender::new(SocketWorkerIndex(i), request_senders.clone());
let response_receiver = response_receivers.remove(&i).unwrap(); let response_receiver = response_receivers.remove(&i).unwrap();
let num_bound_sockets = num_bound_sockets.clone(); let priv_dropper = priv_dropper.clone();
Builder::new() Builder::new()
.name(format!("socket-{:02}", i + 1)) .name(format!("socket-{:02}", i + 1))
@ -115,7 +114,7 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
i, i,
request_sender, request_sender,
response_receiver, response_receiver,
num_bound_sockets, priv_dropper,
); );
}) })
.with_context(|| "spawn socket worker")?; .with_context(|| "spawn socket worker")?;
@ -141,13 +140,6 @@ pub fn run(config: Config) -> ::anyhow::Result<()> {
.with_context(|| "spawn statistics worker")?; .with_context(|| "spawn statistics worker")?;
} }
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
#[cfg(feature = "cpu-pinning")] #[cfg(feature = "cpu-pinning")]
pin_current_if_configured_to( pin_current_if_configured_to(
&config.cpu_pinning, &config.cpu_pinning,

View file

@ -1,12 +1,10 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::io::{Cursor, ErrorKind}; use std::io::{Cursor, ErrorKind};
use std::sync::{ use std::sync::atomic::Ordering;
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::vec::Drain; use std::vec::Drain;
use aquatic_common::privileges::PrivilegeDropper;
use crossbeam_channel::Receiver; use crossbeam_channel::Receiver;
use mio::net::UdpSocket; use mio::net::UdpSocket;
use mio::{Events, Interest, Poll, Token}; use mio::{Events, Interest, Poll, Token};
@ -157,12 +155,12 @@ pub fn run_socket_worker(
token_num: usize, token_num: usize,
request_sender: ConnectedRequestSender, request_sender: ConnectedRequestSender,
response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>, response_receiver: Receiver<(ConnectedResponse, CanonicalSocketAddr)>,
num_bound_sockets: Arc<AtomicUsize>, priv_dropper: PrivilegeDropper,
) { ) {
let mut rng = StdRng::from_entropy(); let mut rng = StdRng::from_entropy();
let mut buffer = [0u8; MAX_PACKET_SIZE]; let mut buffer = [0u8; MAX_PACKET_SIZE];
let mut socket = UdpSocket::from_std(create_socket(&config)); let mut socket = UdpSocket::from_std(create_socket(&config, priv_dropper));
let mut poll = Poll::new().expect("create poll"); let mut poll = Poll::new().expect("create poll");
let interests = Interest::READABLE; let interests = Interest::READABLE;
@ -171,8 +169,6 @@ pub fn run_socket_worker(
.register(&mut socket, Token(token_num), interests) .register(&mut socket, Token(token_num), interests)
.unwrap(); .unwrap();
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
let mut events = Events::with_capacity(config.network.poll_event_capacity); let mut events = Events::with_capacity(config.network.poll_event_capacity);
let mut connections = ConnectionMap::default(); let mut connections = ConnectionMap::default();
let mut pending_scrape_responses = PendingScrapeResponseSlab::default(); let mut pending_scrape_responses = PendingScrapeResponseSlab::default();
@ -520,7 +516,7 @@ fn send_response(
} }
} }
pub fn create_socket(config: &Config) -> ::std::net::UdpSocket { pub fn create_socket(config: &Config, priv_dropper: PrivilegeDropper) -> ::std::net::UdpSocket {
let socket = if config.network.address.is_ipv4() { let socket = if config.network.address.is_ipv4() {
Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
} else { } else {
@ -542,6 +538,8 @@ pub fn create_socket(config: &Config) -> ::std::net::UdpSocket {
.bind(&config.network.address.into()) .bind(&config.network.address.into())
.unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err)); .unwrap_or_else(|err| panic!("socket: bind to {}: {:?}", config.network.address, err));
priv_dropper.after_socket_creation();
let recv_buffer_size = config.network.socket_recv_buffer_size; let recv_buffer_size = config.network.socket_recv_buffer_size;
if recv_buffer_size != 0 { if recv_buffer_size != 0 {

View file

@ -2,7 +2,7 @@ pub mod common;
pub mod config; pub mod config;
pub mod workers; pub mod workers;
use std::sync::{atomic::AtomicUsize, Arc}; use std::sync::Arc;
use aquatic_common::cpu_pinning::glommio::{get_worker_placement, set_affinity_for_util_worker}; use aquatic_common::cpu_pinning::glommio::{get_worker_placement, set_affinity_for_util_worker};
use aquatic_common::cpu_pinning::WorkerIndex; use aquatic_common::cpu_pinning::WorkerIndex;
@ -11,7 +11,7 @@ use glommio::{channels::channel_mesh::MeshBuilder, prelude::*};
use signal_hook::{consts::SIGUSR1, iterator::Signals}; use signal_hook::{consts::SIGUSR1, iterator::Signals};
use aquatic_common::access_list::update_access_list; use aquatic_common::access_list::update_access_list;
use aquatic_common::privileges::drop_privileges_after_socket_binding; use aquatic_common::privileges::PrivilegeDropper;
use common::*; use common::*;
use config::Config; use config::Config;
@ -61,7 +61,7 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> {
let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE); let request_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE);
let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16); let response_mesh_builder = MeshBuilder::partial(num_peers, SHARED_IN_CHANNEL_SIZE * 16);
let num_bound_sockets = Arc::new(AtomicUsize::new(0)); let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers);
let tls_config = Arc::new(create_rustls_config( let tls_config = Arc::new(create_rustls_config(
&config.network.tls_certificate_path, &config.network.tls_certificate_path,
@ -76,7 +76,7 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> {
let tls_config = tls_config.clone(); let tls_config = tls_config.clone();
let request_mesh_builder = request_mesh_builder.clone(); let request_mesh_builder = request_mesh_builder.clone();
let response_mesh_builder = response_mesh_builder.clone(); let response_mesh_builder = response_mesh_builder.clone();
let num_bound_sockets = num_bound_sockets.clone(); let priv_dropper = priv_dropper.clone();
let placement = get_worker_placement( let placement = get_worker_placement(
&config.cpu_pinning, &config.cpu_pinning,
@ -93,7 +93,7 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> {
tls_config, tls_config,
request_mesh_builder, request_mesh_builder,
response_mesh_builder, response_mesh_builder,
num_bound_sockets, priv_dropper,
) )
.await .await
}); });
@ -128,13 +128,6 @@ fn run_workers(config: Config, state: State) -> anyhow::Result<()> {
executors.push(executor); executors.push(executor);
} }
drop_privileges_after_socket_binding(
&config.privileges,
num_bound_sockets,
config.socket_workers,
)
.unwrap();
if config.cpu_pinning.active { if config.cpu_pinning.active {
set_affinity_for_util_worker( set_affinity_for_util_worker(
&config.cpu_pinning, &config.cpu_pinning,

View file

@ -3,11 +3,11 @@ use std::cell::RefCell;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::os::unix::prelude::{FromRawFd, IntoRawFd}; use std::os::unix::prelude::{FromRawFd, IntoRawFd};
use std::rc::Rc; use std::rc::Rc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache}; use aquatic_common::access_list::{create_access_list_cache, AccessListArcSwap, AccessListCache};
use aquatic_common::privileges::PrivilegeDropper;
use aquatic_common::rustls_config::RustlsConfig; use aquatic_common::rustls_config::RustlsConfig;
use aquatic_common::CanonicalSocketAddr; use aquatic_common::CanonicalSocketAddr;
use aquatic_ws_protocol::*; use aquatic_ws_protocol::*;
@ -53,14 +53,12 @@ pub async fn run_socket_worker(
tls_config: Arc<RustlsConfig>, tls_config: Arc<RustlsConfig>,
in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>, in_message_mesh_builder: MeshBuilder<(ConnectionMeta, InMessage), Partial>,
out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>, out_message_mesh_builder: MeshBuilder<(ConnectionMeta, OutMessage), Partial>,
num_bound_sockets: Arc<AtomicUsize>, priv_dropper: PrivilegeDropper,
) { ) {
let config = Rc::new(config); let config = Rc::new(config);
let access_list = state.access_list; let access_list = state.access_list;
let listener = create_tcp_listener(&config); let listener = create_tcp_listener(&config, priv_dropper);
num_bound_sockets.fetch_add(1, Ordering::SeqCst);
let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap(); let (in_message_senders, _) = in_message_mesh_builder.join(Role::Producer).await.unwrap();
let in_message_senders = Rc::new(in_message_senders); let in_message_senders = Rc::new(in_message_senders);
@ -544,7 +542,7 @@ fn calculate_in_message_consumer_index(config: &Config, info_hash: InfoHash) ->
(info_hash.0[0] as usize) % config.request_workers (info_hash.0[0] as usize) % config.request_workers
} }
fn create_tcp_listener(config: &Config) -> TcpListener { fn create_tcp_listener(config: &Config, priv_dropper: PrivilegeDropper) -> TcpListener {
let domain = if config.network.address.is_ipv4() { let domain = if config.network.address.is_ipv4() {
socket2::Domain::IPV4 socket2::Domain::IPV4
} else { } else {
@ -568,5 +566,7 @@ fn create_tcp_listener(config: &Config) -> TcpListener {
.listen(config.network.tcp_backlog) .listen(config.network.tcp_backlog)
.unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err)); .unwrap_or_else(|err| panic!("socket: listen {}: {:?}", config.network.address, err));
priv_dropper.after_socket_creation();
unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) } unsafe { TcpListener::from_raw_fd(socket.into_raw_fd()) }
} }