Skip to content

Commit

Permalink
fix(runner): output dir property as a guard
Browse files Browse the repository at this point in the history
  • Loading branch information
brian-mulier-p committed Apr 9, 2024
1 parent ba22cad commit a680653
Showing 1 changed file with 52 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,34 +127,37 @@ public class AzureBatchTaskRunner extends TaskRunner implements AbstractBatchInt
private ContainerRegistry registry;

@Override
public RunnerResult run(RunContext runContext, TaskCommands commandsWrapper, List<String> filesToUpload, List<String> filesToDownload) throws Exception {
public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<String> filesToUpload, List<String> filesToDownload) throws Exception {
boolean hasBlobStorage = blobStorage != null && blobStorage.valid();

boolean hasFilesToUpload = !ListUtils.isEmpty(filesToUpload);
if (hasFilesToUpload && !hasBlobStorage) {
throw new IllegalArgumentException("You must provide a way to connect to a Blob Storage container to use `inputFiles` or `namespaceFiles`");
}
boolean hasFilesToDownload = !ListUtils.isEmpty(filesToDownload);
if (hasFilesToDownload && !hasBlobStorage) {
throw new IllegalArgumentException("You must provide a way to connect to a Blob Storage container to use `outputFiles`");
boolean outputDirectoryEnabled = taskCommands.outputDirectoryEnabled();
if ((hasFilesToDownload || outputDirectoryEnabled) && !hasBlobStorage) {
throw new IllegalArgumentException("You must provide a way to connect to a Blob Storage container to use `outputFiles` or `{{ outputDir }}`");
}

Map<String, Object> additionalVars = this.additionalVars(runContext, commandsWrapper);
Map<String, Object> additionalVars = this.additionalVars(runContext, taskCommands);
Path outputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR);
String blobStorageWdir = additionalVars.get(ScriptService.VAR_BUCKET_PATH).toString();

String jobId = ScriptService.jobName(runContext);
List<ResourceFile> resourceFiles = new ArrayList<>();
if (hasBlobStorage) {
String relativeOutputDirectoryMarkerPath = outputDirectory + "/.kestradirectory";
File outputDirectoryMarker = runContext.resolve(Path.of(relativeOutputDirectoryMarkerPath)).toFile();
outputDirectoryMarker.getParentFile().mkdirs();
outputDirectoryMarker.createNewFile();
if (hasFilesToUpload || outputDirectoryEnabled) {
List<String> filesToUploadWithOutputDir = new ArrayList<>(filesToUpload);
if (outputDirectoryEnabled) {
String relativeOutputDirectoryMarkerPath = outputDirectory + "/.kestradirectory";
File outputDirectoryMarker = runContext.resolve(Path.of(relativeOutputDirectoryMarkerPath)).toFile();
outputDirectoryMarker.getParentFile().mkdirs();
outputDirectoryMarker.createNewFile();
filesToUploadWithOutputDir.add(relativeOutputDirectoryMarkerPath);
}

BlobContainerClient blobContainerClient = blobStorage.blobContainerClient(runContext);

List<String> filesToUploadWithOutputDir = new ArrayList<>(filesToUpload);
filesToUploadWithOutputDir.add(relativeOutputDirectoryMarkerPath);
filesToUploadWithOutputDir.stream().map(throwFunction(file -> {
// Use path to eventually deduplicate leading '/'
String blobName = blobStorageWdir + Path.of("/" + file);
Expand Down Expand Up @@ -182,9 +185,35 @@ public RunnerResult run(RunContext runContext, TaskCommands commandsWrapper, Lis
})).forEach(resourceFiles::add);
}

AbstractLogConsumer logConsumer = commandsWrapper.getLogConsumer();
AbstractLogConsumer logConsumer = taskCommands.getLogConsumer();

List<String> commands = taskCommands.getCommands();
Task.TaskBuilder taskBuilder = Task.builder()
.id("task-" + jobId)
.constraints(
TaskConstraints.builder()
.maxWallClockTime(this.waitUntilCompletion)
.maxTaskRetryCount(0)
.build()
)
.interpreter(commands.get(0))
.interpreterArgs(commands.size() > 1 ? new String[]{commands.get(1)} : new String[0])
.commands(commands.size() > 2 ? commands.subList(2, commands.size()) : Collections.emptyList())
.resourceFiles(resourceFiles)
.outputFiles(filesToDownload)
.containerSettings(
TaskContainerSettings.builder()
.workingDirectory(ContainerWorkingDirectory.TASK_WORKING_DIRECTORY)
.registry(registry)
.imageName(taskCommands.getContainerImage())
.build()
)
.environments(this.env(runContext, taskCommands));

if (outputDirectoryEnabled) {
taskBuilder.outputDirs(List.of(additionalVars.get(ScriptService.VAR_OUTPUT_DIR).toString()));
}

List<String> commands = commandsWrapper.getCommands();
Create createJob = Create.builder()
.id("create")
.type(Create.class.getName())
Expand All @@ -199,31 +228,7 @@ public RunnerResult run(RunContext runContext, TaskCommands commandsWrapper, Lis
.labels(ScriptService.labels(runContext, "kestra-", true, true))
.build()
)
.tasks(List.of(
Task.builder()
.id("task-" + jobId)
.constraints(
TaskConstraints.builder()
.maxWallClockTime(this.waitUntilCompletion)
.maxTaskRetryCount(0)
.build()
)
.interpreter(commands.get(0))
.interpreterArgs(commands.size() > 1 ? new String[]{commands.get(1)} : new String[0])
.commands(commands.size() > 2 ? commands.subList(2, commands.size()) : Collections.emptyList())
.resourceFiles(resourceFiles)
.outputFiles(filesToDownload)
.outputDirs(List.of(additionalVars.get(ScriptService.VAR_OUTPUT_DIR).toString()))
.containerSettings(
TaskContainerSettings.builder()
.workingDirectory(ContainerWorkingDirectory.TASK_WORKING_DIRECTORY)
.registry(registry)
.imageName(commandsWrapper.getContainerImage())
.build()
)
.environments(this.env(runContext, commandsWrapper))
.build()
))
.tasks(List.of(taskBuilder.build()))
.logConsumer(new AbstractLogConsumer() {
@Override
public void accept(String log, Boolean isStdErr) {
Expand All @@ -246,15 +251,17 @@ public void accept(String log, Boolean isStdErr) {

@Override
public Map<String, Object> runnerAdditionalVars(RunContext runContext, TaskCommands taskCommands) {
Map<String, Object> vars = new HashMap<>();
if (blobStorage != null && blobStorage.valid()) {
Path outputDirectory = taskCommands.getWorkingDirectory().relativize(taskCommands.getOutputDirectory());
return Map.of(
ScriptService.VAR_WORKING_DIR, "",
ScriptService.VAR_OUTPUT_DIR, outputDirectory,
ScriptService.VAR_BUCKET_PATH, outputDirectory
);
Path nestedDir = taskCommands.getWorkingDirectory().relativize(taskCommands.getOutputDirectory());
vars.put(ScriptService.VAR_WORKING_DIR, taskCommands.getWorkingDirectory().toString());
vars.put(ScriptService.VAR_BUCKET_PATH, nestedDir.toString());

if (taskCommands.outputDirectoryEnabled()) {
vars.put(ScriptService.VAR_OUTPUT_DIR, nestedDir);
}
}

return Collections.emptyMap();
return vars;
}
}

0 comments on commit a680653

Please sign in to comment.