Skip to content

Commit

Permalink
aurel/connect-plan: Split insertion as plan/apply (#969)
Browse files Browse the repository at this point in the history
* aurel/connect-plan: Split insertion as plan/execute

* aurel/connect-plan: Collect and compare the graph changes

---------

Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
naure and Aurélien Nicolas authored Jan 27, 2025
1 parent a6de9cb commit 1932a45
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 22 deletions.
69 changes: 47 additions & 22 deletions iris-mpc-cpu/src/execution/hawk_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
session::{BootSession, Session, SessionId},
},
hawkers::aby3_store::{Aby3Store, SharedIrisesRef},
hnsw::HnswSearcher,
hnsw::{searcher::ConnectPlanV, HnswSearcher},
network::grpc::{GrpcConfig, GrpcNetworking},
proto_generated::party_node::party_node_server::PartyNodeServer,
protocol::ops::setup_replicated_prf,
Expand All @@ -18,7 +18,13 @@ use eyre::Result;
use hawk_pack::{graph_store::GraphMem, hawk_searcher::FurthestQueue, VectorStore};
use itertools::{izip, Itertools};
use rand::{thread_rng, Rng, SeedableRng};
use std::{collections::HashMap, ops::DerefMut, sync::Arc, time::Duration};
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
vec,
};
use tokio::{
sync::{mpsc, oneshot, RwLock},
task::JoinSet,
Expand Down Expand Up @@ -77,6 +83,7 @@ pub type SearchResult = (
);

pub type InsertPlan = InsertPlanV<Aby3Store>;
pub type ConnectPlan = ConnectPlanV<Aby3Store>;

#[derive(Debug)]
pub struct InsertPlanV<V: VectorStore> {
Expand Down Expand Up @@ -274,29 +281,42 @@ impl HawkActor {
&mut self,
sessions: &[HawkSessionRef],
plans: Vec<InsertPlan>,
) -> Result<()> {
let plans = join_plans(plans);
for plan in plans {
) -> Result<Vec<ConnectPlan>> {
let insert_plans = join_plans(plans);
let mut connect_plans = vec![];
for plan in insert_plans {
let mut session = sessions[0].write().await;
self.insert_one(&mut session, plan).await?;
let cp = self.insert_one(&mut session, plan).await?;
connect_plans.push(cp);
}
Ok(())
Ok(connect_plans)
}

// TODO: Remove `&mut self` requirement to support parallel sessions.
async fn insert_one(&mut self, session: &mut HawkSession, plan: InsertPlan) -> Result<()> {
let inserted = session.aby3_store.insert(&plan.query).await;
async fn insert_one(
&mut self,
session: &mut HawkSession,
insert_plan: InsertPlan,
) -> Result<ConnectPlan> {
let inserted = session.aby3_store.insert(&insert_plan.query).await;
let mut graph_store = self.graph_store.write().await;
self.search_params
.insert_from_search_results(

let connect_plan = self
.search_params
.insert_prepare(
&mut session.aby3_store,
graph_store.deref_mut(),
graph_store.deref(),
inserted,
plan.links,
plan.set_ep,
insert_plan.links,
insert_plan.set_ep,
)
.await;
Ok(())

self.search_params
.insert_apply(graph_store.deref_mut(), connect_plan.clone())
.await;

Ok(connect_plan)
}
}

Expand All @@ -305,7 +325,7 @@ struct HawkJob {
return_channel: oneshot::Sender<Result<HawkResult>>,
}

type HawkResult = ();
type HawkResult = Vec<ConnectPlan>;

/// HawkHandle is a handle to the HawkActor managing concurrency.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -340,11 +360,11 @@ impl HawkHandle {
.search_to_insert(&sessions, to_insert)
.await
.unwrap();
hawk_actor.insert(&sessions, plans).await.unwrap();
let connect_plans = hawk_actor.insert(&sessions, plans).await.unwrap();

println!("🎉 Inserted items into the database");

let _ = job.return_channel.send(Ok(()));
let _ = job.return_channel.send(Ok(connect_plans));
}
});

Expand Down Expand Up @@ -462,20 +482,25 @@ mod tests {
})
.collect_vec();

izip!(irises, handles.clone())
let all_plans = izip!(irises, handles.clone())
.map(|(share, handle)| async move {
handle
let plans = handle
.submit(HawkRequest {
my_iris_shares: share,
})
.await?;
Ok(())
Ok(plans)
})
.collect::<JoinSet<_>>()
.join_all()
.await
.into_iter()
.collect::<Result<()>>()?;
.collect::<Result<Vec<HawkResult>>>()?;

assert!(
all_plans.iter().all_equal(),
"All parties must agree on the graph changes"
);

