Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make batch size configurable #326

Merged
merged 4 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.mpc1.dist
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SMPC__REQUESTS_QUEUE_URL=https://sqs.eu-north-1.amazonaws.com/654654380399/smpc0
SMPC__RESULTS_TOPIC_ARN=arn:aws:sns:eu-north-1:767397716933:gpu-iris-mpc-results.fifo
SMPC__PROCESSING_TIMEOUT_SECS=60
SMPC__PUBLIC_KEY_BASE_URL=https://d2k2ck8dyw4s60.cloudfront.net
SMPC__MAX_BATCH_SIZE=64

# These can be either ARNs or IDs, in production multi account setup they are ARNs
SMPC__KMS_KEY_ARNS='["arn:aws:kms:eu-north-1:654654380399:key/a7dd6e20-18cb-4e72-8e1a-52de262affb6", "arn:aws:kms:eu-north-1:590183962074:key/ac3cfd34-e170-4f3d-bac2-979a791ccc3f", "arn:aws:kms:eu-north-1:767398084154:key/8f013838-b18f-46b6-8628-d3fd4b72243c"]'
Expand Down
1 change: 1 addition & 0 deletions .env.mpc2.dist
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SMPC__REQUESTS_QUEUE_URL=https://sqs.eu-north-1.amazonaws.com/590183962074/smpc1
SMPC__RESULTS_TOPIC_ARN=arn:aws:sns:eu-north-1:767397716933:gpu-iris-mpc-results.fifo
SMPC__PROCESSING_TIMEOUT_SECS=60
SMPC__PUBLIC_KEY_BASE_URL=https://d2k2ck8dyw4s60.cloudfront.net
SMPC__MAX_BATCH_SIZE=64

# These can be either ARNs or IDs, in production multi account setup they are ARNs
SMPC__KMS_KEY_ARNS='["arn:aws:kms:eu-north-1:654654380399:key/a7dd6e20-18cb-4e72-8e1a-52de262affb6", "arn:aws:kms:eu-north-1:590183962074:key/ac3cfd34-e170-4f3d-bac2-979a791ccc3f", "arn:aws:kms:eu-north-1:767398084154:key/8f013838-b18f-46b6-8628-d3fd4b72243c"]'
Expand Down
1 change: 1 addition & 0 deletions .env.mpc3.dist
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SMPC__REQUESTS_QUEUE_URL=https://sqs.eu-north-1.amazonaws.com/767398084154/smpc2
SMPC__RESULTS_TOPIC_ARN=arn:aws:sns:eu-north-1:767397716933:gpu-iris-mpc-results.fifo
SMPC__PROCESSING_TIMEOUT_SECS=60
SMPC__PUBLIC_KEY_BASE_URL=https://d2k2ck8dyw4s60.cloudfront.net
SMPC__MAX_BATCH_SIZE=64

# These can be either ARNs or IDs, in production multi account setup they are ARNs
SMPC__KMS_KEY_ARNS='["arn:aws:kms:eu-north-1:654654380399:key/a7dd6e20-18cb-4e72-8e1a-52de262affb6", "arn:aws:kms:eu-north-1:590183962074:key/ac3cfd34-e170-4f3d-bac2-979a791ccc3f", "arn:aws:kms:eu-north-1:767398084154:key/8f013838-b18f-46b6-8628-d3fd4b72243c"]'
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ pub struct Config {

#[serde(default)]
pub init_db_size: usize,

#[serde(default)]
pub max_batch_size: usize,
}

