Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POP-2034] LUC implementation on iric-mpc #951

Merged
merged 47 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
09fa7b1
WIP
carlomazzaferro Jan 16, 2025
11ea27f
Merge branch 'main' of github.com:worldcoin/iris-mpc into feat/LUC-wi…
carlomazzaferro Jan 16, 2025
80f4018
reduce args, calculate on kernel
carlomazzaferro Jan 17, 2025
6be106d
populate or policy bitmap and reduce argument list
carlomazzaferro Jan 17, 2025
dfe2829
address PR comments
carlomazzaferro Jan 17, 2025
49ff959
use htod_on_stream_sync for mem allocation, remove pre-allocation
carlomazzaferro Jan 18, 2025
be2fc7b
do no fork streams, use existing
carlomazzaferro Jan 19, 2025
c1d9a16
clippy
carlomazzaferro Jan 19, 2025
d73a77a
generate masks and templates for left and right
carlomazzaferro Jan 20, 2025
a8a2b8e
wrap into structs
carlomazzaferro Jan 20, 2025
af65bfb
doc
carlomazzaferro Jan 20, 2025
62b812d
Merge branch 'chore/improve-gpu-tests' into feat/LUC-with-bitmaps
carlomazzaferro Jan 20, 2025
60bd7b1
implement OR rule tests
carlomazzaferro Jan 20, 2025
03d6af3
merge main
carlomazzaferro Jan 20, 2025
d3fbd87
add option, e2e on or rule
carlomazzaferro Jan 20, 2025
fac30ea
merge main
carlomazzaferro Jan 21, 2025
f51c122
debug log
carlomazzaferro Jan 21, 2025
5a5dce8
remove unneded division, make it pretty
carlomazzaferro Jan 21, 2025
a79cb51
remove unneded division, make it pretty
carlomazzaferro Jan 21, 2025
424926d
remove leftovers
carlomazzaferro Jan 21, 2025
5a90799
address test pr comments
carlomazzaferro Jan 22, 2025
fec1b79
do not assert left = right partial matches for or ruling
carlomazzaferro Jan 22, 2025
6baeef9
simplications, dead code removal, test improvement
carlomazzaferro Jan 22, 2025
8512396
test improvements and debug logging (#961)
carlomazzaferro Jan 23, 2025
b970cb6
bump cudarc
carlomazzaferro Jan 24, 2025
9222246
bump cudarc
carlomazzaferro Jan 24, 2025
6fd7475
remove borrows
carlomazzaferro Jan 24, 2025
e4cb609
lock
carlomazzaferro Jan 24, 2025
6ba1406
Merge branch 'bump/cudarc-0.13.3' into feat/LUC-with-bitmaps
carlomazzaferro Jan 24, 2025
844c663
Merge branch 'main' of github.com:worldcoin/iris-mpc into feat/LUC-wi…
carlomazzaferro Jan 24, 2025
8d8f764
Merge branch 'main' of github.com:worldcoin/iris-mpc into feat/LUC-wi…
carlomazzaferro Jan 28, 2025
52350a6
merge main
carlomazzaferro Jan 29, 2025
7c6c170
bump cudarc
carlomazzaferro Jan 29, 2025
8121c22
Merge branch 'chore/bump-cudarc-0.13.4' into feat/LUC-with-bitmaps
carlomazzaferro Jan 29, 2025
fbd1062
fmt
carlomazzaferro Jan 29, 2025
6b43193
Merge branch 'main' of github.com:worldcoin/iris-mpc into feat/LUC-wi…
carlomazzaferro Jan 29, 2025
5149c2c
[POP-2024] Feat/luc with bitmaps test improvements (#974)
carlomazzaferro Jan 30, 2025
f94f8e0
merge e2e refactors
carlomazzaferro Jan 31, 2025
d20b477
merge e2e refactors
carlomazzaferro Jan 31, 2025
a020105
full mask only
carlomazzaferro Jan 31, 2025
881cc1e
fix tests for or rule
carlomazzaferro Jan 31, 2025
3a40e3f
clippy
carlomazzaferro Jan 31, 2025
9531a21
re-add check for bucket counds
carlomazzaferro Jan 31, 2025
ce02b47
typo
carlomazzaferro Jan 31, 2025
75a4ea8
record stream time for or policy
carlomazzaferro Jan 31, 2025
3ef0aa5
Revert "record stream time for or policy"
carlomazzaferro Jan 31, 2025
808f373
record stream time for or policy
carlomazzaferro Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ where
}

pub const IDENTITY_DELETION_MESSAGE_TYPE: &str = "identity_deletion";
pub const ANONYMIZED_STATISTICS_MESSAGE_TYPE: &str = "anonymized_statstics";
pub const ANONYMIZED_STATISTICS_MESSAGE_TYPE: &str = "anonymized_statistics";
pub const CIRCUIT_BREAKER_MESSAGE_TYPE: &str = "circuit_breaker";
pub const UNIQUENESS_MESSAGE_TYPE: &str = "uniqueness";

Expand Down
97 changes: 81 additions & 16 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,27 @@ const OPEN_RESULTS_FUNCTION: &str = "openResults";
const OPEN_RESULTS_BATCH_FUNCTION: &str = "openResultsBatch";
const MERGE_DB_RESULTS_FUNCTION: &str = "mergeDbResults";
const MERGE_BATCH_RESULTS_FUNCTION: &str = "mergeBatchResults";
const MERGE_BATCH_RESULTS_WITH_OR_POLICY_BITMAP_FUNCTION: &str = "mergeDbResultsWithOrPolicyBitmap";
const ALL_MATCHES_LEN: usize = 256;

pub struct DistanceComparator {
pub device_manager: Arc<DeviceManager>,
pub open_kernels: Vec<CudaFunction>,
pub open_batch_kernels: Vec<CudaFunction>,
pub merge_db_kernels: Vec<CudaFunction>,
pub merge_batch_kernels: Vec<CudaFunction>,
pub query_length: usize,
pub opened_results: Vec<CudaSlice<u32>>,
pub final_results: Vec<CudaSlice<u32>>,
pub results_init_host: Vec<u32>,
pub final_results_init_host: Vec<u32>,
pub match_counters: Vec<CudaSlice<u32>>,
pub all_matches: Vec<CudaSlice<u32>>,
pub match_counters_left: Vec<CudaSlice<u32>>,
pub match_counters_right: Vec<CudaSlice<u32>>,
pub partial_results_left: Vec<CudaSlice<u32>>,
pub partial_results_right: Vec<CudaSlice<u32>>,
pub device_manager: Arc<DeviceManager>,
pub open_kernels: Vec<CudaFunction>,
pub open_batch_kernels: Vec<CudaFunction>,
pub merge_db_kernels: Vec<CudaFunction>,
pub merge_batch_kernels: Vec<CudaFunction>,
pub merge_batch_with_bitmap_kernels: Vec<CudaFunction>,
pub query_length: usize,
pub opened_results: Vec<CudaSlice<u32>>,
pub final_results: Vec<CudaSlice<u32>>,
pub results_init_host: Vec<u32>,
pub final_results_init_host: Vec<u32>,
pub match_counters: Vec<CudaSlice<u32>>,
pub all_matches: Vec<CudaSlice<u32>>,
pub match_counters_left: Vec<CudaSlice<u32>>,
pub match_counters_right: Vec<CudaSlice<u32>>,
pub partial_results_left: Vec<CudaSlice<u32>>,
pub partial_results_right: Vec<CudaSlice<u32>>,
}

impl DistanceComparator {
Expand All @@ -48,6 +50,7 @@ impl DistanceComparator {
let mut open_batch_kernels: Vec<CudaFunction> = Vec::new();
let mut merge_db_kernels = Vec::new();
let mut merge_batch_kernels = Vec::new();
let mut merge_batch_with_bitmap_kernels: Vec<CudaFunction> = Vec::new();
let mut opened_results = vec![];
let mut final_results = vec![];
let mut match_counters = vec![];
Expand All @@ -70,6 +73,7 @@ impl DistanceComparator {
OPEN_RESULTS_BATCH_FUNCTION,
MERGE_DB_RESULTS_FUNCTION,
MERGE_BATCH_RESULTS_FUNCTION,
MERGE_BATCH_RESULTS_WITH_OR_POLICY_BITMAP_FUNCTION,
])
.unwrap();

Expand All @@ -79,6 +83,9 @@ impl DistanceComparator {
let merge_db_results_function = device.get_func("", MERGE_DB_RESULTS_FUNCTION).unwrap();
let merge_batch_results_function =
device.get_func("", MERGE_BATCH_RESULTS_FUNCTION).unwrap();
let merge_batch_results_with_bitmap_function = device
.get_func("", MERGE_BATCH_RESULTS_WITH_OR_POLICY_BITMAP_FUNCTION)
.unwrap();

opened_results.push(device.htod_copy(results_init_host.clone()).unwrap());
final_results.push(device.htod_copy(final_results_init_host.clone()).unwrap());
Expand All @@ -105,6 +112,7 @@ impl DistanceComparator {
open_batch_kernels.push(open_results_batch_function);
merge_db_kernels.push(merge_db_results_function);
merge_batch_kernels.push(merge_batch_results_function);
merge_batch_with_bitmap_kernels.push(merge_batch_results_with_bitmap_function);
}

Self {
Expand All @@ -113,6 +121,7 @@ impl DistanceComparator {
open_batch_kernels,
merge_db_kernels,
merge_batch_kernels,
merge_batch_with_bitmap_kernels,
query_length,
opened_results,
final_results,
Expand Down Expand Up @@ -252,6 +261,62 @@ impl DistanceComparator {
}
}

#[allow(clippy::too_many_arguments)]
pub fn join_db_matches_with_bitmaps(
&self,
max_db_size: usize,
matches_bitmap_left: &[CudaSlice<u64>],
matches_bitmap_right: &[CudaSlice<u64>],
final_results: &[CudaSlice<u32>],
db_sizes: &[usize],
streams: &[CudaStream],
or_policies_bitmap: &[CudaSlice<u64>],
) {
for i in 0..self.device_manager.device_count() {
if db_sizes[i] == 0 {
continue;
}

let num_elements = (db_sizes[i] * self.query_length / ROTATIONS).div_ceil(64);

let threads_per_block = DEFAULT_LAUNCH_CONFIG_THREADS; // ON CHANGE: sync with kernel
let cfg = launch_config_from_elements_and_threads(
num_elements as u32,
threads_per_block,
&self.device_manager.devices()[i],
);

self.device_manager.device(i).bind_to_thread().unwrap();

unsafe {
self.merge_batch_with_bitmap_kernels[i]
.clone()
.launch_on_stream(
&streams[i],
cfg,
(
&matches_bitmap_left[i],
&matches_bitmap_right[i],
&final_results[i],
self.query_length,
db_sizes[i] as u64,
num_elements,
max_db_size,
&self.match_counters[i],
&self.all_matches[i],
&self.match_counters_left[i],
&self.match_counters_right[i],
&self.partial_results_left[i],
&self.partial_results_right[i],
// Additional args
&or_policies_bitmap[i],
),
)
.unwrap();
}
}
}

pub fn join_db_matches(
&self,
matches_bitmap_left: &[CudaSlice<u64>],
Expand Down
62 changes: 62 additions & 0 deletions iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon

extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight)
{

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
{
Expand Down Expand Up @@ -129,6 +130,67 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
}
}

extern "C" __global__ void mergeDbResultsWithOrPolicyBitmap(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, size_t maxDbLength, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight, const unsigned long long *orPolicyBitmap) // 2D bitmap stored as 1D
{

size_t rowStride64 = (maxDbLength + 63) / 64;
carlomazzaferro marked this conversation as resolved.
Show resolved Hide resolved
size_t totalBits = maxDbLength * queryLength;

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
{
for (int i = 0; i < 64; i++)
{

size_t globalBit = idx * 64 + i;
// Protect against any leftover bits if totalBits not multiple of 64
if (globalBit >= totalBits) break;

unsigned int queryIdx = globalBit / dbLength;
unsigned int dbIdx = globalBit % dbLength;
bool matchLeft = (matchResultsLeft[idx] & (1ULL << i));
bool matchRight = (matchResultsRight[idx] & (1ULL << i));

// Check bounds
if (queryIdx >= queryLength || dbIdx >= dbLength)
continue;

// Check for partial results (only used for debugging)
if (matchLeft)
{
unsigned int qmcL = atomicAdd(&matchCounterLeft[queryIdx], 1);
if (qmcL < MAX_MATCHES_LEN)
partialResultsLeft[MAX_MATCHES_LEN * queryIdx + qmcL] = dbIdx;
}
if (matchRight)
{
unsigned int qmcR = atomicAdd(&matchCounterRight[queryIdx], 1);
if (qmcR < MAX_MATCHES_LEN)
partialResultsRight[MAX_MATCHES_LEN * queryIdx + qmcR] = dbIdx;
}
size_t rowIndex = queryIdx * rowStride64;
carlomazzaferro marked this conversation as resolved.
Show resolved Hide resolved
size_t orPolicyBitmapIdx = rowIndex + (dbIdx / 64);

bool useOr = (orPolicyBitmap[orPolicyBitmapIdx]
& (1ULL << (dbIdx % 64))) != 0ULL;

// If useOr is true => (matchLeft || matchRight),
// else => (matchLeft && matchRight).
bool finalMatch = useOr ? (matchLeft || matchRight)
: (matchLeft && matchRight);

if (finalMatch)
{
atomicMin(&finalResults[queryIdx], dbIdx);
unsigned int qmc = atomicAdd(&matchCounter[queryIdx], 1);
if (qmc < MAX_MATCHES_LEN)
allMatches[MAX_MATCHES_LEN * queryIdx + qmc] = dbIdx;
}
}
}
}


extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *__matchCounterLeft, unsigned int *__matchCounterRight, unsigned int *__partialResultsLeft, unsigned int *__partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down
88 changes: 81 additions & 7 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,9 @@ impl ServerActor {
&& batch_size * ROTATIONS == batch.db_right_preprocessed.len(),
"Query batch sizes mismatch"
);
if !batch.or_rule_serial_ids.is_empty() {
assert_eq!(batch.or_rule_serial_ids.len(), batch_size);
};

///////////////////////////////////////////////////////////////////
// PERFORM DELETIONS (IF ANY)
Expand Down Expand Up @@ -844,16 +847,51 @@ impl ServerActor {
///////////////////////////////////////////////////////////////////
// MERGE LEFT & RIGHT results
///////////////////////////////////////////////////////////////////

tracing::info!("Joining both sides");
// Merge results and fetch matching indices
// Format: host_results[device_index][query_index]
self.distance_comparator.join_db_matches(
&self.db_match_list_left,
&self.db_match_list_right,
&self.final_results,
&self.current_db_sizes,
&self.streams[0],
);

// Initialize bitmap with OR rule, if exists
if !batch.or_rule_serial_ids.is_empty()
&& !batch
.or_rule_serial_ids
.iter()
.all(|inner| inner.is_empty())
{
assert_eq!(batch.or_rule_serial_ids.len(), batch_size);

let now = Instant::now();
tracing::info!("Preparing and allocating OR policy bitmap");
// Populate the pre-allocated OR policy bitmap with the serial ids
let host_or_policy_bitmap = prepare_or_policy_bitmap(
self.max_db_size,
batch.or_rule_serial_ids.clone(),
batch_size,
);

let device_or_policy_bitmap =
self.allocate_or_policy_bitmap(host_or_policy_bitmap.clone());
Comment on lines +866 to +874
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's maybe add some timing around this, so we now if this is worth improving more? (we had some ideas to improve this)

tracing::info!("OR policy bitmap prepared in {:?}", now.elapsed());

self.distance_comparator.join_db_matches_with_bitmaps(
self.max_db_size,
&self.db_match_list_left,
&self.db_match_list_right,
&self.final_results,
&self.current_db_sizes,
&self.streams[0],
&device_or_policy_bitmap,
);
} else {
self.distance_comparator.join_db_matches(
&self.db_match_list_left,
&self.db_match_list_right,
&self.final_results,
&self.current_db_sizes,
&self.streams[0],
);
}

self.distance_comparator.join_batch_matches(
&self.batch_match_list_left,
Expand Down Expand Up @@ -1810,6 +1848,20 @@ impl ServerActor {

Ok((compact_device_queries, compact_device_sums))
}

fn allocate_or_policy_bitmap(&mut self, bitmap: Vec<u64>) -> Vec<CudaSlice<u64>> {
let devices = self.device_manager.devices();

let mut or_policy_bitmap = Vec::with_capacity(devices.len());

for (device_idx, dev) in devices.iter().enumerate() {
// Transfer the bitmap to the device. It will be the same for each of the
// devices
let _bitmap = htod_on_stream_sync(&bitmap, dev, &self.streams[0][device_idx]).unwrap();
or_policy_bitmap.push(_bitmap);
}
or_policy_bitmap
}
}

/// Internal helper function to log the timers of measured cuda streams.
Expand Down Expand Up @@ -2232,3 +2284,25 @@ fn sort_shares_by_indices(
})
.collect::<Vec<_>>()
}

pub fn prepare_or_policy_bitmap(
max_db_size: usize,
or_rule_serial_ids: Vec<Vec<u32>>,
batch_size: usize,
) -> Vec<u64> {
let row_stride64 = (max_db_size + 63) / 64;
let total_size = row_stride64 * batch_size;

// Create the bitmap on the host
let mut bitmap = vec![0u64; total_size];

for (query_idx, db_indices) in or_rule_serial_ids.iter().enumerate() {
for &db_idx in db_indices {
let row_start = query_idx * row_stride64;
let word_idx = row_start + (db_idx as usize / 64);
let bit_offset = db_idx as usize % 64;
bitmap[word_idx] |= 1 << bit_offset;
}
}
bitmap
}
6 changes: 5 additions & 1 deletion iris-mpc-gpu/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ mod actor;
pub mod sync_nccl;

use crate::dot::{share_db::preprocess_query, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, ROTATIONS};
pub use actor::{get_dummy_shares_for_deletion, ServerActor, ServerActorHandle};
pub use actor::{
get_dummy_shares_for_deletion, prepare_or_policy_bitmap, ServerActor, ServerActorHandle,
};
use iris_mpc_common::{
galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare},
helpers::statistics::BucketStatistics,
Expand Down Expand Up @@ -80,6 +82,7 @@ pub struct BatchQuery {
pub db_right_preprocessed: BatchQueryEntriesPreprocessed,
pub deletion_requests_indices: Vec<u32>, // 0-indexed indicies in of entries to be deleted
pub deletion_requests_metadata: Vec<BatchMetadata>,
pub or_rule_serial_ids: Vec<Vec<u32>>,
pub valid_entries: Vec<bool>,
}

Expand Down Expand Up @@ -125,6 +128,7 @@ impl BatchQuery {
filter_by_indices!(self.store_left.mask, indices_set);
filter_by_indices!(self.store_right.code, indices_set);
filter_by_indices!(self.store_right.mask, indices_set);
filter_by_indices!(self.or_rule_serial_ids, indices_set);
filter_by_indices_with_rotations!(self.query_left.code, indices_set);
filter_by_indices_with_rotations!(self.query_left.mask, indices_set);
filter_by_indices_with_rotations!(self.db_left.code, indices_set);
Expand Down
Loading
Loading