use zerocopy in udp protocol, easy running transfer CI locally

This commit is contained in:
Joakim Frostegård 2023-12-02 12:24:41 +01:00
parent af16a9e682
commit 0e12dd1b13
24 changed files with 783 additions and 652 deletions

View file

@ -1,6 +1,5 @@
use std::collections::BTreeMap;
use std::hash::Hash;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
@ -35,8 +34,8 @@ pub enum ConnectedRequest {
#[derive(Debug)]
pub enum ConnectedResponse {
AnnounceIpv4(AnnounceResponse<Ipv4Addr>),
AnnounceIpv6(AnnounceResponse<Ipv6Addr>),
AnnounceIpv4(AnnounceResponse<Ipv4AddrBytes>),
AnnounceIpv6(AnnounceResponse<Ipv6AddrBytes>),
Scrape(PendingScrapeResponse),
}
@ -125,7 +124,7 @@ impl PeerStatus {
pub fn from_event_and_bytes_left(event: AnnounceEvent, bytes_left: NumberOfBytes) -> Self {
if event == AnnounceEvent::Stopped {
Self::Stopped
} else if bytes_left.0 == 0 {
} else if bytes_left.0.get() == 0 {
Self::Seeding
} else {
Self::Leeching
@ -207,17 +206,17 @@ mod tests {
let f = PeerStatus::from_event_and_bytes_left;
assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes(0)));
assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes(1)));
assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes::new(0)));
assert_eq!(Stopped, f(AnnounceEvent::Stopped, NumberOfBytes::new(1)));
assert_eq!(Seeding, f(AnnounceEvent::Started, NumberOfBytes(0)));
assert_eq!(Leeching, f(AnnounceEvent::Started, NumberOfBytes(1)));
assert_eq!(Seeding, f(AnnounceEvent::Started, NumberOfBytes::new(0)));
assert_eq!(Leeching, f(AnnounceEvent::Started, NumberOfBytes::new(1)));
assert_eq!(Seeding, f(AnnounceEvent::Completed, NumberOfBytes(0)));
assert_eq!(Leeching, f(AnnounceEvent::Completed, NumberOfBytes(1)));
assert_eq!(Seeding, f(AnnounceEvent::Completed, NumberOfBytes::new(0)));
assert_eq!(Leeching, f(AnnounceEvent::Completed, NumberOfBytes::new(1)));
assert_eq!(Seeding, f(AnnounceEvent::None, NumberOfBytes(0)));
assert_eq!(Leeching, f(AnnounceEvent::None, NumberOfBytes(1)));
assert_eq!(Seeding, f(AnnounceEvent::None, NumberOfBytes::new(0)));
assert_eq!(Leeching, f(AnnounceEvent::None, NumberOfBytes::new(1)));
}
// Assumes that announce response with maximum amount of ipv6 peers will
@ -229,17 +228,19 @@ mod tests {
let config = Config::default();
let peers = ::std::iter::repeat(ResponsePeer {
ip_address: Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1),
port: Port(1),
ip_address: Ipv6AddrBytes(Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1).octets()),
port: Port::new(1),
})
.take(config.protocol.max_response_peers)
.collect();
let response = Response::AnnounceIpv6(AnnounceResponse {
transaction_id: TransactionId(1),
announce_interval: AnnounceInterval(1),
seeders: NumberOfPeers(1),
leechers: NumberOfPeers(1),
fixed: AnnounceResponseFixedData {
transaction_id: TransactionId::new(1),
announce_interval: AnnounceInterval::new(1),
seeders: NumberOfPeers::new(1),
leechers: NumberOfPeers::new(1),
},
peers,
});

View file

