From 7a681ae46d78935fd901e76e92b85969bd3ab597 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Wed, 8 Nov 2023 23:37:18 +0000 Subject: [PATCH] small improvements related to persistent_local_path --- src/datatrove/executor/slurm.py | 8 ++++---- src/datatrove/io/base.py | 2 +- src/datatrove/io/fsspec.py | 2 +- src/datatrove/io/local.py | 2 +- src/datatrove/io/s3.py | 7 +++---- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/datatrove/executor/slurm.py b/src/datatrove/executor/slurm.py index c9f01146..f69e5e2a 100644 --- a/src/datatrove/executor/slurm.py +++ b/src/datatrove/executor/slurm.py @@ -129,7 +129,7 @@ def launch_job(self): launchscript_f.write( self.get_launch_file( self.sbatch_args, - f"srun -l python -u -c \"import dill;dill.load(open('{executor_f.path_in_local_disk}', 'rb')).run()\"", + f"srun -l python -u -c \"import dill;dill.load(open('{executor_f.persistent_local_path}', 'rb')).run()\"", ) ) logger.info(f'Launching Slurm job {self.job_name} with launch script "{launchscript_f.path}"') @@ -139,7 +139,7 @@ def launch_job(self): args = ["sbatch", f"--export=ALL,RUN_OFFSET={len(self.job_ids)}"] if self.dependency: args.append(f"--dependency={self.dependency}") - output = subprocess.check_output(args + [launchscript_f.path_in_local_disk]).decode("utf-8") + output = subprocess.check_output(args + [launchscript_f.persistent_local_path]).decode("utf-8") self.job_ids.append(output.split()[-1]) self.launched = True logger.info(f"Slurm job launched successfully with id(s)={','.join(self.job_ids)}.") @@ -159,8 +159,8 @@ def sbatch_args(self) -> dict: "partition": self.partition, "job-name": self.job_name, "time": self.time, - "output": slurm_logfile.path_in_local_disk, - "error": slurm_logfile.path_in_local_disk, + "output": slurm_logfile.persistent_local_path, + "error": slurm_logfile.persistent_local_path, "array": f"0-{self.max_array - 1}{f'%{self.workers}' if self.workers != -1 else ''}", **self._sbatch_args, } diff --git a/src/datatrove/io/base.py b/src/datatrove/io/base.py index 1c990179..dc65d0a7 100644 --- a/src/datatrove/io/base.py +++ b/src/datatrove/io/base.py @@ -377,7 +377,7 @@ def delete(self): @property @abstractmethod - def path_in_local_disk(self): + def persistent_local_path(self): raise NotImplementedError diff --git a/src/datatrove/io/fsspec.py b/src/datatrove/io/fsspec.py index a6867cab..ab4d6d24 100644 --- a/src/datatrove/io/fsspec.py +++ b/src/datatrove/io/fsspec.py @@ -26,7 +26,7 @@ def delete(self): self._fs.rm(self.path) @property - def path_in_local_disk(self): + def persistent_local_path(self): raise NotImplementedError diff --git a/src/datatrove/io/local.py b/src/datatrove/io/local.py index 091686b1..505659d4 100644 --- a/src/datatrove/io/local.py +++ b/src/datatrove/io/local.py @@ -24,7 +24,7 @@ def delete(self): os.remove(self.path) @property - def path_in_local_disk(self): + def persistent_local_path(self): return self.path diff --git a/src/datatrove/io/s3.py b/src/datatrove/io/s3.py index a14abf6b..66bd05bc 100644 --- a/src/datatrove/io/s3.py +++ b/src/datatrove/io/s3.py @@ -67,9 +67,7 @@ class S3OutputDataFile(BaseOutputDataFile): def __post_init__(self): if not self.path.startswith("s3://"): raise ValueError("S3OutputDataFile path must start with s3://") - if not self.local_path: - self._tmpdir = tempfile.TemporaryDirectory() - self.local_path = self._tmpdir.name + assert self.local_path is not None, "S3OutputDataFile must have a local_path" def close(self): super().close() @@ -90,7 +88,8 @@ def delete(self): os.remove(self.local_path) @property - def path_in_local_disk(self): + def persistent_local_path(self): + assert not self.folder or (not self.folder._tmpdir and not self.folder.cleanup) return self.local_path