From 1932a457f1fcb17053eeac4822f52e5131ce38ce Mon Sep 17 00:00:00 2001 From: naure Date: Mon, 27 Jan 2025 13:02:34 +0100 Subject: [PATCH] aurel/connect-plan: Split insertion as plan/apply (#969) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * aurel/connect-plan: Split insertion as plan/execute * aurel/connect-plan: Collect and compare the graph changes --------- Co-authored-by: Aurélien Nicolas --- iris-mpc-cpu/src/execution/hawk_main.rs | 69 ++++++++++----- iris-mpc-cpu/src/hnsw/searcher.rs | 109 ++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 22 deletions(-) diff --git a/iris-mpc-cpu/src/execution/hawk_main.rs b/iris-mpc-cpu/src/execution/hawk_main.rs index 851db42d2..52398148d 100644 --- a/iris-mpc-cpu/src/execution/hawk_main.rs +++ b/iris-mpc-cpu/src/execution/hawk_main.rs @@ -7,7 +7,7 @@ use crate::{ session::{BootSession, Session, SessionId}, }, hawkers::aby3_store::{Aby3Store, SharedIrisesRef}, - hnsw::HnswSearcher, + hnsw::{searcher::ConnectPlanV, HnswSearcher}, network::grpc::{GrpcConfig, GrpcNetworking}, proto_generated::party_node::party_node_server::PartyNodeServer, protocol::ops::setup_replicated_prf, @@ -18,7 +18,13 @@ use eyre::Result; use hawk_pack::{graph_store::GraphMem, hawk_searcher::FurthestQueue, VectorStore}; use itertools::{izip, Itertools}; use rand::{thread_rng, Rng, SeedableRng}; -use std::{collections::HashMap, ops::DerefMut, sync::Arc, time::Duration}; +use std::{ + collections::HashMap, + ops::{Deref, DerefMut}, + sync::Arc, + time::Duration, + vec, +}; use tokio::{ sync::{mpsc, oneshot, RwLock}, task::JoinSet, @@ -77,6 +83,7 @@ pub type SearchResult = ( ); pub type InsertPlan = InsertPlanV; +pub type ConnectPlan = ConnectPlanV; #[derive(Debug)] pub struct InsertPlanV { @@ -274,29 +281,42 @@ impl HawkActor { &mut self, sessions: &[HawkSessionRef], plans: Vec, - ) -> Result<()> { - let plans = join_plans(plans); - for plan in plans { + ) -> Result> { + let insert_plans = join_plans(plans); + let mut connect_plans = vec![]; + for plan in insert_plans { let mut session = sessions[0].write().await; - self.insert_one(&mut session, plan).await?; + let cp = self.insert_one(&mut session, plan).await?; + connect_plans.push(cp); } - Ok(()) + Ok(connect_plans) } // TODO: Remove `&mut self` requirement to support parallel sessions. - async fn insert_one(&mut self, session: &mut HawkSession, plan: InsertPlan) -> Result<()> { - let inserted = session.aby3_store.insert(&plan.query).await; + async fn insert_one( + &mut self, + session: &mut HawkSession, + insert_plan: InsertPlan, + ) -> Result { + let inserted = session.aby3_store.insert(&insert_plan.query).await; let mut graph_store = self.graph_store.write().await; - self.search_params - .insert_from_search_results( + + let connect_plan = self + .search_params + .insert_prepare( &mut session.aby3_store, - graph_store.deref_mut(), + graph_store.deref(), inserted, - plan.links, - plan.set_ep, + insert_plan.links, + insert_plan.set_ep, ) .await; - Ok(()) + + self.search_params + .insert_apply(graph_store.deref_mut(), connect_plan.clone()) + .await; + + Ok(connect_plan) } } @@ -305,7 +325,7 @@ struct HawkJob { return_channel: oneshot::Sender>, } -type HawkResult = (); +type HawkResult = Vec; /// HawkHandle is a handle to the HawkActor managing concurrency. #[derive(Clone, Debug)] @@ -340,11 +360,11 @@ impl HawkHandle { .search_to_insert(&sessions, to_insert) .await .unwrap(); - hawk_actor.insert(&sessions, plans).await.unwrap(); + let connect_plans = hawk_actor.insert(&sessions, plans).await.unwrap(); println!("🎉 Inserted items into the database"); - let _ = job.return_channel.send(Ok(())); + let _ = job.return_channel.send(Ok(connect_plans)); } }); @@ -462,20 +482,25 @@ mod tests { }) .collect_vec(); - izip!(irises, handles.clone()) + let all_plans = izip!(irises, handles.clone()) .map(|(share, handle)| async move { - handle + let plans = handle .submit(HawkRequest { my_iris_shares: share, }) .await?; - Ok(()) + Ok(plans) }) .collect::>() .join_all() .await .into_iter() - .collect::>()?; + .collect::>>()?; + + assert!( + all_plans.iter().all_equal(), + "All parties must agree on the graph changes" + ); Ok(()) } diff --git a/iris-mpc-cpu/src/hnsw/searcher.rs b/iris-mpc-cpu/src/hnsw/searcher.rs index 6ad92d1c6..e8492301b 100644 --- a/iris-mpc-cpu/src/hnsw/searcher.rs +++ b/iris-mpc-cpu/src/hnsw/searcher.rs @@ -8,6 +8,7 @@ pub use hawk_pack::data_structures::queue::{ FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV, }; use hawk_pack::{GraphStore, VectorStore}; +use itertools::izip; use rand::RngCore; use rand_distr::{Distribution, Geometric}; use serde::{Deserialize, Serialize}; @@ -169,6 +170,24 @@ pub struct HnswSearcher { pub params: HnswParams, } +pub type ConnectPlanV = + ConnectPlan<::VectorRef, ::DistanceRef>; +type ConnectPlanLayerV = + ConnectPlanLayer<::VectorRef, ::DistanceRef>; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ConnectPlan { + inserted_vector: Vector, + layers: Vec>, + set_ep: bool, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct ConnectPlanLayer { + neighbors: FurthestQueue, + n_links: Vec>, +} + // TODO remove default value; this varies too much between applications // to make sense to specify something "obvious" impl Default for HnswSearcher { @@ -181,6 +200,49 @@ impl Default for HnswSearcher { #[allow(non_snake_case)] impl HnswSearcher { + /// Two-step variant of `connect_bidir`: prepare. + async fn connect_prepare>( + &self, + vector_store: &mut V, + graph_store: &G, + q: &V::VectorRef, + mut neighbors: FurthestQueueV, + lc: usize, + ) -> ConnectPlanLayerV { + let M = self.params.get_M(lc); + let max_links = self.params.get_M_max(lc); + + neighbors.trim_to_k_nearest(M); + + // Connect all n -> q. + let mut n_links = vec![]; + for (n, nq) in neighbors.iter() { + let mut links = graph_store.get_links(n, lc).await; + links.insert(vector_store, q.clone(), nq.clone()).await; + links.trim_to_k_nearest(max_links); + n_links.push(links); + } + + ConnectPlanLayer { neighbors, n_links } + } + + /// Two-step variant of `connect_bidir`: execute. + async fn connect_apply>( + &self, + graph_store: &mut G, + q: V::VectorRef, + lc: usize, + plan: ConnectPlanLayerV, + ) { + // Connect all n -> q. + for ((n, _nq), links) in izip!(plan.neighbors.iter(), plan.n_links) { + graph_store.set_links(n.clone(), links, lc).await; + } + + // Connect q -> all n. + graph_store.set_links(q, plan.neighbors, lc).await; + } + async fn connect_bidir>( &self, vector_store: &mut V, @@ -422,6 +484,53 @@ impl HnswSearcher { (links, set_ep) } + /// Two-step variant of `insert_from_search_results`: prepare. + pub async fn insert_prepare>( + &self, + vector_store: &mut V, + graph_store: &G, + inserted_vector: V::VectorRef, + links: Vec>, + set_ep: bool, + ) -> ConnectPlanV { + let mut plan = ConnectPlan { + inserted_vector: inserted_vector.clone(), + layers: vec![], + set_ep, + }; + + // Connect the new vector to its neighbors in each layer. + for (lc, layer_links) in links.into_iter().enumerate() { + let lp = self + .connect_prepare(vector_store, graph_store, &inserted_vector, layer_links, lc) + .await; + plan.layers.push(lp); + } + + plan + } + + /// Two-step variant of `insert_from_search_results`: execute. + pub async fn insert_apply>( + &self, + graph_store: &mut G, + plan: ConnectPlanV, + ) { + // If required, set vector as new entry point + if plan.set_ep { + let insertion_layer = plan.layers.len() - 1; + graph_store + .set_entry_point(plan.inserted_vector.clone(), insertion_layer) + .await; + } + + // Connect the new vector to its neighbors in each layer. + for (lc, layer_plan) in plan.layers.into_iter().enumerate() { + self.connect_apply(graph_store, plan.inserted_vector.clone(), lc, layer_plan) + .await; + } + } + /// Insert a vector using the search results from `search_to_insert`, /// that is the nearest neighbor links at each insertion layer, and a flag /// indicating whether the vector is to be inserted as the new entry point.