diff --git a/python/pylibwholegraph/pylibwholegraph/torch/comm.py b/python/pylibwholegraph/pylibwholegraph/torch/comm.py index c7cca2e7b..cc27fa41e 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/comm.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/comm.py @@ -31,6 +31,17 @@ all_comm_local_rank = 0 all_comm_local_size = 1 +def reset_communicators(): + global all_comm_world_rank, all_comm_world_size, all_comm_local_rank, all_comm_local_size + global global_communicators, local_node_communicator, local_device_communicator + global_communicators = {} + local_node_communicator = None + local_device_communicator = None + + all_comm_world_rank = 0 + all_comm_world_size = 1 + all_comm_local_rank = 0 + all_comm_local_size = 1 def set_world_info(world_rank: int, world_size: int, local_rank: int, local_size: int): """ diff --git a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py index 3e1238c2f..339bd0492 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py @@ -15,7 +15,7 @@ import torch import torch.utils.dlpack import pylibwholegraph.binding.wholememory_binding as wmb -from .comm import set_world_info, get_global_communicator, get_local_node_communicator +from .comm import set_world_info, get_global_communicator, get_local_node_communicator, reset_communicators def init(world_rank: int, world_size: int, local_rank: int, local_size: int): @@ -73,3 +73,4 @@ def finalize(): :return: None """ wmb.finalize() + reset_communicators()