Skip to content

Commit

Permalink
Ps/buckets config improvements (#965)
Browse files Browse the repository at this point in the history
* buckets as config

* stage config
  • Loading branch information
carlomazzaferro authored Jan 24, 2025
1 parent 1681616 commit a2c9626
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 17 deletions.
6 changes: 6 additions & 0 deletions deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ env:
- name: SMPC__MAX_BATCH_SIZE
value: "64"

- name: SMPC__MATCH_DISTANCES_BUFFER_SIZE
value: "128"

- name: SMPC__N_BUCKETS
value: "10"

- name: SMPC__SERVICE__METRICS__HOST
valueFrom:
fieldRef:
Expand Down
6 changes: 6 additions & 0 deletions deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ env:
- name: SMPC__MAX_BATCH_SIZE
value: "64"

- name: SMPC__MATCH_DISTANCES_BUFFER_SIZE
value: "128"

- name: SMPC__N_BUCKETS
value: "10"

- name: SMPC__SERVICE__METRICS__HOST
valueFrom:
fieldRef:
Expand Down
6 changes: 6 additions & 0 deletions deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ env:
- name: SMPC__MAX_BATCH_SIZE
value: "64"

- name: SMPC__MATCH_DISTANCES_BUFFER_SIZE
value: "128"

- name: SMPC__N_BUCKETS
value: "10"

- name: SMPC__SERVICE__METRICS__HOST
valueFrom:
fieldRef:
Expand Down
14 changes: 14 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ pub struct Config {

#[serde(default)]
pub fixed_shared_secrets: bool,

#[serde(default = "default_match_distances_buffer_size")]
pub match_distances_buffer_size: usize,

#[serde(default = "default_n_buckets")]
pub n_buckets: usize,
}

fn default_load_chunks_parallelism() -> usize {
Expand Down Expand Up @@ -145,6 +151,14 @@ fn default_db_load_safety_overlap_seconds() -> i64 {
60
}

fn default_match_distances_buffer_size() -> usize {
1 << 20
}

fn default_n_buckets() -> usize {
10
}

impl Config {
pub fn load_config(prefix: &str) -> eyre::Result<Config> {
let settings = config::Config::builder();
Expand Down
50 changes: 33 additions & 17 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ impl ServerActorHandle {
const DB_CHUNK_SIZE: usize = 1 << 15;
const KDF_SALT: &str = "111a1a93518f670e9bb0c2c68888e2beb9406d4c4ed571dc77b801e676ae3091"; // Random 32 byte salt
const SUPERMATCH_THRESHOLD: usize = 4_000;
const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 20;
const N_BUCKETS: usize = 10;

pub struct ServerActor {
job_queue: mpsc::Receiver<ServerJob>,
Expand Down Expand Up @@ -112,6 +110,8 @@ pub struct ServerActor {
query_db_size: Vec<usize>,
max_batch_size: usize,
max_db_size: usize,
match_distances_buffer_size: usize,
n_buckets: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
Expand Down Expand Up @@ -140,6 +140,8 @@ impl ServerActor {
job_queue_size: usize,
max_db_size: usize,
max_batch_size: usize,
match_distances_buffer_size: usize,
n_buckets: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
Expand All @@ -152,6 +154,8 @@ impl ServerActor {
job_queue_size,
max_db_size,
max_batch_size,
match_distances_buffer_size,
n_buckets,
return_partial_results,
disable_persistence,
enable_debug_timing,
Expand All @@ -165,6 +169,8 @@ impl ServerActor {
job_queue_size: usize,
max_db_size: usize,
max_batch_size: usize,
match_distances_buffer_size: usize,
n_buckets: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
Expand All @@ -179,6 +185,8 @@ impl ServerActor {
job_queue_size,
max_db_size,
max_batch_size,
match_distances_buffer_size,
n_buckets,
return_partial_results,
disable_persistence,
enable_debug_timing,
Expand All @@ -194,6 +202,8 @@ impl ServerActor {
job_queue_size: usize,
max_db_size: usize,
max_batch_size: usize,
match_distances_buffer_size: usize,
n_buckets: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
Expand All @@ -207,6 +217,8 @@ impl ServerActor {
rx,
max_db_size,
max_batch_size,
match_distances_buffer_size,
n_buckets,
return_partial_results,
disable_persistence,
enable_debug_timing,
Expand All @@ -223,11 +235,13 @@ impl ServerActor {
job_queue: mpsc::Receiver<ServerJob>,
max_db_size: usize,
max_batch_size: usize,
match_distances_buffer_size: usize,
n_buckets: usize,
return_partial_results: bool,
disable_persistence: bool,
enable_debug_timing: bool,
) -> eyre::Result<Self> {
assert!(max_batch_size != 0);
assert_ne!(max_batch_size, 0);
let mut kdf_nonce = 0;
let kdf_salt: Salt = Salt::new(HKDF_SHA256, &hex::decode(KDF_SALT)?);
let n_queries = max_batch_size * ROTATIONS;
Expand Down Expand Up @@ -333,8 +347,8 @@ impl ServerActor {

let phase2_buckets = Circuits::new(
party_id,
MATCH_DISTANCES_BUFFER_SIZE,
MATCH_DISTANCES_BUFFER_SIZE / 64,
match_distances_buffer_size,
match_distances_buffer_size / 64,
next_chacha_seeds(chacha_seeds)?,
device_manager.clone(),
comms.clone(),
Expand Down Expand Up @@ -373,20 +387,20 @@ impl ServerActor {

// Buffers and counters for match distribution
let match_distances_buffer_codes_left =
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size);
let match_distances_buffer_codes_right =
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size);
let match_distances_buffer_masks_left =
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size);
let match_distances_buffer_masks_right =
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
distance_comparator.prepare_match_distances_buffer(match_distances_buffer_size);
let match_distances_counter_left = distance_comparator.prepare_match_distances_counter();
let match_distances_counter_right = distance_comparator.prepare_match_distances_counter();
let match_distances_indices_left =
distance_comparator.prepare_match_distances_index(MATCH_DISTANCES_BUFFER_SIZE);
distance_comparator.prepare_match_distances_index(match_distances_buffer_size);
let match_distances_indices_right =
distance_comparator.prepare_match_distances_index(MATCH_DISTANCES_BUFFER_SIZE);
let buckets = distance_comparator.prepare_match_distances_buckets(N_BUCKETS);
distance_comparator.prepare_match_distances_index(match_distances_buffer_size);
let buckets = distance_comparator.prepare_match_distances_buckets(n_buckets);

for dev in device_manager.devices() {
dev.synchronize().unwrap();
Expand Down Expand Up @@ -423,6 +437,8 @@ impl ServerActor {
batch_match_list_right,
max_batch_size,
max_db_size,
match_distances_buffer_size,
n_buckets,
return_partial_results,
disable_persistence,
enable_debug_timing,
Expand Down Expand Up @@ -1191,7 +1207,7 @@ impl ServerActor {

tracing::info!("Matching distances collected: {}", total_distance_counter);

if total_distance_counter >= MATCH_DISTANCES_BUFFER_SIZE as u32 {
if total_distance_counter >= self.match_distances_buffer_size as u32 {
let now = std::time::Instant::now();
tracing::info!("Collected enough match distances, starting bucket calculation");

Expand Down Expand Up @@ -1236,10 +1252,10 @@ impl ServerActor {
&match_distances_buffers_codes_view,
&match_distances_buffers_masks_view,
batch_streams,
&(1..=N_BUCKETS)
&(1..=self.n_buckets)
.map(|x: usize| {
Circuits::translate_threshold_a(
MATCH_THRESHOLD_RATIO / (N_BUCKETS as f64) * (x as f64),
MATCH_THRESHOLD_RATIO / (self.n_buckets as f64) * (x as f64),
) as u16
})
.collect::<Vec<_>>(),
Expand All @@ -1252,7 +1268,7 @@ impl ServerActor {

let mut results = String::new();
for i in 0..buckets.len() {
let step = MATCH_THRESHOLD_RATIO / (N_BUCKETS as f64);
let step = MATCH_THRESHOLD_RATIO / (self.n_buckets as f64);
let previous_threshold = step * (i as f64);
let threshold = step * (i as f64 + 1.0);
let previous_count = if i == 0 { 0 } else { buckets[i - 1] };
Expand Down Expand Up @@ -1613,7 +1629,7 @@ impl ServerActor {
&code_dots,
&mask_dots,
batch_size,
MATCH_DISTANCES_BUFFER_SIZE,
self.match_distances_buffer_size,
request_streams,
);
self.phase2.return_result_buffer(res);
Expand Down
8 changes: 8 additions & 0 deletions iris-mpc-gpu/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ mod e2e_test {
const INTERNAL_RNG_SEED: u64 = 0xdeadbeef;
const NUM_BATCHES: usize = 10;
const MAX_BATCH_SIZE: usize = 64;
const N_BUCKETS: usize = 10;
const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 7;
const MAX_DELETIONS_PER_BATCH: usize = 10;
const THRESHOLD_ABSOLUTE: usize = 4800; // 0.375 * 12800

Expand Down Expand Up @@ -141,6 +143,8 @@ mod e2e_test {
8,
DB_SIZE + DB_BUFFER,
MAX_BATCH_SIZE,
MATCH_DISTANCES_BUFFER_SIZE,
N_BUCKETS,
true,
false,
false,
Expand Down Expand Up @@ -170,6 +174,8 @@ mod e2e_test {
8,
DB_SIZE + DB_BUFFER,
MAX_BATCH_SIZE,
MATCH_DISTANCES_BUFFER_SIZE,
N_BUCKETS,
true,
false,
false,
Expand Down Expand Up @@ -199,6 +205,8 @@ mod e2e_test {
8,
DB_SIZE + DB_BUFFER,
MAX_BATCH_SIZE,
MATCH_DISTANCES_BUFFER_SIZE,
N_BUCKETS,
true,
false,
false,
Expand Down
2 changes: 2 additions & 0 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,8 @@ async fn server_main(config: Config) -> eyre::Result<()> {
8,
config.max_db_size,
config.max_batch_size,
config.match_distances_buffer_size,
config.n_buckets,
config.return_partial_results,
config.disable_persistence,
config.enable_debug_timing,
Expand Down

0 comments on commit a2c9626

Please sign in to comment.