From 611cf01cffe3c6edfab47afd6835b2994d61c0b2 Mon Sep 17 00:00:00 2001 From: Prikshit7766 Date: Sat, 23 Dec 2023 20:55:17 +0530 Subject: [PATCH] fix: checkpoint for multi model --- langtest/utils/checkpoints.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/langtest/utils/checkpoints.py b/langtest/utils/checkpoints.py index a77684aff..32338e996 100644 --- a/langtest/utils/checkpoints.py +++ b/langtest/utils/checkpoints.py @@ -6,21 +6,19 @@ class CheckpointManager: - _instance = None # Class variable to store the singleton instance + def __init__(self, checkpoint_folder="checkpoints"): + """Initialize the CheckpointManager. - def __new__(cls, checkpoint_folder="checkpoints"): - if cls._instance is None: - cls._instance = super(CheckpointManager, cls).__new__(cls) - # Initialize the instance only if it doesn't exist - cls._instance.checkpoint_folder = checkpoint_folder - cls._instance.complete_folder = os.path.join(checkpoint_folder, "complete") - cls._instance.remaining_folder = os.path.join(checkpoint_folder, "remaining") - - os.makedirs(cls._instance.checkpoint_folder, exist_ok=True) - os.makedirs(cls._instance.complete_folder, exist_ok=True) - os.makedirs(cls._instance.remaining_folder, exist_ok=True) + Args: + checkpoint_folder (str): The directory to store checkpoints and batch information. + """ + self.checkpoint_folder = checkpoint_folder + self.complete_folder = os.path.join(checkpoint_folder, "complete") + self.remaining_folder = os.path.join(checkpoint_folder, "remaining") - return cls._instance + os.makedirs(self.checkpoint_folder, exist_ok=True) + os.makedirs(self.complete_folder, exist_ok=True) + os.makedirs(self.remaining_folder, exist_ok=True) def save_checkpoint(self, check_point_extension: str, results_so_far: List[Sample]): """Save a checkpoint with partial results.