From fad8d51b85a429cb280a34143f01e508ee8c825f Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Sat, 25 Jan 2025 17:15:45 +0100 Subject: [PATCH] revert nccl changes --- Cargo.lock | 5 ++--- iris-mpc-gpu/Cargo.toml | 5 ++++- iris-mpc-gpu/src/bin/nccl.rs | 12 +++--------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index acb6d671c..0824f6878 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1473,9 +1473,8 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712390a76a83255b3d02760b8b3b3b7cb518f1f7cd83aafbbd65c382a7d6e8ef" +version = "0.12.2" +source = "git+https://github.com/worldcoin/cudarc-fork.git#b6ba8da720217736d262c43d39260c7380530567" dependencies = [ "libloading", ] diff --git a/iris-mpc-gpu/Cargo.toml b/iris-mpc-gpu/Cargo.toml index 726483ef9..33e016bb1 100644 --- a/iris-mpc-gpu/Cargo.toml +++ b/iris-mpc-gpu/Cargo.toml @@ -9,7 +9,10 @@ repository.workspace = true [dependencies] bincode = "1.3.3" -cudarc = { version = "0.13.3", features = ["cuda-12020", "nccl"] } +cudarc = { git = "https://github.com/worldcoin/cudarc-fork.git", features = [ + "cuda-12020", + "nccl", +] } eyre.workspace = true tracing.workspace = true bytemuck.workspace = true diff --git a/iris-mpc-gpu/src/bin/nccl.rs b/iris-mpc-gpu/src/bin/nccl.rs index ec27fe5c8..81248e555 100644 --- a/iris-mpc-gpu/src/bin/nccl.rs +++ b/iris-mpc-gpu/src/bin/nccl.rs @@ -85,15 +85,9 @@ async fn main() -> eyre::Result<()> { for i in 0..n_devices { devs[i].bind_to_thread().unwrap(); - comms[i] - .broadcast(slices[i].as_ref(), &mut slices1[i], 0) - .unwrap(); - comms[i] - .broadcast(slices[i].as_ref(), &mut slices2[i], 1) - .unwrap(); - comms[i] - .broadcast(slices[i].as_ref(), &mut slices3[i], 2) - .unwrap(); + comms[i].broadcast(&slices[i], &mut slices1[i], 0).unwrap(); + comms[i].broadcast(&slices[i], &mut slices2[i], 1).unwrap(); + comms[i].broadcast(&slices[i], &mut slices3[i], 2).unwrap(); } for dev in devs.iter() {