From e0d1c4aca41804c7f78c41f3fb0de7340ec54c8b Mon Sep 17 00:00:00 2001 From: Olli Saarikivi <olsaarik@microsoft.com> Date: Thu, 19 Oct 2023 18:02:18 +0000 Subject: [PATCH] Update setup example to use cupy --- docs/setup_example.ipynb | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/docs/setup_example.ipynb b/docs/setup_example.ipynb index 743883b4f..f8cefa289 100644 --- a/docs/setup_example.ipynb +++ b/docs/setup_example.ipynb @@ -12,18 +12,15 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import mscclpp\n", - "import torch\n", "\n", "def setup_channels(comm, memory, proxy_service):\n", " # Register the memory with the communicator\n", - " ptr = memory.data_ptr()\n", - " size = memory.numel() * memory.element_size()\n", - " reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.CudaIpc)\n", + " reg_mem = comm.register_memory(memory.data.ptr, memory.nbytes, mscclpp.Transport.CudaIpc)\n", "\n", " # Create connections to all other ranks and exchange registered memories\n", " connections = []\n", @@ -60,17 +57,18 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ + "import cupy as cp\n", + "\n", "def run(rank, world_size, if_ip_port_trio):\n", " # Use the right GPU for this rank\n", - " torch.cuda.set_device(rank)\n", + " cp.cuda.Device(rank).use()\n", " \n", " # Allocate memory on the GPU\n", - " memory = torch.zeros(1024, dtype=torch.int32)\n", - " memory = memory.to(\"cuda\")\n", + " memory = cp.zeros(1024, dtype=cp.int32)\n", "\n", " # Initialize a bootstrapper using a known interface/IP/port trio for the root rank\n", " boot = mscclpp.TcpBootstrap.create(rank, world_size)\n", @@ -109,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -155,7 +153,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.11.5" }, "orig_nbformat": 4 },