@ -195,6 +195,8 @@ impl SocketWorker {
let src = CanonicalSocketAddr::new(src);
::log::trace!("received request bytes: {}", hex_slice(&self.buffer[..bytes_read]));
let request_parsable = match Request::from_bytes(
&self.buffer[..bytes_read],
self.config.protocol.max_scrape_torrents,
@ -221,7 +223,7 @@ impl SocketWorker {
if self.validator.connection_id_valid(src, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
message: err.into(),
};
local_responses.push((response.into(), src));
@ -285,6 +287,8 @@ impl SocketWorker {
match request {
Request::Connect(request) => {
::log::trace!("received {:?} from {:?}", request, src);
let connection_id = self.validator.create_connection_id(src);
let response = Response::Connect(ConnectResponse {
@ -295,6 +299,8 @@ impl SocketWorker {
local_responses.push((response, src))
}
Request::Announce(request) => {
::log::trace!("received {:?} from {:?}", request, src);
if self
.validator
.connection_id_valid(src, request.connection_id)
@ -323,6 +329,8 @@ impl SocketWorker {
}
}
Request::Scrape(request) => {
::log::trace!("received {:?} from {:?}", request, src);
if self
.validator
.connection_id_valid(src, request.connection_id)
@ -372,6 +380,8 @@ impl SocketWorker {
canonical_addr.get_ipv6_mapped()
};
::log::trace!("sending {:?} to {}, bytes: {}", response, addr, hex_slice(&cursor.get_ref()[..bytes_written]));
match socket.send_to(&cursor.get_ref()[..bytes_written], addr) {
Ok(amt) if config.statistics.active() => {
let stats = if canonical_addr.is_ipv4() {
@ -430,3 +440,14 @@ impl SocketWorker {
}
}
}
fn hex_slice(bytes: &[u8]) -> String {
let mut output = String::with_capacity(bytes.len() * 3);
for chunk in bytes.chunks(4) {
output.push_str(&hex::encode(chunk));
output.push(' ');
}
output
}

View file

@ -156,8 +156,8 @@ mod tests {
}
let request = ScrapeRequest {
transaction_id: TransactionId(t),
connection_id: ConnectionId(c),
transaction_id: TransactionId::new(t),
connection_id: ConnectionId::new(c),
info_hashes,
};
@ -192,9 +192,9 @@ mod tests {
(
i,
TorrentScrapeStatistics {
seeders: NumberOfPeers((info_hash.0[0]) as i32),
leechers: NumberOfPeers(0),
completed: NumberOfDownloads(0),
seeders: NumberOfPeers::new((info_hash.0[0]) as i32),
leechers: NumberOfPeers::new(0),
completed: NumberOfDownloads::new(0),
},
)
})

View file

@ -393,7 +393,7 @@ impl SocketWorker {
if self.validator.connection_id_valid(addr, connection_id) {
let response = ErrorResponse {
transaction_id,
message: err.right_or("Parse error").into(),
message: err.into(),
};
self.local_responses.push_back((response.into(), addr));

View file

@ -59,7 +59,7 @@ impl ConnectionValidator {
(&mut connection_id_bytes[..4]).copy_from_slice(&valid_until);
(&mut connection_id_bytes[4..]).copy_from_slice(&hash);
ConnectionId(i64::from_ne_bytes(connection_id_bytes))
ConnectionId::new(i64::from_ne_bytes(connection_id_bytes))
}
pub fn connection_id_valid(
@ -67,7 +67,7 @@ impl ConnectionValidator {
source_addr: CanonicalSocketAddr,
connection_id: ConnectionId,
) -> bool {
let bytes = connection_id.0.to_ne_bytes();
let bytes = connection_id.0.get().to_ne_bytes();
let (valid_until, hash) = bytes.split_at(4);
let valid_until: [u8; 4] = valid_until.try_into().unwrap();

View file

@ -53,7 +53,7 @@ pub fn run_swarm_worker(
&statistics_sender,
&mut torrents.ipv4,
request,
ip,
ip.into(),
peer_valid_until,
);
@ -66,7 +66,7 @@ pub fn run_swarm_worker(
&statistics_sender,
&mut torrents.ipv6,
request,
ip,
ip.into(),
peer_valid_until,
);
@ -126,18 +126,19 @@ fn handle_announce_request<I: Ip>(
peer_ip: I,
peer_valid_until: ValidUntil,
) -> AnnounceResponse<I> {
let max_num_peers_to_take: usize = if request.peers_wanted.0 <= 0 {
let max_num_peers_to_take: usize = if request.peers_wanted.0.get() <= 0 {
config.protocol.max_response_peers
} else {
::std::cmp::min(
config.protocol.max_response_peers,
request.peers_wanted.0.try_into().unwrap(),
request.peers_wanted.0.get().try_into().unwrap(),
)
};
let torrent_data = torrents.0.entry(request.info_hash).or_default();
let peer_status = PeerStatus::from_event_and_bytes_left(request.event, request.bytes_left);
let peer_status =
PeerStatus::from_event_and_bytes_left(request.event.into(), request.bytes_left);
torrent_data.update_peer(
config,
@ -156,10 +157,14 @@ fn handle_announce_request<I: Ip>(
};
AnnounceResponse {
transaction_id: request.transaction_id,
announce_interval: AnnounceInterval(config.protocol.peer_announce_interval),
leechers: NumberOfPeers(torrent_data.num_leechers().try_into().unwrap_or(i32::MAX)),
seeders: NumberOfPeers(torrent_data.num_seeders().try_into().unwrap_or(i32::MAX)),
fixed: AnnounceResponseFixedData {
transaction_id: request.transaction_id,
announce_interval: AnnounceInterval::new(config.protocol.peer_announce_interval),
leechers: NumberOfPeers::new(
torrent_data.num_leechers().try_into().unwrap_or(i32::MAX),
),
seeders: NumberOfPeers::new(torrent_data.num_seeders().try_into().unwrap_or(i32::MAX)),
},
peers: response_peers,
}
}
@ -168,8 +173,6 @@ fn handle_scrape_request<I: Ip>(
torrents: &mut TorrentMap<I>,
request: PendingScrapeRequest,
) -> PendingScrapeResponse {
const EMPTY_STATS: TorrentScrapeStatistics = create_torrent_scrape_statistics(0, 0);
let torrent_stats = request
.info_hashes
.into_iter()
@ -178,7 +181,7 @@ fn handle_scrape_request<I: Ip>(
.0
.get(&info_hash)
.map(|torrent_data| torrent_data.scrape_statistics())
.unwrap_or(EMPTY_STATS);
.unwrap_or_else(|| create_torrent_scrape_statistics(0, 0));
(i, stats)
})
@ -191,10 +194,10 @@ fn handle_scrape_request<I: Ip>(
}
#[inline(always)]
const fn create_torrent_scrape_statistics(seeders: i32, leechers: i32) -> TorrentScrapeStatistics {
fn create_torrent_scrape_statistics(seeders: i32, leechers: i32) -> TorrentScrapeStatistics {
TorrentScrapeStatistics {
seeders: NumberOfPeers(seeders),
completed: NumberOfDownloads(0), // No implementation planned
leechers: NumberOfPeers(leechers),
seeders: NumberOfPeers::new(seeders),
completed: NumberOfDownloads::new(0), // No implementation planned
leechers: NumberOfPeers::new(leechers),
}
}

View file

@ -1,5 +1,3 @@
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
@ -256,8 +254,8 @@ impl<I: Ip> TorrentMap<I> {
}
pub struct TorrentMaps {
pub ipv4: TorrentMap<Ipv4Addr>,
pub ipv6: TorrentMap<Ipv6Addr>,
pub ipv4: TorrentMap<Ipv4AddrBytes>,
pub ipv6: TorrentMap<Ipv6AddrBytes>,
}
impl Default for TorrentMaps {
@ -312,7 +310,6 @@ impl TorrentMaps {
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::net::Ipv4Addr;
use quickcheck::{quickcheck, TestResult};
use rand::thread_rng;
@ -326,10 +323,10 @@ mod tests {
peer_id
}
fn gen_peer(i: u32) -> Peer<Ipv4Addr> {
fn gen_peer(i: u32) -> Peer<Ipv4AddrBytes> {
Peer {
ip_address: Ipv4Addr::from(i.to_be_bytes()),
port: Port(1),
ip_address: Ipv4AddrBytes(i.to_be_bytes()),
port: Port::new(1),
is_seeder: false,
valid_until: ValidUntil::new(ServerStartInstant::new(), 0),
}
@ -341,7 +338,7 @@ mod tests {
let gen_num_peers = data.0 as u32;
let req_num_peers = data.1 as usize;
let mut peer_map: PeerMap<Ipv4Addr> = Default::default();
let mut peer_map: PeerMap<Ipv4AddrBytes> = Default::default();
let mut opt_sender_key = None;
let mut opt_sender_peer = None;

View file

@ -10,8 +10,8 @@ use anyhow::Context;
use aquatic_udp::{common::BUFFER_SIZE, config::Config};
use aquatic_udp_protocol::{
common::PeerId, AnnounceEvent, AnnounceRequest, ConnectRequest, ConnectionId, InfoHash,
NumberOfBytes, NumberOfPeers, PeerKey, Port, Request, Response, ScrapeRequest, ScrapeResponse,
TransactionId,
Ipv4AddrBytes, NumberOfBytes, NumberOfPeers, PeerKey, Port, Request, Response, ScrapeRequest,
ScrapeResponse, TransactionId,
};
// FIXME: should ideally try different ports and use sync primitives to find
@ -26,7 +26,7 @@ pub fn run_tracker(config: Config) {
pub fn connect(socket: &UdpSocket, tracker_addr: SocketAddr) -> anyhow::Result<ConnectionId> {
let request = Request::Connect(ConnectRequest {
transaction_id: TransactionId(0),
transaction_id: TransactionId::new(0),
});
let response = request_and_response(&socket, tracker_addr, request)?;
@ -55,17 +55,18 @@ pub fn announce(
let request = Request::Announce(AnnounceRequest {
connection_id,
transaction_id: TransactionId(0),
action_placeholder: Default::default(),
transaction_id: TransactionId::new(0),
info_hash,
peer_id,
bytes_downloaded: NumberOfBytes(0),
bytes_uploaded: NumberOfBytes(0),
bytes_left: NumberOfBytes(if seeder { 0 } else { 1 }),
event: AnnounceEvent::Started,
ip_address: None,
key: PeerKey(0),
peers_wanted: NumberOfPeers(peers_wanted as i32),
port: Port(peer_port),
bytes_downloaded: NumberOfBytes::new(0),
bytes_uploaded: NumberOfBytes::new(0),
bytes_left: NumberOfBytes::new(if seeder { 0 } else { 1 }),
event: AnnounceEvent::Started.into(),
ip_address: Ipv4AddrBytes([0; 4]),
key: PeerKey::new(0),
peers_wanted: NumberOfPeers::new(peers_wanted as i32),
port: Port::new(peer_port),
});
Ok(request_and_response(&socket, tracker_addr, request)?)
@ -79,7 +80,7 @@ pub fn scrape(
) -> anyhow::Result<ScrapeResponse> {
let request = Request::Scrape(ScrapeRequest {
connection_id,
transaction_id: TransactionId(0),
transaction_id: TransactionId::new(0),
info_hashes,
});

View file

@ -11,8 +11,8 @@ use std::{
use anyhow::Context;
use aquatic_udp::{common::BUFFER_SIZE, config::Config};
use aquatic_udp_protocol::{
common::PeerId, AnnounceEvent, AnnounceRequest, ConnectionId, InfoHash, NumberOfBytes,
NumberOfPeers, PeerKey, Port, Request, ScrapeRequest, TransactionId,
common::PeerId, AnnounceEvent, AnnounceRequest, ConnectionId, InfoHash, Ipv4AddrBytes,
NumberOfBytes, NumberOfPeers, PeerKey, Port, Request, ScrapeRequest, TransactionId,
};
#[test]
@ -40,22 +40,23 @@ fn test_invalid_connection_id() -> anyhow::Result<()> {
let announce_request = Request::Announce(AnnounceRequest {
connection_id: invalid_connection_id,
transaction_id: TransactionId(0),
action_placeholder: Default::default(),
transaction_id: TransactionId::new(0),
info_hash: InfoHash([0; 20]),
peer_id: PeerId([0; 20]),
bytes_downloaded: NumberOfBytes(0),
bytes_uploaded: NumberOfBytes(0),
bytes_left: NumberOfBytes(0),
event: AnnounceEvent::Started,
ip_address: None,
key: PeerKey(0),
peers_wanted: NumberOfPeers(10),
port: Port(1),
bytes_downloaded: NumberOfBytes::new(0),
bytes_uploaded: NumberOfBytes::new(0),
bytes_left: NumberOfBytes::new(0),
event: AnnounceEvent::Started.into(),
ip_address: Ipv4AddrBytes([0; 4]),
key: PeerKey::new(0),
peers_wanted: NumberOfPeers::new(10),
port: Port::new(1),
});
let scrape_request = Request::Scrape(ScrapeRequest {
connection_id: invalid_connection_id,
transaction_id: TransactionId(0),
transaction_id: TransactionId::new(0),
info_hashes: vec![InfoHash([0; 20])],
});

View file

@ -67,11 +67,11 @@ fn test_multiple_connect_announce_scrape() -> anyhow::Result<()> {
assert_eq!(announce_response.peers.len(), i.min(PEERS_WANTED));
assert_eq!(announce_response.seeders.0, num_seeders);
assert_eq!(announce_response.leechers.0, num_leechers);
assert_eq!(announce_response.fixed.seeders.0.get(), num_seeders);
assert_eq!(announce_response.fixed.leechers.0.get(), num_leechers);
let response_peer_ports: HashSet<u16, RandomState> =
HashSet::from_iter(announce_response.peers.iter().map(|p| p.port.0));
HashSet::from_iter(announce_response.peers.iter().map(|p| p.port.0.get()));
let expected_peer_ports: HashSet<u16, RandomState> =
HashSet::from_iter((0..i).map(|i| PEER_PORT_START + i as u16));
@ -89,10 +89,16 @@ fn test_multiple_connect_announce_scrape() -> anyhow::Result<()> {
)
.with_context(|| "scrape")?;
assert_eq!(scrape_response.torrent_stats[0].seeders.0, num_seeders);
assert_eq!(scrape_response.torrent_stats[0].leechers.0, num_leechers);
assert_eq!(scrape_response.torrent_stats[1].seeders.0, 0);
assert_eq!(scrape_response.torrent_stats[1].leechers.0, 0);
assert_eq!(
scrape_response.torrent_stats[0].seeders.0.get(),
num_seeders
);
assert_eq!(
scrape_response.torrent_stats[0].leechers.0.get(),
num_leechers
);
assert_eq!(scrape_response.torrent_stats[1].seeders.0.get(), 0);
assert_eq!(scrape_response.torrent_stats[1].leechers.0.get(), 0);
}
Ok(())