aquatic_udp: glommio: mostly implement scrape request support

This commit is contained in:
Joakim Frostegård 2021-10-23 14:42:42 +02:00
parent 08920fce5f
commit 96196239f5
2 changed files with 88 additions and 28 deletions

View file

@ -42,7 +42,7 @@ impl Default for AccessListConfig {
} }
} }
#[derive(Default)] #[derive(Default, Clone)]
pub struct AccessList(HashSet<[u8; 20]>); pub struct AccessList(HashSet<[u8; 20]>);
impl AccessList { impl AccessList {

View file

@ -1,6 +1,5 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::io::Cursor; use std::io::Cursor;
use std::iter::FromIterator;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::rc::Rc; use std::rc::Rc;
use std::sync::{ use std::sync::{
@ -31,7 +30,6 @@ use crate::config::Config;
struct PendingScrapeResponse { struct PendingScrapeResponse {
pending_worker_responses: usize, pending_worker_responses: usize,
valid_until: ValidUntil, valid_until: ValidUntil,
src: SocketAddr,
stats: Vec<TorrentScrapeStatistics>, stats: Vec<TorrentScrapeStatistics>,
} }
@ -39,16 +37,39 @@ struct PendingScrapeResponse {
struct PendingScrapeResponses(HashMap<TransactionId, PendingScrapeResponse>); struct PendingScrapeResponses(HashMap<TransactionId, PendingScrapeResponse>);
impl PendingScrapeResponses { impl PendingScrapeResponses {
fn insert_empty(&mut self, transaction_id: TransactionId, src: SocketAddr, pending_worker_responses: usize, valid_until: ValidUntil) { fn prepare(&mut self, transaction_id: TransactionId, pending_worker_responses: usize, valid_until: ValidUntil) {
let pending = PendingScrapeResponse { let pending = PendingScrapeResponse {
pending_worker_responses, pending_worker_responses,
valid_until, valid_until,
src,
stats: Vec::new(), stats: Vec::new(),
}; };
self.0.insert(transaction_id, pending); self.0.insert(transaction_id, pending);
} }
fn add_and_get_finished(&mut self, mut response: ScrapeResponse) -> Option<ScrapeResponse> {
let finished = if let Some(r) = self.0.get_mut(&response.transaction_id) {
r.pending_worker_responses -= 1;
r.stats.append(&mut response.torrent_stats);
r.pending_worker_responses == 0
} else {
::log::warn!("PendingScrapeResponses.add didn't find PendingScrapeResponse in map");
false
};
if finished {
let r = self.0.remove(&response.transaction_id).unwrap();
Some(ScrapeResponse {
transaction_id: response.transaction_id,
torrent_stats: r.stats,
})
} else {
None
}
}
} }
pub async fn run_socket_worker( pub async fn run_socket_worker(
@ -78,9 +99,10 @@ pub async fn run_socket_worker(
let response_consumer_index = response_receivers.consumer_id().unwrap(); let response_consumer_index = response_receivers.consumer_id().unwrap();
// FIXME: needs cleaning
let pending_scrape_responses = Rc::new(RefCell::new(PendingScrapeResponses::default())); let pending_scrape_responses = Rc::new(RefCell::new(PendingScrapeResponses::default()));
spawn_local(read_requests( spawn_local(enclose!((pending_scrape_responses) read_requests(
config.clone(), config.clone(),
request_senders, request_senders,
response_consumer_index, response_consumer_index,
@ -88,18 +110,19 @@ pub async fn run_socket_worker(
socket.clone(), socket.clone(),
pending_scrape_responses, pending_scrape_responses,
access_list, access_list,
)) )))
.detach(); .detach();
for (_, receiver) in response_receivers.streams().into_iter() { for (_, receiver) in response_receivers.streams().into_iter() {
spawn_local(send_responses( spawn_local(enclose!((pending_scrape_responses) handle_shared_responses(
socket.clone(), socket.clone(),
receiver.map(|(response, addr)| (response.into(), addr)), pending_scrape_responses,
)) receiver,
)))
.detach(); .detach();
} }
send_responses(socket, local_receiver.stream()).await; send_local_responses(socket, local_receiver.stream()).await;
} }
async fn read_requests( async fn read_requests(
@ -215,9 +238,8 @@ async fn read_requests(
.info_hashes.push(info_hash); .info_hashes.push(info_hash);
} }
pending_scrape_responses.borrow_mut().insert_empty( pending_scrape_responses.borrow_mut().prepare(
request.transaction_id, request.transaction_id,
src,
consumer_requests.len(), consumer_requests.len(),
connection_valid_until.borrow().to_owned(), // FIXME: use seperate ValidUntil connection_valid_until.borrow().to_owned(), // FIXME: use seperate ValidUntil
); );
@ -262,32 +284,70 @@ async fn read_requests(
} }
} }
async fn send_responses<S>(socket: Rc<UdpSocket>, mut stream: S) async fn handle_shared_responses<S>(
socket: Rc<UdpSocket>,
pending_scrape_responses: Rc<RefCell<PendingScrapeResponses>>,
mut stream: S,
) where
S: Stream<Item = (ConnectedResponse, SocketAddr)> + ::std::marker::Unpin,
{
let mut buf = [0u8; MAX_PACKET_SIZE];
let mut buf = Cursor::new(&mut buf[..]);
while let Some((response, addr)) = stream.next().await {
let opt_response = match response {
ConnectedResponse::Announce(response) => Some((Response::Announce(response), addr)),
ConnectedResponse::Scrape(response) => {
pending_scrape_responses
.borrow_mut()
.add_and_get_finished(response)
.map(|response| (Response::Scrape(response), addr))
},
};
if let Some((response, addr)) = opt_response {
write_response_to_socket(&socket, &mut buf, addr, response).await;
}
yield_if_needed().await;
}
}
async fn send_local_responses<S>(socket: Rc<UdpSocket>, mut stream: S)
where where
S: Stream<Item = (Response, SocketAddr)> + ::std::marker::Unpin, S: Stream<Item = (Response, SocketAddr)> + ::std::marker::Unpin,
{ {
let mut buf = [0u8; MAX_PACKET_SIZE]; let mut buf = [0u8; MAX_PACKET_SIZE];
let mut buf = Cursor::new(&mut buf[..]); let mut buf = Cursor::new(&mut buf[..]);
while let Some((response, src)) = stream.next().await { while let Some((response, addr)) = stream.next().await {
buf.set_position(0); write_response_to_socket(&socket, &mut buf, addr, response).await;
::log::debug!("preparing to send response: {:?}", response.clone());
response
.write(&mut buf, ip_version_from_ip(src.ip()))
.expect("write response");
let position = buf.position() as usize;
if let Err(err) = socket.send_to(&buf.get_ref()[..position], src).await {
::log::info!("send_to failed: {:?}", err);
}
yield_if_needed().await; yield_if_needed().await;
} }
} }
async fn write_response_to_socket(
socket: &Rc<UdpSocket>,
buf: &mut Cursor<&mut [u8]>,
addr: SocketAddr,
response: Response,
) {
buf.set_position(0);
::log::debug!("preparing to send response: {:?}", response.clone());
response
.write(buf, ip_version_from_ip(addr.ip()))
.expect("write response");
let position = buf.position() as usize;
if let Err(err) = socket.send_to(&buf.get_ref()[..position], addr).await {
::log::info!("send_to failed: {:?}", err);
}
}
fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize { fn calculate_request_consumer_index(config: &Config, info_hash: InfoHash) -> usize {
(info_hash.0[0] as usize) % config.request_workers (info_hash.0[0] as usize) % config.request_workers
} }