fn default_processing_timeout_secs() -> u64 {
Expand Down
1 change: 1 addition & 0 deletions iris-mpc-gpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
bincode = "1.3.3"
cudarc = { version = "0.12", features = ["cuda-12020", "nccl"] }
eyre.workspace = true
tracing.workspace = true
Expand Down
59 changes: 36 additions & 23 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{BatchQuery, Eye, ServerJob, ServerJobResult, MAX_BATCH_SIZE};
use super::{BatchQuery, Eye, ServerJob, ServerJobResult};
use crate::{
dot::{
distance_comparator::DistanceComparator,
Expand Down Expand Up @@ -61,7 +61,6 @@ impl ServerActorHandle {
}

const DB_CHUNK_SIZE: usize = 512;
const QUERIES: usize = ROTATIONS * MAX_BATCH_SIZE;
pub struct ServerActor {
job_queue: mpsc::Receiver<ServerJob>,
device_manager: Arc<DeviceManager>,
Expand Down Expand Up @@ -91,6 +90,7 @@ pub struct ServerActor {
batch_match_list_right: Vec<CudaSlice<u64>>,
current_db_sizes: Vec<usize>,
query_db_size: Vec<usize>,
max_batch_size: usize,
}

const NON_MATCH_ID: u32 = u32::MAX;
Expand All @@ -105,6 +105,7 @@ impl ServerActor {
job_queue_size: usize,
db_size: usize,
db_buffer: usize,
max_batch_size: usize,
) -> eyre::Result<(Self, ServerActorHandle)> {
let device_manager = Arc::new(DeviceManager::init());
Self::new_with_device_manager(
Expand All @@ -116,6 +117,7 @@ impl ServerActor {
job_queue_size,
db_size,
db_buffer,
max_batch_size,
)
}
#[allow(clippy::too_many_arguments)]
Expand All @@ -128,6 +130,7 @@ impl ServerActor {
job_queue_size: usize,
db_size: usize,
db_buffer: usize,
max_batch_size: usize,
) -> eyre::Result<(Self, ServerActorHandle)> {
let ids = device_manager.get_ids_from_magic(0);
let comms = device_manager.instantiate_network_from_ids(party_id, &ids)?;
Expand All @@ -141,6 +144,7 @@ impl ServerActor {
job_queue_size,
db_size,
db_buffer,
max_batch_size,
)
}

Expand All @@ -155,6 +159,7 @@ impl ServerActor {
job_queue_size: usize,
db_size: usize,
db_buffer: usize,
max_batch_size: usize,
) -> eyre::Result<(Self, ServerActorHandle)> {
assert!(
[left_eye_db.0.len(), right_eye_db.0.len(),]
Expand Down Expand Up @@ -189,6 +194,7 @@ impl ServerActor {
rx,
db_size,
db_buffer,
max_batch_size,
)?;
Ok((actor, ServerActorHandle { job_queue: tx }))
}
Expand All @@ -204,9 +210,11 @@ impl ServerActor {
job_queue: mpsc::Receiver<ServerJob>,
db_size: usize,
db_buffer: usize,
max_batch_size: usize,
) -> eyre::Result<Self> {
let mut kdf_nonce = 0;
let kdf_salt: Salt = Salt::new(HKDF_SHA256, b"IRIS_MPC");
let n_queries = max_batch_size * ROTATIONS;

// helper closure to generate the next chacha seeds
let mut next_chacha_seeds =
Expand All @@ -226,7 +234,7 @@ impl ServerActor {
party_id,
device_manager.clone(),
DB_CHUNK_SIZE,
QUERIES,
n_queries,
IRIS_CODE_LENGTH,
next_chacha_seeds(chacha_seeds)?,
comms.clone(),
Expand All @@ -236,7 +244,7 @@ impl ServerActor {
party_id,
device_manager.clone(),
DB_CHUNK_SIZE,
QUERIES,
n_queries,
MASK_CODE_LENGTH,
next_chacha_seeds(chacha_seeds)?,
comms.clone(),
Expand Down Expand Up @@ -264,8 +272,8 @@ impl ServerActor {
let batch_codes_engine = ShareDB::init(
party_id,
device_manager.clone(),
QUERIES,
QUERIES,
n_queries,
n_queries,
IRIS_CODE_LENGTH,
next_chacha_seeds(chacha_seeds)?,
comms.clone(),
Expand All @@ -274,19 +282,19 @@ impl ServerActor {
let batch_masks_engine = ShareDB::init(
party_id,
device_manager.clone(),
QUERIES,
QUERIES,
n_queries,
n_queries,
MASK_CODE_LENGTH,
next_chacha_seeds(chacha_seeds)?,
comms.clone(),
);

// Phase 2 Setup
let phase2_chunk_size = QUERIES * DB_CHUNK_SIZE;
let phase2_chunk_size = n_queries * DB_CHUNK_SIZE;

// Not divided by GPU_COUNT since we do the work on all GPUs for simplicity,
// also not padded to 2048 since we only require it to be a multiple of 64
let phase2_batch_chunk_size = QUERIES * QUERIES;
let phase2_batch_chunk_size = n_queries * n_queries;
assert_eq!(
phase2_batch_chunk_size % 64,
0,
Expand Down Expand Up @@ -316,7 +324,7 @@ impl ServerActor {
comms.clone(),
);

let distance_comparator = DistanceComparator::init(QUERIES, device_manager.clone());
let distance_comparator = DistanceComparator::init(n_queries, device_manager.clone());
// Prepare streams etc.
let mut streams = vec![];
let mut cublas_handles = vec![];
Expand All @@ -334,10 +342,10 @@ impl ServerActor {
.prepare_db_match_list((db_size + db_buffer) / device_manager.device_count());
let db_match_list_right = distance_comparator
.prepare_db_match_list((db_size + db_buffer) / device_manager.device_count());
let batch_match_list_left = distance_comparator.prepare_db_match_list(QUERIES);
let batch_match_list_right = distance_comparator.prepare_db_match_list(QUERIES);
let batch_match_list_left = distance_comparator.prepare_db_match_list(n_queries);
let batch_match_list_right = distance_comparator.prepare_db_match_list(n_queries);

let query_db_size = vec![QUERIES; device_manager.device_count()];
let query_db_size = vec![n_queries; device_manager.device_count()];

for dev in device_manager.devices() {
dev.synchronize().unwrap();
Expand Down Expand Up @@ -370,6 +378,7 @@ impl ServerActor {
db_match_list_right,
batch_match_list_left,
batch_match_list_right,
max_batch_size,
})
}

Expand All @@ -394,7 +403,7 @@ impl ServerActor {

let mut batch = batch;
let mut batch_size = batch.store_left.code.len();
assert!(batch_size > 0 && batch_size <= MAX_BATCH_SIZE);
assert!(batch_size > 0 && batch_size <= self.max_batch_size);
assert!(
batch_size == batch.store_left.mask.len()
&& batch_size == batch.request_ids.len()
Expand Down Expand Up @@ -470,12 +479,12 @@ impl ServerActor {
};
let query_store_left = batch.store_left;

// THIS needs to be MAX_BATCH_SIZE, even though the query can be shorter to have
// THIS needs to be max_batch_size, even though the query can be shorter to have
// enough padding for GEMM
let compact_device_queries_left = compact_query_left.htod_transfer(
&self.device_manager,
&self.streams[0],
MAX_BATCH_SIZE,
self.max_batch_size,
)?;

let compact_device_sums_left = compact_device_queries_left.query_sums(
Expand Down Expand Up @@ -546,7 +555,7 @@ impl ServerActor {
let compact_device_queries_right = compact_query_right.htod_transfer(
&self.device_manager,
&self.streams[0],
MAX_BATCH_SIZE,
self.max_batch_size,
)?;

let compact_device_sums_right = compact_device_queries_right.query_sums(
Expand Down Expand Up @@ -798,7 +807,9 @@ impl ServerActor {
tracing::info!(
"Batch took {:?} [{:.2} Melems/s]",
now.elapsed(),
(MAX_BATCH_SIZE * previous_total_db_size) as f64 / now.elapsed().as_secs_f64() / 1e6
(self.max_batch_size * previous_total_db_size) as f64
/ now.elapsed().as_secs_f64()
/ 1e6
);
Ok(())
}
Expand Down Expand Up @@ -859,7 +870,8 @@ impl ServerActor {
}
);

let db_sizes_batch = vec![QUERIES; self.device_manager.device_count()];
let db_sizes_batch =
vec![self.max_batch_size * ROTATIONS; self.device_manager.device_count()];
let code_dots_batch = self.batch_codes_engine.result_chunk_shares(&db_sizes_batch);
let mask_dots_batch = self.batch_masks_engine.result_chunk_shares(&db_sizes_batch);

Expand Down Expand Up @@ -1017,11 +1029,12 @@ impl ServerActor {
let mask_dots = self.masks_engine.result_chunk_shares(&phase_2_chunk_sizes);
{
assert_eq!(
(max_chunk_size * QUERIES) % 64,
(max_chunk_size * self.max_batch_size * ROTATIONS) % 64,
0,
"Phase 2 input size must be a multiple of 64"
);
self.phase2.set_chunk_size(max_chunk_size * QUERIES / 64);
self.phase2
.set_chunk_size(max_chunk_size * self.max_batch_size * ROTATIONS / 64);
record_stream_time!(
&self.device_manager,
batch_streams,
Expand Down Expand Up @@ -1052,7 +1065,7 @@ impl ServerActor {
&res,
&self.distance_comparator,
db_match_bitmap,
max_chunk_size * QUERIES / 64,
max_chunk_size * self.max_batch_size * ROTATIONS / 64,
&dot_chunk_size,
&chunk_size,
offset,
Expand Down
2 changes: 0 additions & 2 deletions iris-mpc-gpu/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use iris_mpc_common::galois_engine::degree4::{
use std::collections::HashSet;
use tokio::sync::oneshot;

pub const MAX_BATCH_SIZE: usize = 64;

#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub struct BatchQueryEntries {
pub code: Vec<GaloisRingIrisCodeShare>,
Expand Down
28 changes: 8 additions & 20 deletions iris-mpc-gpu/src/server/sync_nccl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,25 @@ pub fn sync(comm: &NcclComm, state: &SyncState) -> Result<SyncResult> {
}

// Change these parameters together - see unittests below.
/// Maximum number of requests in SyncState.
pub const MAX_REQUESTS: usize = 128;
/// The fixed serialization size of SyncState.
const SERIAL_SIZE: usize = 16384;
pub const MAX_REQUESTS: usize = 256 * 2;
const MAX_REQUEST_ID_LEN: usize = 36; // uuidv4 string
const SERIAL_SIZE: usize =
MAX_REQUESTS * (size_of::<usize>() + MAX_REQUEST_ID_LEN) + 2 * size_of::<usize>();

/// Serialize the state to a fixed-size buffer suitable for all_gather.
fn serialize(state: &SyncState) -> Result<Vec<u8>> {
let mut state_ser = vec![0; 8];
serde_json::to_writer(&mut state_ser, state)?;
// Frame with the buffer length.
let buf_len = state_ser.len();
if buf_len > SERIAL_SIZE {
let mut state_ser = bincode::serialize(state)?;
if state_ser.len() > SERIAL_SIZE {
return Err(eyre!("State too large to serialize"));
}
state_ser[..8].copy_from_slice(&(buf_len as u64).to_le_bytes());
// Pad to fixed size.
state_ser.resize(SERIAL_SIZE, 0);
state_ser.extend(std::iter::repeat(0).take(SERIAL_SIZE - state_ser.len()));
Ok(state_ser)
}

/// Deserialize the state from a fixed-size buffer.
fn deserialize(state_ser: &[u8]) -> Result<SyncState> {
// Unframe the buffer.
let buf_len = u64::from_le_bytes(state_ser[..8].try_into().unwrap()) as usize;
if buf_len > SERIAL_SIZE {
return Err(eyre!("State too large to deserialize"));
}
let state = serde_json::from_slice(&state_ser[8..buf_len])?;
Ok(state)
Ok(bincode::deserialize(state_ser)?)
}

/// Deserialize all states concatenated in a buffer (the output of all_gather).
Expand All @@ -67,8 +57,6 @@ mod tests {
#[test]
#[cfg(feature = "gpu_dependent")]
fn test_serialize() -> Result<()> {
// Make sure we can serialize enough request IDs assuming a maximum length.
const MAX_REQUEST_ID_LEN: usize = 100;
// My state.
let state = SyncState {
db_len: 123,
Expand Down
6 changes: 5 additions & 1 deletion iris-mpc-gpu/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use iris_mpc_common::{
};
use iris_mpc_gpu::{
helpers::device_manager::DeviceManager,
server::{BatchQuery, ServerActor, ServerJobResult, MAX_BATCH_SIZE},
server::{BatchQuery, ServerActor, ServerJobResult},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::{collections::HashMap, env, sync::Arc};
Expand All @@ -19,6 +19,7 @@ const DB_BUFFER: usize = 8 * 1000;
const DB_RNG_SEED: u64 = 0xdeadbeef;
const INTERNAL_RNG_SEED: u64 = 0xdeadbeef;
const NUM_BATCHES: usize = 10;
const MAX_BATCH_SIZE: usize = 64;

fn generate_db(party_id: usize) -> Result<(Vec<u16>, Vec<u16>)> {
let mut rng = StdRng::seed_from_u64(DB_RNG_SEED);
Expand Down Expand Up @@ -113,6 +114,7 @@ async fn e2e_test() -> Result<()> {
8,
DB_SIZE,
DB_BUFFER,
MAX_BATCH_SIZE,
) {
Ok((actor, handle)) => {
tx0.send(Ok(handle)).unwrap();
Expand All @@ -139,6 +141,7 @@ async fn e2e_test() -> Result<()> {
8,
DB_SIZE,
DB_BUFFER,
MAX_BATCH_SIZE,
) {
Ok((actor, handle)) => {
tx1.send(Ok(handle)).unwrap();
Expand All @@ -165,6 +168,7 @@ async fn e2e_test() -> Result<()> {
8,
DB_SIZE,
DB_BUFFER,
MAX_BATCH_SIZE,
) {
Ok((actor, handle)) => {
tx2.send(Ok(handle)).unwrap();
Expand Down
Loading
Loading