Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POP-2034] LUC implementation on iric-mpc #951

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
09fa7b1
WIP
carlomazzaferro Jan 16, 2025
11ea27f
Merge branch 'main' of github.com:worldcoin/iris-mpc into feat/LUC-wi…
carlomazzaferro Jan 16, 2025
80f4018
reduce args, calculate on kernel
carlomazzaferro Jan 17, 2025
6be106d
populate or policy bitmap and reduce argument list
carlomazzaferro Jan 17, 2025
dfe2829
address PR comments
carlomazzaferro Jan 17, 2025
49ff959
use htod_on_stream_sync for mem allocation, remove pre-allocation
carlomazzaferro Jan 18, 2025
be2fc7b
do no fork streams, use existing
carlomazzaferro Jan 19, 2025
c1d9a16
clippy
carlomazzaferro Jan 19, 2025
d73a77a
generate masks and templates for left and right
carlomazzaferro Jan 20, 2025
a8a2b8e
wrap into structs
carlomazzaferro Jan 20, 2025
af65bfb
doc
carlomazzaferro Jan 20, 2025
62b812d
Merge branch 'chore/improve-gpu-tests' into feat/LUC-with-bitmaps
carlomazzaferro Jan 20, 2025
60bd7b1
implement OR rule tests
carlomazzaferro Jan 20, 2025
03d6af3
merge main
carlomazzaferro Jan 20, 2025
d3fbd87
add option, e2e on or rule
carlomazzaferro Jan 20, 2025
fac30ea
merge main
carlomazzaferro Jan 21, 2025
f51c122
debug log
carlomazzaferro Jan 21, 2025
5a5dce8
remove unneded division, make it pretty
carlomazzaferro Jan 21, 2025
a79cb51
remove unneded division, make it pretty
carlomazzaferro Jan 21, 2025
424926d
remove leftovers
carlomazzaferro Jan 21, 2025
5a90799
address test pr comments
carlomazzaferro Jan 22, 2025
fec1b79
do not assert left = right partial matches for or ruling
carlomazzaferro Jan 22, 2025
6baeef9
simplications, dead code removal, test improvement
carlomazzaferro Jan 22, 2025
8512396
test improvements and debug logging (#961)
carlomazzaferro Jan 23, 2025
b970cb6
bump cudarc
carlomazzaferro Jan 24, 2025
9222246
bump cudarc
carlomazzaferro Jan 24, 2025
6fd7475
remove borrows
carlomazzaferro Jan 24, 2025
e4cb609
lock
carlomazzaferro Jan 24, 2025
6ba1406
Merge branch 'bump/cudarc-0.13.3' into feat/LUC-with-bitmaps
carlomazzaferro Jan 24, 2025
844c663
Merge branch 'main' of github.com:worldcoin/iris-mpc into feat/LUC-wi…
carlomazzaferro Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 80 additions & 16 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,26 @@ const PTX_SRC: &str = include_str!("kernel.cu");
const OPEN_RESULTS_FUNCTION: &str = "openResults";
const MERGE_DB_RESULTS_FUNCTION: &str = "mergeDbResults";
const MERGE_BATCH_RESULTS_FUNCTION: &str = "mergeBatchResults";
const MERGE_BATCH_RESULTS_WITH_OR_POLICY_BITMAP_FUNCTION: &str = "mergeDbResultsWithOrPolicyBitmap";
const ALL_MATCHES_LEN: usize = 256;

pub struct DistanceComparator {
pub device_manager: Arc<DeviceManager>,
pub open_kernels: Vec<CudaFunction>,
pub merge_db_kernels: Vec<CudaFunction>,
pub merge_batch_kernels: Vec<CudaFunction>,
pub query_length: usize,
pub opened_results: Vec<CudaSlice<u32>>,
pub final_results: Vec<CudaSlice<u32>>,
pub results_init_host: Vec<u32>,
pub final_results_init_host: Vec<u32>,
pub match_counters: Vec<CudaSlice<u32>>,
pub all_matches: Vec<CudaSlice<u32>>,
pub match_counters_left: Vec<CudaSlice<u32>>,
pub match_counters_right: Vec<CudaSlice<u32>>,
pub partial_results_left: Vec<CudaSlice<u32>>,
pub partial_results_right: Vec<CudaSlice<u32>>,
pub device_manager: Arc<DeviceManager>,
pub open_kernels: Vec<CudaFunction>,
pub merge_db_kernels: Vec<CudaFunction>,
pub merge_batch_kernels: Vec<CudaFunction>,
pub merge_batch_with_bitmap_kernels: Vec<CudaFunction>,
pub query_length: usize,
pub opened_results: Vec<CudaSlice<u32>>,
pub final_results: Vec<CudaSlice<u32>>,
pub results_init_host: Vec<u32>,
pub final_results_init_host: Vec<u32>,
pub match_counters: Vec<CudaSlice<u32>>,
pub all_matches: Vec<CudaSlice<u32>>,
pub match_counters_left: Vec<CudaSlice<u32>>,
pub match_counters_right: Vec<CudaSlice<u32>>,
pub partial_results_left: Vec<CudaSlice<u32>>,
pub partial_results_right: Vec<CudaSlice<u32>>,
}

impl DistanceComparator {
Expand All @@ -39,6 +41,7 @@ impl DistanceComparator {
let mut open_kernels: Vec<CudaFunction> = Vec::new();
let mut merge_db_kernels = Vec::new();
let mut merge_batch_kernels = Vec::new();
let mut merge_batch_with_bitmap_kernels: Vec<CudaFunction> = Vec::new();
let mut opened_results = vec![];
let mut final_results = vec![];
let mut match_counters = vec![];
Expand All @@ -60,13 +63,17 @@ impl DistanceComparator {
OPEN_RESULTS_FUNCTION,
MERGE_DB_RESULTS_FUNCTION,
MERGE_BATCH_RESULTS_FUNCTION,
MERGE_BATCH_RESULTS_WITH_OR_POLICY_BITMAP_FUNCTION,
])
.unwrap();

let open_results_function = device.get_func("", OPEN_RESULTS_FUNCTION).unwrap();
let merge_db_results_function = device.get_func("", MERGE_DB_RESULTS_FUNCTION).unwrap();
let merge_batch_results_function =
device.get_func("", MERGE_BATCH_RESULTS_FUNCTION).unwrap();
let merge_batch_results_with_bitmap_function = device
.get_func("", MERGE_BATCH_RESULTS_WITH_OR_POLICY_BITMAP_FUNCTION)
.unwrap();

opened_results.push(device.htod_copy(results_init_host.clone()).unwrap());
final_results.push(device.htod_copy(final_results_init_host.clone()).unwrap());
Expand All @@ -88,17 +95,18 @@ impl DistanceComparator {
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);

open_kernels.push(open_results_function);
merge_db_kernels.push(merge_db_results_function);
merge_batch_kernels.push(merge_batch_results_function);
merge_batch_with_bitmap_kernels.push(merge_batch_results_with_bitmap_function);
}

Self {
device_manager,
open_kernels,
merge_db_kernels,
merge_batch_kernels,
merge_batch_with_bitmap_kernels,
query_length,
opened_results,
final_results,
Expand Down Expand Up @@ -166,6 +174,62 @@ impl DistanceComparator {
}
}

#[allow(clippy::too_many_arguments)]
pub fn join_db_matches_with_bitmaps(
&self,
max_db_size: usize,
matches_bitmap_left: &[CudaSlice<u64>],
matches_bitmap_right: &[CudaSlice<u64>],
final_results: &[CudaSlice<u32>],
db_sizes: &[usize],
streams: &[CudaStream],
or_policies_bitmap: &[CudaSlice<u64>],
) {
for i in 0..self.device_manager.device_count() {
if db_sizes[i] == 0 {
continue;
}

let num_elements = (db_sizes[i] * self.query_length / ROTATIONS).div_ceil(64);

let threads_per_block = DEFAULT_LAUNCH_CONFIG_THREADS; // ON CHANGE: sync with kernel
let cfg = launch_config_from_elements_and_threads(
num_elements as u32,
threads_per_block,
&self.device_manager.devices()[i],
);

self.device_manager.device(i).bind_to_thread().unwrap();

unsafe {
self.merge_batch_with_bitmap_kernels[i]
.clone()
.launch_on_stream(
&streams[i],
cfg,
(
&matches_bitmap_left[i],
&matches_bitmap_right[i],
&final_results[i],
self.query_length,
db_sizes[i] as u64,
num_elements,
max_db_size,
&self.match_counters[i],
&self.all_matches[i],
&self.match_counters_left[i],
&self.match_counters_right[i],
&self.partial_results_left[i],
&self.partial_results_right[i],
// Additional args
&or_policies_bitmap[i],
),
)
.unwrap();
}
}
}

pub fn join_db_matches(
&self,
matches_bitmap_left: &[CudaSlice<u64>],
Expand Down
62 changes: 62 additions & 0 deletions iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon

extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight)
{

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
{
Expand Down Expand Up @@ -93,6 +94,67 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
}
}

extern "C" __global__ void mergeDbResultsWithOrPolicyBitmap(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, size_t maxDbLength, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight, const unsigned long long *orPolicyBitmap) // 2D bitmap stored as 1D
{

size_t rowStride64 = (maxDbLength + 63) / 64;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Row stride needs to be calculated in the same way as it is in the host: using the max_db_size

size_t totalBits = maxDbLength * queryLength;

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
{
for (int i = 0; i < 64; i++)
{

size_t globalBit = idx * 64 + i;
// Protect against any leftover bits if totalBits not multiple of 64
if (globalBit >= totalBits) break;

unsigned int queryIdx = globalBit / dbLength;
unsigned int dbIdx = globalBit % dbLength;
bool matchLeft = (matchResultsLeft[idx] & (1ULL << i));
bool matchRight = (matchResultsRight[idx] & (1ULL << i));

// Check bounds
if (queryIdx >= queryLength || dbIdx >= dbLength)
continue;

// Check for partial results (only used for debugging)
if (matchLeft)
{
unsigned int qmcL = atomicAdd(&matchCounterLeft[queryIdx], 1);
if (qmcL < MAX_MATCHES_LEN)
partialResultsLeft[MAX_MATCHES_LEN * queryIdx + qmcL] = dbIdx;
}
if (matchRight)
{
unsigned int qmcR = atomicAdd(&matchCounterRight[queryIdx], 1);
if (qmcR < MAX_MATCHES_LEN)
partialResultsRight[MAX_MATCHES_LEN * queryIdx + qmcR] = dbIdx;
}
size_t rowIndex = queryIdx * rowStride64;
carlomazzaferro marked this conversation as resolved.
Show resolved Hide resolved
size_t orPolicyBitmapIdx = rowIndex + (dbIdx / 64);

bool useOr = (orPolicyBitmap[orPolicyBitmapIdx]
& (1ULL << (dbIdx % 64))) != 0ULL;

// If useOr is true => (matchLeft || matchRight),
// else => (matchLeft && matchRight).
bool finalMatch = useOr ? (matchLeft || matchRight)
: (matchLeft && matchRight);

if (finalMatch)
{
atomicMin(&finalResults[queryIdx], dbIdx);
unsigned int qmc = atomicAdd(&matchCounter[queryIdx], 1);
if (qmc < MAX_MATCHES_LEN)
allMatches[MAX_MATCHES_LEN * queryIdx + qmc] = dbIdx;
}
}
}
}


extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *__matchCounterLeft, unsigned int *__matchCounterRight, unsigned int *__partialResultsLeft, unsigned int *__partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down
86 changes: 79 additions & 7 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
self,
comm::NcclComm,
device_manager::DeviceManager,
htod_on_stream_sync,
query_processor::{
CompactQuery, CudaVec2DSlicerRawPointer, DeviceCompactQuery, DeviceCompactSums,
},
Expand Down Expand Up @@ -606,6 +607,9 @@ impl ServerActor {
&& batch_size * ROTATIONS == batch.db_right_preprocessed.len(),
"Query batch sizes mismatch"
);
if !batch.or_rule_serial_ids.is_empty() {
assert_eq!(batch.or_rule_serial_ids.len(), batch_size);
};

///////////////////////////////////////////////////////////////////
// PERFORM DELETIONS (IF ANY)
Expand Down Expand Up @@ -762,16 +766,48 @@ impl ServerActor {
///////////////////////////////////////////////////////////////////
// MERGE LEFT & RIGHT results
///////////////////////////////////////////////////////////////////

tracing::info!("Joining both sides");
// Merge results and fetch matching indices
// Format: host_results[device_index][query_index]
self.distance_comparator.join_db_matches(
&self.db_match_list_left,
&self.db_match_list_right,
&self.final_results,
&self.current_db_sizes,
&self.streams[0],
);

// Initialize bitmap with OR rule, if exists
if !batch.or_rule_serial_ids.is_empty()
&& !batch
.or_rule_serial_ids
.iter()
.all(|inner| inner.is_empty())
{
assert_eq!(batch.or_rule_serial_ids.len(), batch_size);

// Populate the pre-allocated OR policy bitmap with the serial ids
let host_or_policy_bitmap = prepare_or_policy_bitmap(
self.max_db_size,
batch.or_rule_serial_ids.clone(),
batch_size,
);

let device_or_policy_bitmap =
self.allocate_or_policy_bitmap(host_or_policy_bitmap.clone());

self.distance_comparator.join_db_matches_with_bitmaps(
self.max_db_size,
&self.db_match_list_left,
&self.db_match_list_right,
&self.final_results,
&self.current_db_sizes,
&self.streams[0],
&device_or_policy_bitmap,
);
} else {
self.distance_comparator.join_db_matches(
&self.db_match_list_left,
&self.db_match_list_right,
&self.final_results,
&self.current_db_sizes,
&self.streams[0],
);
}

self.distance_comparator.join_batch_matches(
&self.batch_match_list_left,
Expand Down Expand Up @@ -1565,6 +1601,20 @@ impl ServerActor {

Ok((compact_device_queries, compact_device_sums))
}

fn allocate_or_policy_bitmap(&mut self, bitmap: Vec<u64>) -> Vec<CudaSlice<u64>> {
let devices = self.device_manager.devices();

let mut or_policy_bitmap = Vec::with_capacity(devices.len());

for (device_idx, dev) in devices.iter().enumerate() {
// Transfer the bitmap to the device. It will be the same for each of the
// devices
let _bitmap = htod_on_stream_sync(&bitmap, dev, &self.streams[0][device_idx]).unwrap();
or_policy_bitmap.push(_bitmap);
}
or_policy_bitmap
}
}

/// Internal helper function to log the timers of measured cuda streams.
Expand Down Expand Up @@ -1844,3 +1894,25 @@ pub fn get_dummy_shares_for_deletion(
.into();
(iris_share, mask_share)
}

pub fn prepare_or_policy_bitmap(
max_db_size: usize,
or_rule_serial_ids: Vec<Vec<u32>>,
batch_size: usize,
) -> Vec<u64> {
let row_stride64 = (max_db_size + 63) / 64;
let total_size = row_stride64 * batch_size;

// Create the bitmap on the host
let mut bitmap = vec![0u64; total_size];

for (query_idx, db_indices) in or_rule_serial_ids.iter().enumerate() {
for &db_idx in db_indices {
let row_start = query_idx * row_stride64;
let word_idx = row_start + (db_idx as usize / 64);
let bit_offset = db_idx as usize % 64;
bitmap[word_idx] |= 1 << bit_offset;
}
}
bitmap
}
6 changes: 5 additions & 1 deletion iris-mpc-gpu/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ mod actor;
pub mod sync_nccl;

use crate::dot::{share_db::preprocess_query, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, ROTATIONS};
pub use actor::{get_dummy_shares_for_deletion, ServerActor, ServerActorHandle};
pub use actor::{
get_dummy_shares_for_deletion, prepare_or_policy_bitmap, ServerActor, ServerActorHandle,
};
use iris_mpc_common::galois_engine::degree4::{
GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare,
};
Expand Down Expand Up @@ -79,6 +81,7 @@ pub struct BatchQuery {
pub db_right_preprocessed: BatchQueryEntriesPreprocessed,
pub deletion_requests_indices: Vec<u32>, // 0-indexed indicies in of entries to be deleted
pub deletion_requests_metadata: Vec<BatchMetadata>,
pub or_rule_serial_ids: Vec<Vec<u32>>,
pub valid_entries: Vec<bool>,
}

Expand Down Expand Up @@ -124,6 +127,7 @@ impl BatchQuery {
filter_by_indices!(self.store_left.mask, indices_set);
filter_by_indices!(self.store_right.code, indices_set);
filter_by_indices!(self.store_right.mask, indices_set);
filter_by_indices!(self.or_rule_serial_ids, indices_set);
filter_by_indices_with_rotations!(self.query_left.code, indices_set);
filter_by_indices_with_rotations!(self.query_left.mask, indices_set);
filter_by_indices_with_rotations!(self.db_left.code, indices_set);
Expand Down
Loading
Loading