Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
rw0x0 committed Jan 22, 2025
1 parent fd6d3ec commit 0d6be5d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
20 changes: 10 additions & 10 deletions iris-mpc-gpu/tests/or_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,23 @@ mod or_tree_test {
}

fn open(party: &mut Circuits, result: &mut ChunkShare<u64>, streams: &[CudaStream]) -> bool {
let res = result.get_offset(0, 1);
let mut res_helper = result.get_offset(1, 1);
let dev = party.get_devices()[0].clone();
let stream = &streams[0];

let mut res = result.get_offset(0, 1);
let a = dtoh_on_stream_sync(&res.a, &dev, stream).unwrap();
let b = dtoh_on_stream_sync(&res.b, &dev, stream).unwrap();

cudarc::nccl::result::group_start().expect("group start should work");
party.comms()[0]
.send_view(&res.b, party.next_id(), &streams[0])
.send_view(&res.b, party.next_id(), stream)
.unwrap();
party.comms()[0]
.receive_view(&mut res_helper.a, party.prev_id(), &streams[0])
.receive_view(&mut res.a, party.prev_id(), stream)
.unwrap();
cudarc::nccl::result::group_end().expect("group end should work");

let dev = party.get_devices()[0].clone();
let stream = &streams[0];

let a = dtoh_on_stream_sync(&res.a, &dev, stream).unwrap();
let b = dtoh_on_stream_sync(&res.b, &dev, stream).unwrap();
let c = dtoh_on_stream_sync(&res_helper.a, &dev, stream).unwrap();
let c = dtoh_on_stream_sync(&res.a, &dev, stream).unwrap();
assert_eq!(a.len(), 1);
assert_eq!(b.len(), 1);
assert_eq!(c.len(), 1);
Expand Down
24 changes: 12 additions & 12 deletions iris-mpc-gpu/tests/threshold_and_or_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ mod test_threshold_and_or_tree_test {

const DB_RNG_SEED: u64 = 0xdeadbeef;
// ceil(930 * 125_000 / 2048) * 2048
const INPUTS_PER_GPU_SIZE: usize = 116_250_624;
// const INPUTS_PER_GPU_SIZE: usize = 12_507_136;
// const INPUTS_PER_GPU_SIZE: usize = 116_250_624;
const INPUTS_PER_GPU_SIZE: usize = 12_507_136;
const B_BITS: u64 = 16;
pub(crate) const B: u64 = 1 << B_BITS;
pub(crate) const A: u64 = ((1. - 2. * MATCH_THRESHOLD_RATIO) * B as f64) as u64;
Expand Down Expand Up @@ -104,23 +104,23 @@ mod test_threshold_and_or_tree_test {
}

fn open(party: &mut Circuits, result: &mut ChunkShare<u64>, streams: &[CudaStream]) -> bool {
let res = result.get_offset(0, 1);
let mut res_helper = result.get_offset(1, 1);
let dev = party.get_devices()[0].clone();
let stream = &streams[0];

let mut res = result.get_offset(0, 1);
let a = dtoh_on_stream_sync(&res.a, &dev, stream).unwrap();
let b = dtoh_on_stream_sync(&res.b, &dev, stream).unwrap();

cudarc::nccl::result::group_start().expect("group start should work");
party.comms()[0]
.send_view(&res.b, party.next_id(), &streams[0])
.send_view(&res.b, party.next_id(), stream)
.unwrap();
party.comms()[0]
.receive_view(&mut res_helper.a, party.prev_id(), &streams[0])
.receive_view(&mut res.a, party.prev_id(), stream)
.unwrap();
cudarc::nccl::result::group_end().expect("group end should work");

let dev = party.get_devices()[0].clone();
let stream = &streams[0];

let a = dtoh_on_stream_sync(&res.a, &dev, stream).unwrap();
let b = dtoh_on_stream_sync(&res.b, &dev, stream).unwrap();
let c = dtoh_on_stream_sync(&res_helper.a, &dev, stream).unwrap();
let c = dtoh_on_stream_sync(&res.a, &dev, stream).unwrap();
assert_eq!(a.len(), 1);
assert_eq!(b.len(), 1);
assert_eq!(c.len(), 1);
Expand Down

0 comments on commit 0d6be5d

Please sign in to comment.