diff --git a/loki/batch/tests/test_transformation.py b/loki/batch/tests/test_transformation.py index c41f1b843..c410441c1 100644 --- a/loki/batch/tests/test_transformation.py +++ b/loki/batch/tests/test_transformation.py @@ -414,22 +414,23 @@ def test_transformation_file_write(tmp_path): """ source = Sourcefile.from_source(fcode) source.path = Path('rick.F90') - item = ProcedureItem(name='#rick', source=source) + item = ProcedureItem(name='#rick', source=source, config={'mode': 'roll'}) - # Test default file writes - ricks_path = tmp_path/'rick.loki.F90' + # Test mode and suffix overrides + ricks_path = tmp_path/'rick.roll.java' if ricks_path.exists(): ricks_path.unlink() - FileWriteTransformation().apply(source=source, item=item, build_args={'output_dir': tmp_path}) + FileWriteTransformation(suffix='.java').apply(source=source, item=item, + build_args={'output_dir': tmp_path}) assert ricks_path.exists() ricks_path.unlink() - # Test mode and suffix overrides - ricks_path = tmp_path/'rick.roll.java' + item = ProcedureItem(name='#rick', source=source) + # Test default file writes + ricks_path = tmp_path/'rick.loki.F90' if ricks_path.exists(): ricks_path.unlink() - FileWriteTransformation(mode='roll', suffix='.java').apply(source=source, item=item, - build_args={'output_dir': tmp_path}) + FileWriteTransformation().apply(source=source, item=item, build_args={'output_dir': tmp_path}) assert ricks_path.exists() ricks_path.unlink() diff --git a/loki/transformations/build_system/file_write.py b/loki/transformations/build_system/file_write.py index 506e172d9..f9182d877 100644 --- a/loki/transformations/build_system/file_write.py +++ b/loki/transformations/build_system/file_write.py @@ -23,8 +23,6 @@ class FileWriteTransformation(Transformation): Parameters ---------- - mode : str, optional - "Mode" identifier string to add in front of the file suffix suffix : str, optional File suffix to determine file type for all written file. If omitted, it will preserve the original file type. @@ -39,10 +37,9 @@ class FileWriteTransformation(Transformation): traverse_file_graph = True def __init__( - self, mode='loki', suffix=None, cuf=False, + self, suffix=None, cuf=False, include_module_var_imports=False ): - self.mode = mode self.suffix = suffix self.cuf = cuf self.include_module_var_imports = include_module_var_imports @@ -66,9 +63,12 @@ def transform_file(self, sourcefile, **kwargs): if not item: raise ValueError('No Item provided; required to determine file write path') + _mode = item.mode if item.mode else 'loki' + _mode = _mode.replace('-', '_') # Sanitize mode string + path = Path(item.path) suffix = self.suffix if self.suffix else path.suffix - sourcepath = Path(item.path).with_suffix(f'.{self.mode}{suffix}') + sourcepath = Path(item.path).with_suffix(f'.{_mode}{suffix}') build_args = kwargs.get('build_args', {}) if build_args and (output_dir := build_args.get('output_dir', None)) is not None: sourcepath = Path(output_dir)/sourcepath.name diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index f39005bc2..2309d903e 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -140,6 +140,9 @@ def convert( config = SchedulerConfig.from_file(config) + # set default transformation mode in Scheduler config + config.default['mode'] = mode + directive = None if directive.lower() == 'none' else directive.lower() build_args = { @@ -186,7 +189,7 @@ def convert( # Write out all modified source files into the build directory file_write_trafo = scheduler.config.transformations.get('FileWriteTransformation', None) if not file_write_trafo: - file_write_trafo = FileWriteTransformation(mode=mode, cuf='cuf' in mode) + file_write_trafo = FileWriteTransformation(cuf='cuf' in mode) scheduler.process(transformation=file_write_trafo) return @@ -406,8 +409,7 @@ def convert( # Write out all modified source files into the build directory scheduler.process(transformation=FileWriteTransformation( - mode=mode, cuf='cuf' in mode, - include_module_var_imports=global_var_offload + cuf='cuf' in mode, include_module_var_imports=global_var_offload ))