diff --git a/aquatic_common/src/lib.rs b/aquatic_common/src/lib.rs index 96804ff..17a647d 100644 --- a/aquatic_common/src/lib.rs +++ b/aquatic_common/src/lib.rs @@ -94,71 +94,6 @@ impl Drop for PanicSentinel { } } -/// Extract response peers -/// -/// If there are more peers in map than `max_num_peers_to_take`, do a -/// half-random selection of peers from first and second halves of map, -/// in order to avoid returning too homogeneous peers. -/// -/// Might return one less peer than wanted since sender is filtered out. -#[inline] -pub fn extract_response_peers( - rng: &mut impl Rng, - peer_map: &IndexMap, - max_num_peers_to_take: usize, - sender_peer_map_key: K, - peer_conversion_function: F, -) -> Vec -where - K: Eq + ::std::hash::Hash, - F: Fn(&V) -> R, -{ - let peer_map_len = peer_map.len(); - - if peer_map_len <= max_num_peers_to_take + 1 { - let mut peers = Vec::with_capacity(peer_map_len); - - peers.extend(peer_map.iter().filter_map(|(k, v)| { - if *k == sender_peer_map_key { - None - } else { - Some(peer_conversion_function(v)) - } - })); - - peers - } else { - let half_num_to_take = max_num_peers_to_take / 2; - let half_peer_map_len = peer_map_len / 2; - - let offset_first_half = - rng.gen_range(0..(half_peer_map_len + (peer_map_len % 2)) - half_num_to_take); - let offset_second_half = rng.gen_range(half_peer_map_len..peer_map_len - half_num_to_take); - - let end_first_half = offset_first_half + half_num_to_take; - let end_second_half = offset_second_half + half_num_to_take + (max_num_peers_to_take % 2); - - let mut peers: Vec = Vec::with_capacity(max_num_peers_to_take); - - for i in offset_first_half..end_first_half { - if let Some((k, peer)) = peer_map.get_index(i) { - if *k != sender_peer_map_key { - peers.push(peer_conversion_function(peer)) - } - } - } - for i in offset_second_half..end_second_half { - if let Some((k, peer)) = peer_map.get_index(i) { - if *k != sender_peer_map_key { - peers.push(peer_conversion_function(peer)) - } - } - } - - peers - } -} - /// SocketAddr that is not an IPv6-mapped IPv4 address #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] pub struct CanonicalSocketAddr(SocketAddr); @@ -205,3 +140,167 @@ impl CanonicalSocketAddr { self.0.is_ipv4() } } + +/// Extract response peers +/// +/// If there are more peers in map than `max_num_peers_to_take`, do a random +/// selection of peers from first and second halves of map in order to avoid +/// returning too homogeneous peers. +#[inline] +pub fn extract_response_peers( + rng: &mut impl Rng, + peer_map: &IndexMap, + max_num_peers_to_take: usize, + sender_peer_map_key: K, + peer_conversion_function: F, +) -> Vec +where + K: Eq + ::std::hash::Hash, + F: Fn(&V) -> R, +{ + if peer_map.len() <= max_num_peers_to_take + 1 { + // First case: number of peers in map (minus sender peer) is less than + // or equal to number of peers to take, so return all except sender + // peer. + let mut peers = Vec::with_capacity(peer_map.len()); + + let mut opt_sender_peer_index = None; + + for (i, (k, v)) in peer_map.iter().enumerate() { + opt_sender_peer_index = + opt_sender_peer_index.or((*k == sender_peer_map_key).then_some(i)); + + peers.push(peer_conversion_function(v)); + } + + if let Some(index) = opt_sender_peer_index { + peers.swap_remove(index); + } + + // Handle the case when sender peer is not in peer list. Typically, + // this function will not be called when this is the case. + if peers.len() > max_num_peers_to_take { + peers.swap_remove(0); + } + + peers + } else { + // If this branch is taken, the peer map contains at least two more + // peers than max_num_peers_to_take + + let middle_index = peer_map.len() / 2; + // Add one to take two extra peers in case sender peer is among + // selected peers and will need to be filtered out + let half_num_to_take = (max_num_peers_to_take / 2) + 1; + + let offset_half_one = rng.gen_range(0..(middle_index - half_num_to_take).max(1)); + let offset_half_two = + rng.gen_range(middle_index..(peer_map.len() - half_num_to_take).max(middle_index + 1)); + + let end_half_one = offset_half_one + half_num_to_take; + let end_half_two = offset_half_two + half_num_to_take; + + let mut peers = Vec::with_capacity(max_num_peers_to_take + 2); + + // Extract first range + { + let mut opt_sender_peer_index = None; + + if let Some(slice) = peer_map.get_range(offset_half_one..end_half_one) { + for (i, (k, v)) in slice.into_iter().enumerate() { + opt_sender_peer_index = + opt_sender_peer_index.or((*k == sender_peer_map_key).then_some(i)); + + peers.push(peer_conversion_function(v)); + } + } + + if let Some(index) = opt_sender_peer_index { + peers.swap_remove(index); + } + } + + // Extract second range + { + let initial_peers_len = peers.len(); + + let mut opt_sender_peer_index = None; + + if let Some(slice) = peer_map.get_range(offset_half_two..end_half_two) { + for (i, (k, v)) in slice.into_iter().enumerate() { + opt_sender_peer_index = + opt_sender_peer_index.or((*k == sender_peer_map_key).then_some(i)); + + peers.push(peer_conversion_function(v)); + } + } + + if let Some(index) = opt_sender_peer_index { + peers.swap_remove(initial_peers_len + index); + } + } + + while peers.len() > max_num_peers_to_take { + peers.swap_remove(0); + } + + peers + } +} + +#[cfg(test)] +mod tests { + use ahash::HashSet; + + use rand::{rngs::SmallRng, SeedableRng}; + + use super::*; + + #[test] + fn test_extract_response_peers() { + let mut rng = SmallRng::from_entropy(); + + for num_peers_in_map in 0..1000 { + for max_num_peers_to_take in [0, 1, 2, 5, 10, 50] { + for sender_peer_map_key in [0, 1, 4, 5, 10, 44, 500] { + test_extract_response_peers_helper( + &mut rng, + num_peers_in_map, + max_num_peers_to_take, + sender_peer_map_key, + ); + } + } + } + } + + fn test_extract_response_peers_helper( + rng: &mut SmallRng, + num_peers_in_map: usize, + max_num_peers_to_take: usize, + sender_peer_map_key: usize, + ) { + let peer_map = IndexMap::from_iter((0..num_peers_in_map).map(|i| (i, i))); + + let response_peers = extract_response_peers( + rng, + &peer_map, + max_num_peers_to_take, + sender_peer_map_key, + |p| *p, + ); + + if num_peers_in_map > max_num_peers_to_take + 1 { + assert_eq!(response_peers.len(), max_num_peers_to_take); + } else { + assert!(response_peers.len() <= max_num_peers_to_take); + } + + assert!(!response_peers.contains(&sender_peer_map_key)); + + assert_eq!( + response_peers.len(), + HashSet::from_iter(response_peers.iter().copied()).len() + ); + } +}