From e0d1c4aca41804c7f78c41f3fb0de7340ec54c8b Mon Sep 17 00:00:00 2001
From: Olli Saarikivi <olsaarik@microsoft.com>
Date: Thu, 19 Oct 2023 18:02:18 +0000
Subject: [PATCH] Update setup example to use cupy

---
 docs/setup_example.ipynb | 20 +++++++++-----------
 1 file changed, 9 insertions(+), 11 deletions(-)

diff --git a/docs/setup_example.ipynb b/docs/setup_example.ipynb
index 743883b4f..f8cefa289 100644
--- a/docs/setup_example.ipynb
+++ b/docs/setup_example.ipynb
@@ -12,18 +12,15 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
     "import mscclpp\n",
-    "import torch\n",
     "\n",
     "def setup_channels(comm, memory, proxy_service):\n",
     "    # Register the memory with the communicator\n",
-    "    ptr = memory.data_ptr()\n",
-    "    size = memory.numel() * memory.element_size()\n",
-    "    reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.CudaIpc)\n",
+    "    reg_mem = comm.register_memory(memory.data.ptr, memory.nbytes, mscclpp.Transport.CudaIpc)\n",
     "\n",
     "    # Create connections to all other ranks and exchange registered memories\n",
     "    connections = []\n",
@@ -60,17 +57,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
+    "import cupy as cp\n",
+    "\n",
     "def run(rank, world_size, if_ip_port_trio):\n",
     "    # Use the right GPU for this rank\n",
-    "    torch.cuda.set_device(rank)\n",
+    "    cp.cuda.Device(rank).use()\n",
     "    \n",
     "    # Allocate memory on the GPU\n",
-    "    memory = torch.zeros(1024, dtype=torch.int32)\n",
-    "    memory = memory.to(\"cuda\")\n",
+    "    memory = cp.zeros(1024, dtype=cp.int32)\n",
     "\n",
     "    # Initialize a bootstrapper using a known interface/IP/port trio for the root rank\n",
     "    boot = mscclpp.TcpBootstrap.create(rank, world_size)\n",
@@ -109,7 +107,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
@@ -155,7 +153,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.10"
+   "version": "3.11.5"
   },
   "orig_nbformat": 4
  },