diff --git a/loki/batch/scheduler.py b/loki/batch/scheduler.py index edb6cd94a..a6d3fef98 100644 --- a/loki/batch/scheduler.py +++ b/loki/batch/scheduler.py @@ -120,6 +120,8 @@ class Scheduler: By default a full parse is executed, use this flag to suppress. frontend : :any:`Frontend`, optional Frontend to use for full parse of source files (default :any:`FP`). + output_dir : str or path + Directory for the output to be written to """ # TODO: Should be user-definable! @@ -127,7 +129,7 @@ class Scheduler: def __init__(self, paths, config=None, seed_routines=None, preprocess=False, includes=None, defines=None, definitions=None, xmods=None, - omni_includes=None, full_parse=True, frontend=FP): + omni_includes=None, full_parse=True, frontend=FP, output_dir=None): # Derive config from file or dict if isinstance(config, SchedulerConfig): self.config = config @@ -153,7 +155,8 @@ def __init__(self, paths, config=None, seed_routines=None, preprocess=False, 'defines': defines, 'xmods': xmods, 'omni_includes': omni_includes, - 'frontend': frontend + 'frontend': frontend, + 'output_dir': output_dir } # Internal data structures to store the callgraph @@ -495,7 +498,7 @@ def _get_definition_items(_item, sgraph_items): _item.scope_ir, role=_item.role, mode=_item.mode, item=_item, targets=_item.targets, items=_get_definition_items(_item, sgraph_items), successors=graph.successors(_item, item_filter=item_filter), - depths=graph.depths + depths=graph.depths, build_args=self.build_args ) if transformation.renames_items: diff --git a/loki/batch/tests/test_transformation.py b/loki/batch/tests/test_transformation.py index 16c1bbf67..c410441c1 100644 --- a/loki/batch/tests/test_transformation.py +++ b/loki/batch/tests/test_transformation.py @@ -414,21 +414,23 @@ def test_transformation_file_write(tmp_path): """ source = Sourcefile.from_source(fcode) source.path = Path('rick.F90') - item = ProcedureItem(name='#rick', source=source) + item = ProcedureItem(name='#rick', source=source, config={'mode': 'roll'}) - # Test default file writes - ricks_path = tmp_path/'rick.loki.F90' + # Test mode and suffix overrides + ricks_path = tmp_path/'rick.roll.java' if ricks_path.exists(): ricks_path.unlink() - FileWriteTransformation(builddir=tmp_path).apply(source=source, item=item) + FileWriteTransformation(suffix='.java').apply(source=source, item=item, + build_args={'output_dir': tmp_path}) assert ricks_path.exists() ricks_path.unlink() - # Test mode and suffix overrides - ricks_path = tmp_path/'rick.roll.java' + item = ProcedureItem(name='#rick', source=source) + # Test default file writes + ricks_path = tmp_path/'rick.loki.F90' if ricks_path.exists(): ricks_path.unlink() - FileWriteTransformation(builddir=tmp_path, mode='roll', suffix='.java').apply(source=source, item=item) + FileWriteTransformation().apply(source=source, item=item, build_args={'output_dir': tmp_path}) assert ricks_path.exists() ricks_path.unlink() @@ -436,13 +438,13 @@ def test_transformation_file_write(tmp_path): ricks_path = tmp_path/'rick.loki.F90' if ricks_path.exists(): ricks_path.unlink() - FileWriteTransformation(builddir=tmp_path).apply(source=source, items=(item,)) + FileWriteTransformation().apply(source=source, items=(item,), build_args={'output_dir': tmp_path}) assert ricks_path.exists() ricks_path.unlink() # Check error behaviour if no item provided with pytest.raises(ValueError): - FileWriteTransformation(builddir=tmp_path).apply(source=source) + FileWriteTransformation().apply(source=source) def test_transformation_pipeline_simple(): diff --git a/loki/transformations/build_system/file_write.py b/loki/transformations/build_system/file_write.py index 7dc406a15..f9182d877 100644 --- a/loki/transformations/build_system/file_write.py +++ b/loki/transformations/build_system/file_write.py @@ -23,10 +23,6 @@ class FileWriteTransformation(Transformation): Parameters ---------- - builddir : str or path - Directory for the output to be written to - mode : str, optional - "Mode" identifier string to add in front of the file suffix suffix : str, optional File suffix to determine file type for all written file. If omitted, it will preserve the original file type. @@ -41,11 +37,9 @@ class FileWriteTransformation(Transformation): traverse_file_graph = True def __init__( - self, builddir=None, mode='loki', suffix=None, cuf=False, + self, suffix=None, cuf=False, include_module_var_imports=False ): - self.builddir = Path(builddir) - self.mode = mode self.suffix = suffix self.cuf = cuf self.include_module_var_imports = include_module_var_imports @@ -69,9 +63,13 @@ def transform_file(self, sourcefile, **kwargs): if not item: raise ValueError('No Item provided; required to determine file write path') + _mode = item.mode if item.mode else 'loki' + _mode = _mode.replace('-', '_') # Sanitize mode string + path = Path(item.path) suffix = self.suffix if self.suffix else path.suffix - sourcepath = Path(item.path).with_suffix(f'.{self.mode}{suffix}') - if self.builddir is not None: - sourcepath = self.builddir/sourcepath.name + sourcepath = Path(item.path).with_suffix(f'.{_mode}{suffix}') + build_args = kwargs.get('build_args', {}) + if build_args and (output_dir := build_args.get('output_dir', None)) is not None: + sourcepath = Path(output_dir)/sourcepath.name sourcefile.write(path=sourcepath, cuf=self.cuf) diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 1f9589171..2309d903e 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -140,6 +140,9 @@ def convert( config = SchedulerConfig.from_file(config) + # set default transformation mode in Scheduler config + config.default['mode'] = mode + directive = None if directive.lower() == 'none' else directive.lower() build_args = { @@ -169,7 +172,7 @@ def convert( paths = [Path(p).resolve() for p in as_tuple(source)] paths += [Path(h).resolve().parent for h in as_tuple(header)] scheduler = Scheduler( - paths=paths, config=config, frontend=frontend, definitions=definitions, **build_args + paths=paths, config=config, frontend=frontend, definitions=definitions, output_dir=build, **build_args ) # If requested, apply a custom pipeline from the scheduler config @@ -186,7 +189,7 @@ def convert( # Write out all modified source files into the build directory file_write_trafo = scheduler.config.transformations.get('FileWriteTransformation', None) if not file_write_trafo: - file_write_trafo = FileWriteTransformation(builddir=build, mode=mode, cuf='cuf' in mode) + file_write_trafo = FileWriteTransformation(cuf='cuf' in mode) scheduler.process(transformation=file_write_trafo) return @@ -352,7 +355,7 @@ def convert( transformation_type='hoist', derived_types = ['TECLDP'], block_dim=block_dim, dim_vars=(vertical.size,), as_kwarguments=True, remove_vector_section=True) scheduler.process( pipeline ) - + if mode == 'cuf-parametrise': pipeline = scheduler.config.transformations.get('cuf-parametrise', None) if not pipeline: @@ -406,8 +409,7 @@ def convert( # Write out all modified source files into the build directory scheduler.process(transformation=FileWriteTransformation( - builddir=build, mode=mode, cuf='cuf' in mode, - include_module_var_imports=global_var_offload + cuf='cuf' in mode, include_module_var_imports=global_var_offload ))