From 7b46a2ede5d1a00233511c1beb188a6bff79fb15 Mon Sep 17 00:00:00 2001
From: Thomas Eizinger <thomas@eizinger.io>
Date: Mon, 4 Dec 2023 16:29:28 +1100
Subject: [PATCH] Don't treat `QueryId`s as unique

---
 protocols/kad/src/handler.rs | 67 +++++++++++++++++++-----------------
 1 file changed, 36 insertions(+), 31 deletions(-)

diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs
index 318261d8d21..5e7c2e21b8b 100644
--- a/protocols/kad/src/handler.rs
+++ b/protocols/kad/src/handler.rs
@@ -60,7 +60,8 @@ pub struct Handler {
     next_connec_unique_id: UniqueConnecId,
 
     /// List of active outbound streams.
-    outbound_substreams: futures_bounded::FuturesMap<QueryId, io::Result<Option<KadResponseMsg>>>,
+    outbound_substreams:
+        futures_bounded::FuturesTupleSet<io::Result<Option<KadResponseMsg>>, QueryId>,
 
     /// Contains one [`oneshot::Sender`] per outbound stream that we have requested.
     pending_streams:
@@ -453,7 +454,7 @@ impl Handler {
             remote_peer_id,
             next_connec_unique_id: UniqueConnecId(0),
             inbound_substreams: Default::default(),
-            outbound_substreams: futures_bounded::FuturesMap::new(
+            outbound_substreams: futures_bounded::FuturesTupleSet::new(
                 Duration::from_secs(10),
                 MAX_NUM_STREAMS,
             ),
@@ -552,32 +553,36 @@ impl Handler {
         let (sender, receiver) = oneshot::channel();
 
         self.pending_streams.push_back(sender);
-        let result = self.outbound_substreams.try_push(id, async move {
-            let mut stream = receiver
-                .await
-                .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
-                .map_err(|e| match e {
-                    StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(),
-                    StreamUpgradeError::Apply(e) => e,
-                    StreamUpgradeError::NegotiationFailed => {
-                        io::Error::new(io::ErrorKind::ConnectionRefused, "protocol not supported")
-                    }
-                    StreamUpgradeError::Io(e) => e,
-                })?;
-
-            let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. });
-
-            stream.send(msg).await?;
-            stream.close().await?;
-
-            if !has_answer {
-                return Ok(None);
-            }
+        let result = self.outbound_substreams.try_push(
+            async move {
+                let mut stream = receiver
+                    .await
+                    .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
+                    .map_err(|e| match e {
+                        StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(),
+                        StreamUpgradeError::Apply(e) => e,
+                        StreamUpgradeError::NegotiationFailed => io::Error::new(
+                            io::ErrorKind::ConnectionRefused,
+                            "protocol not supported",
+                        ),
+                        StreamUpgradeError::Io(e) => e,
+                    })?;
+
+                let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. });
+
+                stream.send(msg).await?;
+                stream.close().await?;
+
+                if !has_answer {
+                    return Ok(None);
+                }
 
-            let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??;
+                let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??;
 
-            Ok(Some(msg))
-        });
+                Ok(Some(msg))
+            },
+            id,
+        );
 
         debug_assert!(
             result.is_ok(),
@@ -728,15 +733,15 @@ impl ConnectionHandler for Handler {
             }
 
             match self.outbound_substreams.poll_unpin(cx) {
-                Poll::Ready((query, Ok(Ok(Some(response))))) => {
+                Poll::Ready((Ok(Ok(Some(response))), query_id)) => {
                     return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
-                        process_kad_response(response, query),
+                        process_kad_response(response, query_id),
                     ))
                 }
-                Poll::Ready((_, Ok(Ok(None)))) => {
+                Poll::Ready((Ok(Ok(None)), _)) => {
                     continue;
                 }
-                Poll::Ready((query_id, Ok(Err(e)))) => {
+                Poll::Ready((Ok(Err(e)), query_id)) => {
                     return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
                         HandlerEvent::QueryError {
                             error: HandlerQueryErr::Io(e),
@@ -744,7 +749,7 @@ impl ConnectionHandler for Handler {
                         },
                     ))
                 }
-                Poll::Ready((query_id, Err(_timeout))) => {
+                Poll::Ready((Err(_timeout), query_id)) => {
                     return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
                         HandlerEvent::QueryError {
                             error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()),