ws: improve socket worker connection code

This commit is contained in:
Joakim Frostegård 2023-11-17 00:15:41 +01:00
parent fe5ccf6646
commit 7b2a7a4f46

View file

@ -229,11 +229,14 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
match &message { match &message {
tungstenite::Message::Text(_) | tungstenite::Message::Binary(_) => { tungstenite::Message::Text(_) | tungstenite::Message::Binary(_) => {
match InMessage::from_ws_message(message) { match InMessage::from_ws_message(message) {
Ok(in_message) => { Ok(InMessage::AnnounceRequest(request)) => {
self.handle_in_message(in_message).await?; self.handle_announce_request(request).await?;
}
Ok(InMessage::ScrapeRequest(request)) => {
self.handle_scrape_request(request).await?;
} }
Err(err) => { Err(err) => {
::log::debug!("Couldn't parse in_message: {:?}", err); ::log::debug!("Couldn't parse in_message: {:#}", err);
self.send_error_response("Invalid request".into(), None, None) self.send_error_response("Invalid request".into(), None, None)
.await?; .await?;
@ -261,9 +264,7 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
} }
} }
async fn handle_in_message(&mut self, in_message: InMessage) -> anyhow::Result<()> { async fn handle_announce_request(&mut self, request: AnnounceRequest) -> anyhow::Result<()> {
match in_message {
InMessage::AnnounceRequest(announce_request) => {
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
::metrics::increment_counter!( ::metrics::increment_counter!(
"aquatic_requests_total", "aquatic_requests_total",
@ -272,20 +273,19 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
"worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(),
); );
let info_hash = announce_request.info_hash; let info_hash = request.info_hash;
if self if self
.access_list_cache .access_list_cache
.load() .load()
.allows(self.config.access_list.mode, &info_hash.0) .allows(self.config.access_list.mode, &info_hash.0)
{ {
let mut announced_info_hashes = let mut announced_info_hashes = self.clean_up_data.announced_info_hashes.borrow_mut();
self.clean_up_data.announced_info_hashes.borrow_mut();
// Store peer id / check if stored peer id matches // Store peer id / check if stored peer id matches
match announced_info_hashes.entry(announce_request.info_hash) { match announced_info_hashes.entry(request.info_hash) {
Entry::Occupied(entry) => { Entry::Occupied(entry) => {
if *entry.get() != announce_request.peer_id { if *entry.get() != request.peer_id {
// Drop Rc borrow before awaiting // Drop Rc borrow before awaiting
drop(announced_info_hashes); drop(announced_info_hashes);
@ -302,7 +302,7 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
} }
} }
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
entry.insert(announce_request.peer_id); entry.insert(request.peer_id);
// Set peer client info if not set // Set peer client info if not set
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
@ -310,7 +310,7 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
&& self.config.metrics.peer_clients && self.config.metrics.peer_clients
&& self.clean_up_data.opt_peer_client.borrow().is_none() && self.clean_up_data.opt_peer_client.borrow().is_none()
{ {
let peer_id = aquatic_peer_id::PeerId(announce_request.peer_id.0); let peer_id = aquatic_peer_id::PeerId(request.peer_id.0);
let client = peer_id.client(); let client = peer_id.client();
let prefix = peer_id.first_8_bytes_hex().to_string(); let prefix = peer_id.first_8_bytes_hex().to_string();
@ -328,23 +328,21 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
); );
} }
*self.clean_up_data.opt_peer_client.borrow_mut() = *self.clean_up_data.opt_peer_client.borrow_mut() = Some((client, prefix));
Some((client, prefix));
}; };
} }
} }
if let Some(AnnounceEvent::Stopped) = announce_request.event { if let Some(AnnounceEvent::Stopped) = request.event {
announced_info_hashes.remove(&announce_request.info_hash); announced_info_hashes.remove(&request.info_hash);
} }
// Drop Rc borrow before awaiting // Drop Rc borrow before awaiting
drop(announced_info_hashes); drop(announced_info_hashes);
let in_message = InMessage::AnnounceRequest(announce_request); let in_message = InMessage::AnnounceRequest(request);
let consumer_index = let consumer_index = calculate_in_message_consumer_index(&self.config, info_hash);
calculate_in_message_consumer_index(&self.config, info_hash);
// Only fails when receiver is closed // Only fails when receiver is closed
self.in_message_senders self.in_message_senders
@ -362,8 +360,11 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
) )
.await?; .await?;
} }
Ok(())
} }
InMessage::ScrapeRequest(ScrapeRequest { info_hashes, .. }) => {
async fn handle_scrape_request(&mut self, request: ScrapeRequest) -> anyhow::Result<()> {
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
::metrics::increment_counter!( ::metrics::increment_counter!(
"aquatic_requests_total", "aquatic_requests_total",
@ -372,7 +373,7 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
"worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(), "worker_index" => WORKER_INDEX.with(|index| index.get()).to_string(),
); );
let info_hashes = if let Some(info_hashes) = info_hashes { let info_hashes = if let Some(info_hashes) = request.info_hashes {
info_hashes info_hashes
} else { } else {
// If request.info_hashes is empty, don't return scrape for all // If request.info_hashes is empty, don't return scrape for all
@ -425,8 +426,6 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionReader<S> {
.await .await
.unwrap(); .unwrap();
} }
}
}
Ok(()) Ok(())
} }
@ -485,31 +484,28 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionWriter<S> {
.pending_scrape_id .pending_scrape_id
.expect("meta.pending_scrape_id not set"); .expect("meta.pending_scrape_id not set");
let finished = if let Some(pending) = Slab::get_mut( let mut pending_responses = self.pending_scrape_slab.borrow_mut();
&mut RefCell::borrow_mut(&self.pending_scrape_slab),
pending_scrape_id.0 as usize,
) {
pending.stats.extend(out_message.files);
pending.pending_worker_out_messages -= 1;
pending.pending_worker_out_messages == 0 let pending_response = pending_responses
} else { .get_mut(pending_scrape_id.0 as usize)
return Err(anyhow::anyhow!("pending scrape not found in slab")); .ok_or(anyhow::anyhow!("pending scrape not found in slab"))?;
};
if finished { pending_response.stats.extend(out_message.files);
let out_message = { pending_response.pending_worker_out_messages -= 1;
let mut slab = RefCell::borrow_mut(&self.pending_scrape_slab);
let pending = slab.remove(pending_scrape_id.0 as usize); if pending_response.pending_worker_out_messages == 0 {
let pending_response =
pending_responses.remove(pending_scrape_id.0 as usize);
slab.shrink_to_fit(); pending_responses.shrink_to_fit();
OutMessage::ScrapeResponse(ScrapeResponse { let out_message = OutMessage::ScrapeResponse(ScrapeResponse {
action: ScrapeAction::Scrape, action: ScrapeAction::Scrape,
files: pending.stats, files: pending_response.stats,
}) });
};
// Drop Rc borrow before awaiting
drop(pending_responses);
self.send_out_message(&out_message).await?; self.send_out_message(&out_message).await?;
} }
@ -522,18 +518,16 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionWriter<S> {
} }
async fn send_out_message(&mut self, out_message: &OutMessage) -> anyhow::Result<()> { async fn send_out_message(&mut self, out_message: &OutMessage) -> anyhow::Result<()> {
let result = timeout(Duration::from_secs(10), async { timeout(Duration::from_secs(10), async {
let result = Ok(futures::SinkExt::send(&mut self.ws_out, out_message.to_ws_message()).await)
futures::SinkExt::send(&mut self.ws_out, out_message.to_ws_message()).await;
Ok(result)
}) })
.await; .await
.map_err(|err| {
anyhow::anyhow!("send_out_message: sending to peer took too long: {:#}", err)
})?
.with_context(|| "send_out_message")?;
match result { if let OutMessage::AnnounceResponse(_) | OutMessage::ScrapeResponse(_) = out_message {
Ok(Ok(())) => {
if let OutMessage::AnnounceResponse(_) | OutMessage::ScrapeResponse(_) = out_message
{
*self.connection_valid_until.borrow_mut() = ValidUntil::new( *self.connection_valid_until.borrow_mut() = ValidUntil::new(
self.server_start_instant, self.server_start_instant,
self.config.cleaning.max_connection_idle, self.config.cleaning.max_connection_idle,
@ -582,13 +576,6 @@ impl<S: futures::AsyncRead + futures::AsyncWrite + Unpin> ConnectionWriter<S> {
Ok(()) Ok(())
} }
Ok(Err(err)) => Err(err.into()),
Err(err) => Err(anyhow::anyhow!(
"send_out_message: sending to peer took too long: {}",
err
)),
}
}
} }
/// Data stored with connection needed for cleanup after it closes /// Data stored with connection needed for cleanup after it closes