Ok(())
}
Expand Down
109 changes: 109 additions & 0 deletions iris-mpc-cpu/src/hnsw/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub use hawk_pack::data_structures::queue::{
FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV,
};
use hawk_pack::{GraphStore, VectorStore};
use itertools::izip;
use rand::RngCore;
use rand_distr::{Distribution, Geometric};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -169,6 +170,24 @@ pub struct HnswSearcher {
pub params: HnswParams,
}

pub type ConnectPlanV<V> =
ConnectPlan<<V as VectorStore>::VectorRef, <V as VectorStore>::DistanceRef>;
type ConnectPlanLayerV<V> =
ConnectPlanLayer<<V as VectorStore>::VectorRef, <V as VectorStore>::DistanceRef>;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ConnectPlan<Vector, Distance> {
inserted_vector: Vector,
layers: Vec<ConnectPlanLayer<Vector, Distance>>,
set_ep: bool,
}

#[derive(Clone, Debug, PartialEq, Eq)]
struct ConnectPlanLayer<Vector, Distance> {
neighbors: FurthestQueue<Vector, Distance>,
n_links: Vec<FurthestQueue<Vector, Distance>>,
}

// TODO remove default value; this varies too much between applications
// to make sense to specify something "obvious"
impl Default for HnswSearcher {
Expand All @@ -181,6 +200,49 @@ impl Default for HnswSearcher {

#[allow(non_snake_case)]
impl HnswSearcher {
/// Two-step variant of `connect_bidir`: prepare.
async fn connect_prepare<V: VectorStore, G: GraphStore<V>>(
&self,
vector_store: &mut V,
graph_store: &G,
q: &V::VectorRef,
mut neighbors: FurthestQueueV<V>,
lc: usize,
) -> ConnectPlanLayerV<V> {
let M = self.params.get_M(lc);
let max_links = self.params.get_M_max(lc);

neighbors.trim_to_k_nearest(M);

// Connect all n -> q.
let mut n_links = vec![];
for (n, nq) in neighbors.iter() {
let mut links = graph_store.get_links(n, lc).await;
links.insert(vector_store, q.clone(), nq.clone()).await;
links.trim_to_k_nearest(max_links);
n_links.push(links);
}

ConnectPlanLayer { neighbors, n_links }
}

/// Two-step variant of `connect_bidir`: execute.
async fn connect_apply<V: VectorStore, G: GraphStore<V>>(
&self,
graph_store: &mut G,
q: V::VectorRef,
lc: usize,
plan: ConnectPlanLayerV<V>,
) {
// Connect all n -> q.
for ((n, _nq), links) in izip!(plan.neighbors.iter(), plan.n_links) {
graph_store.set_links(n.clone(), links, lc).await;
}

// Connect q -> all n.
graph_store.set_links(q, plan.neighbors, lc).await;
}

async fn connect_bidir<V: VectorStore, G: GraphStore<V>>(
&self,
vector_store: &mut V,
Expand Down Expand Up @@ -422,6 +484,53 @@ impl HnswSearcher {
(links, set_ep)
}

/// Two-step variant of `insert_from_search_results`: prepare.
pub async fn insert_prepare<V: VectorStore, G: GraphStore<V>>(
&self,
vector_store: &mut V,
graph_store: &G,
inserted_vector: V::VectorRef,
links: Vec<FurthestQueueV<V>>,
set_ep: bool,
) -> ConnectPlanV<V> {
let mut plan = ConnectPlan {
inserted_vector: inserted_vector.clone(),
layers: vec![],
set_ep,
};

// Connect the new vector to its neighbors in each layer.
for (lc, layer_links) in links.into_iter().enumerate() {
let lp = self
.connect_prepare(vector_store, graph_store, &inserted_vector, layer_links, lc)
.await;
plan.layers.push(lp);
}

plan
}

/// Two-step variant of `insert_from_search_results`: execute.
pub async fn insert_apply<V: VectorStore, G: GraphStore<V>>(
&self,
graph_store: &mut G,
plan: ConnectPlanV<V>,
) {
// If required, set vector as new entry point
if plan.set_ep {
let insertion_layer = plan.layers.len() - 1;
graph_store
.set_entry_point(plan.inserted_vector.clone(), insertion_layer)
.await;
}

// Connect the new vector to its neighbors in each layer.
for (lc, layer_plan) in plan.layers.into_iter().enumerate() {
self.connect_apply(graph_store, plan.inserted_vector.clone(), lc, layer_plan)
.await;
}
}

/// Insert a vector using the search results from `search_to_insert`,
/// that is the nearest neighbor links at each insertion layer, and a flag
/// indicating whether the vector is to be inserted as the new entry point.
Expand Down

0 comments on commit 1932a45

Please sign in to comment.