diff --git a/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml index a9a1337c0..a57487c59 100644 --- a/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml @@ -92,6 +92,12 @@ env: - name: SMPC__MAX_BATCH_SIZE value: "64" + - name: SMPC__MATCH_DISTANCES_BUFFER_SIZE + value: "128" + + - name: SMPC__N_BUCKETS + value: "10" + - name: SMPC__SERVICE__METRICS__HOST valueFrom: fieldRef: diff --git a/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml index 05937004e..b74472502 100644 --- a/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml @@ -92,6 +92,12 @@ env: - name: SMPC__MAX_BATCH_SIZE value: "64" + - name: SMPC__MATCH_DISTANCES_BUFFER_SIZE + value: "128" + + - name: SMPC__N_BUCKETS + value: "10" + - name: SMPC__SERVICE__METRICS__HOST valueFrom: fieldRef: diff --git a/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml index 4730fe0f0..f76d5b2fe 100644 --- a/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml @@ -92,6 +92,12 @@ env: - name: SMPC__MAX_BATCH_SIZE value: "64" + - name: SMPC__MATCH_DISTANCES_BUFFER_SIZE + value: "128" + + - name: SMPC__N_BUCKETS + value: "10" + - name: SMPC__SERVICE__METRICS__HOST valueFrom: fieldRef: diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index a6a5c6484..1146920ef 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -111,6 +111,12 @@ pub struct Config { #[serde(default)] pub fixed_shared_secrets: bool, + + #[serde(default = "default_match_distances_buffer_size")] + pub match_distances_buffer_size: usize, + + #[serde(default = "default_n_buckets")] + pub n_buckets: usize, } fn default_load_chunks_parallelism() -> usize { @@ -145,6 +151,14 @@ fn default_db_load_safety_overlap_seconds() -> i64 { 60 } +fn default_match_distances_buffer_size() -> usize { + 1 << 20 +} + +fn default_n_buckets() -> usize { + 10 +} + impl Config { pub fn load_config(prefix: &str) -> eyre::Result { let settings = config::Config::builder(); diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 68ec6d0bf..53ae5c129 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -77,8 +77,6 @@ impl ServerActorHandle { const DB_CHUNK_SIZE: usize = 1 << 15; const KDF_SALT: &str = "111a1a93518f670e9bb0c2c68888e2beb9406d4c4ed571dc77b801e676ae3091"; // Random 32 byte salt const SUPERMATCH_THRESHOLD: usize = 4_000; -const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 20; -const N_BUCKETS: usize = 10; pub struct ServerActor { job_queue: mpsc::Receiver, @@ -112,6 +110,8 @@ pub struct ServerActor { query_db_size: Vec, max_batch_size: usize, max_db_size: usize, + match_distances_buffer_size: usize, + n_buckets: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -140,6 +140,8 @@ impl ServerActor { job_queue_size: usize, max_db_size: usize, max_batch_size: usize, + match_distances_buffer_size: usize, + n_buckets: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -152,6 +154,8 @@ impl ServerActor { job_queue_size, max_db_size, max_batch_size, + match_distances_buffer_size, + n_buckets, return_partial_results, disable_persistence, enable_debug_timing, @@ -165,6 +169,8 @@ impl ServerActor { job_queue_size: usize, max_db_size: usize, max_batch_size: usize, + match_distances_buffer_size: usize, + n_buckets: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -179,6 +185,8 @@ impl ServerActor { job_queue_size, max_db_size, max_batch_size, + match_distances_buffer_size, + n_buckets, return_partial_results, disable_persistence, enable_debug_timing, @@ -194,6 +202,8 @@ impl ServerActor { job_queue_size: usize, max_db_size: usize, max_batch_size: usize, + match_distances_buffer_size: usize, + n_buckets: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -207,6 +217,8 @@ impl ServerActor { rx, max_db_size, max_batch_size, + match_distances_buffer_size, + n_buckets, return_partial_results, disable_persistence, enable_debug_timing, @@ -223,11 +235,13 @@ impl ServerActor { job_queue: mpsc::Receiver, max_db_size: usize, max_batch_size: usize, + match_distances_buffer_size: usize, + n_buckets: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, ) -> eyre::Result { - assert!(max_batch_size != 0); + assert_ne!(max_batch_size, 0); let mut kdf_nonce = 0; let kdf_salt: Salt = Salt::new(HKDF_SHA256, &hex::decode(KDF_SALT)?); let n_queries = max_batch_size * ROTATIONS; @@ -333,8 +347,8 @@ impl ServerActor { let phase2_buckets = Circuits::new( party_id, - MATCH_DISTANCES_BUFFER_SIZE, - MATCH_DISTANCES_BUFFER_SIZE / 64, + match_distances_buffer_size, + match_distances_buffer_size / 64, next_chacha_seeds(chacha_seeds)?, device_manager.clone(), comms.clone(), @@ -373,20 +387,20 @@ impl ServerActor { // Buffers and counters for match distribution let match_distances_buffer_codes_left = - distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE); + distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size); let match_distances_buffer_codes_right = - distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE); + distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size); let match_distances_buffer_masks_left = - distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE); + distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size); let match_distances_buffer_masks_right = - distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE); + distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size); let match_distances_counter_left = distance_comparator.prepare_match_distances_counter(); let match_distances_counter_right = distance_comparator.prepare_match_distances_counter(); let match_distances_indices_left = - distance_comparator.prepare_match_distances_index(MATCH_DISTANCES_BUFFER_SIZE); + distance_comparator.prepare_match_distances_index(match_distances_buffer_size); let match_distances_indices_right = - distance_comparator.prepare_match_distances_index(MATCH_DISTANCES_BUFFER_SIZE); - let buckets = distance_comparator.prepare_match_distances_buckets(N_BUCKETS); + distance_comparator.prepare_match_distances_index(match_distances_buffer_size); + let buckets = distance_comparator.prepare_match_distances_buckets(n_buckets); for dev in device_manager.devices() { dev.synchronize().unwrap(); @@ -423,6 +437,8 @@ impl ServerActor { batch_match_list_right, max_batch_size, max_db_size, + match_distances_buffer_size, + n_buckets, return_partial_results, disable_persistence, enable_debug_timing, @@ -1191,7 +1207,7 @@ impl ServerActor { tracing::info!("Matching distances collected: {}", total_distance_counter); - if total_distance_counter >= MATCH_DISTANCES_BUFFER_SIZE as u32 { + if total_distance_counter >= self.match_distances_buffer_size as u32 { let now = std::time::Instant::now(); tracing::info!("Collected enough match distances, starting bucket calculation"); @@ -1236,10 +1252,10 @@ impl ServerActor { &match_distances_buffers_codes_view, &match_distances_buffers_masks_view, batch_streams, - &(1..=N_BUCKETS) + &(1..=self.n_buckets) .map(|x: usize| { Circuits::translate_threshold_a( - MATCH_THRESHOLD_RATIO / (N_BUCKETS as f64) * (x as f64), + MATCH_THRESHOLD_RATIO / (self.n_buckets as f64) * (x as f64), ) as u16 }) .collect::>(), @@ -1252,7 +1268,7 @@ impl ServerActor { let mut results = String::new(); for i in 0..buckets.len() { - let step = MATCH_THRESHOLD_RATIO / (N_BUCKETS as f64); + let step = MATCH_THRESHOLD_RATIO / (self.n_buckets as f64); let previous_threshold = step * (i as f64); let threshold = step * (i as f64 + 1.0); let previous_count = if i == 0 { 0 } else { buckets[i - 1] }; @@ -1613,7 +1629,7 @@ impl ServerActor { &code_dots, &mask_dots, batch_size, - MATCH_DISTANCES_BUFFER_SIZE, + self.match_distances_buffer_size, request_streams, ); self.phase2.return_result_buffer(res); diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index 638dcab22..ab2e332b6 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -25,6 +25,8 @@ mod e2e_test { const INTERNAL_RNG_SEED: u64 = 0xdeadbeef; const NUM_BATCHES: usize = 10; const MAX_BATCH_SIZE: usize = 64; + const N_BUCKETS: usize = 10; + const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 7; const MAX_DELETIONS_PER_BATCH: usize = 10; const THRESHOLD_ABSOLUTE: usize = 4800; // 0.375 * 12800 @@ -141,6 +143,8 @@ mod e2e_test { 8, DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, + MATCH_DISTANCES_BUFFER_SIZE, + N_BUCKETS, true, false, false, @@ -170,6 +174,8 @@ mod e2e_test { 8, DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, + MATCH_DISTANCES_BUFFER_SIZE, + N_BUCKETS, true, false, false, @@ -199,6 +205,8 @@ mod e2e_test { 8, DB_SIZE + DB_BUFFER, MAX_BATCH_SIZE, + MATCH_DISTANCES_BUFFER_SIZE, + N_BUCKETS, true, false, false, diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index e067e19b1..0060122fb 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -1047,6 +1047,8 @@ async fn server_main(config: Config) -> eyre::Result<()> { 8, config.max_db_size, config.max_batch_size, + config.match_distances_buffer_size, + config.n_buckets, config.return_partial_results, config.disable_persistence, config.enable_debug_timing,