Skip to content

Commit

Permalink
small improvements related to persistent_local_path
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Nov 8, 2023
1 parent 6b1a822 commit 7a681ae
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def launch_job(self):
launchscript_f.write(
self.get_launch_file(
self.sbatch_args,
f"srun -l python -u -c \"import dill;dill.load(open('{executor_f.path_in_local_disk}', 'rb')).run()\"",
f"srun -l python -u -c \"import dill;dill.load(open('{executor_f.persistent_local_path}', 'rb')).run()\"",
)
)
logger.info(f'Launching Slurm job {self.job_name} with launch script "{launchscript_f.path}"')
Expand All @@ -139,7 +139,7 @@ def launch_job(self):
args = ["sbatch", f"--export=ALL,RUN_OFFSET={len(self.job_ids)}"]
if self.dependency:
args.append(f"--dependency={self.dependency}")
output = subprocess.check_output(args + [launchscript_f.path_in_local_disk]).decode("utf-8")
output = subprocess.check_output(args + [launchscript_f.persistent_local_path]).decode("utf-8")
self.job_ids.append(output.split()[-1])
self.launched = True
logger.info(f"Slurm job launched successfully with id(s)={','.join(self.job_ids)}.")
Expand All @@ -159,8 +159,8 @@ def sbatch_args(self) -> dict:
"partition": self.partition,
"job-name": self.job_name,
"time": self.time,
"output": slurm_logfile.path_in_local_disk,
"error": slurm_logfile.path_in_local_disk,
"output": slurm_logfile.persistent_local_path,
"error": slurm_logfile.persistent_local_path,
"array": f"0-{self.max_array - 1}{f'%{self.workers}' if self.workers != -1 else ''}",
**self._sbatch_args,
}
Expand Down
2 changes: 1 addition & 1 deletion src/datatrove/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def delete(self):

@property
@abstractmethod
def path_in_local_disk(self):
def persistent_local_path(self):
raise NotImplementedError


Expand Down
2 changes: 1 addition & 1 deletion src/datatrove/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def delete(self):
self._fs.rm(self.path)

@property
def path_in_local_disk(self):
def persistent_local_path(self):
raise NotImplementedError


Expand Down
2 changes: 1 addition & 1 deletion src/datatrove/io/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def delete(self):
os.remove(self.path)

@property
def path_in_local_disk(self):
def persistent_local_path(self):
return self.path


Expand Down
7 changes: 3 additions & 4 deletions src/datatrove/io/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class S3OutputDataFile(BaseOutputDataFile):
def __post_init__(self):
if not self.path.startswith("s3://"):
raise ValueError("S3OutputDataFile path must start with s3://")
if not self.local_path:
self._tmpdir = tempfile.TemporaryDirectory()
self.local_path = self._tmpdir.name
assert self.local_path is not None, "S3OutputDataFile must have a local_path"

def close(self):
super().close()
Expand All @@ -90,7 +88,8 @@ def delete(self):
os.remove(self.local_path)

@property
def path_in_local_disk(self):
def persistent_local_path(self):
assert not self.folder or (not self.folder._tmpdir and not self.folder.cleanup)
return self.local_path


Expand Down

0 comments on commit 7a681ae

Please sign in to comment.