diff --git a/i2_client/client.py b/i2_client/client.py index 84cc4af..45cec62 100644 --- a/i2_client/client.py +++ b/i2_client/client.py @@ -13,6 +13,7 @@ import archipel_utils as utils import msgpack +import nest_asyncio import websockets log = logging.getLogger(__name__) @@ -204,4 +205,12 @@ async def _inference(self, inputs): await self.__aexit__(exc_type=None, exc_value=None, traceback=None) return outputs + try: + loop = asyncio.get_running_loop() + if loop.is_running(): + # when executed in jupyter notebook or something + nest_asyncio.apply(loop) + except RuntimeError: + pass + return asyncio.run(_inference(self, inputs)) diff --git a/setup.py b/setup.py index 0cf7379..53de557 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ "imutils>=0.5.4", "msgpack>=1.0", "numpy>=1.19", + "nest-asyncio>=1.5", "rich>=10.13", "websockets>=8.1", "opencv-python==4.6.0.66",