Skip to content

Commit

Permalink
fix max recursion depth on slurm dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Nov 2, 2023
1 parent f6e8707 commit 98ec10d
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
self._sbatch_args = sbatch_args if sbatch_args else {}
self.max_array_size = max_array_size
self.job_ids = []
self.depends_job_ids = []

def run(self):
if "SLURM_ARRAY_TASK_ID" in os.environ:
Expand Down Expand Up @@ -98,8 +99,8 @@ def launch_merge_stats(self):
@property
def dependency(self):
dependency = []
if self.depends:
dependency.append(f"afterok:{','.join(self.depends.job_ids)}")
if self.depends_job_ids:
dependency.append(f"afterok:{','.join(self.depends_job_ids)}")
if self.job_ids:
dependency.append(f"afterany:{self.job_ids[-1]}")
return ",".join(dependency)
Expand All @@ -112,9 +113,12 @@ def launch_job(self):
assert not self.depends or (
isinstance(self.depends, SlurmPipelineExecutor)
), "depends= must be a SlurmPipelineExecutor"
if self.depends and not self.depends.job_ids:
logger.info(f'Launching dependency job "{self.depends.job_name}"')
self.depends.launch_job()
if self.depends:
if not self.depends.job_ids:
logger.info(f'Launching dependency job "{self.depends.job_name}"')
self.depends.launch_job()
self.depends_job_ids = self.depends.job_ids
self.depends = None # avoid pickling the entire dependency and possibly its dependencies

# pickle
with open(os.path.join(self.logging_dir, "executor.pik"), "wb") as f:
Expand Down

0 comments on commit 98ec10d

Please sign in to comment.