Skip to content

Commit

Permalink
FileWrite: make mode a runtime arg too
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Nov 4, 2024
1 parent f6b3c18 commit 74d9660
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
17 changes: 9 additions & 8 deletions loki/batch/tests/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,22 +414,23 @@ def test_transformation_file_write(tmp_path):
"""
source = Sourcefile.from_source(fcode)
source.path = Path('rick.F90')
item = ProcedureItem(name='#rick', source=source)
item = ProcedureItem(name='#rick', source=source, config={'mode': 'roll'})

# Test default file writes
ricks_path = tmp_path/'rick.loki.F90'
# Test mode and suffix overrides
ricks_path = tmp_path/'rick.roll.java'
if ricks_path.exists():
ricks_path.unlink()
FileWriteTransformation().apply(source=source, item=item, build_args={'output_dir': tmp_path})
FileWriteTransformation(suffix='.java').apply(source=source, item=item,
build_args={'output_dir': tmp_path})
assert ricks_path.exists()
ricks_path.unlink()

# Test mode and suffix overrides
ricks_path = tmp_path/'rick.roll.java'
item = ProcedureItem(name='#rick', source=source)
# Test default file writes
ricks_path = tmp_path/'rick.loki.F90'
if ricks_path.exists():
ricks_path.unlink()
FileWriteTransformation(mode='roll', suffix='.java').apply(source=source, item=item,
build_args={'output_dir': tmp_path})
FileWriteTransformation().apply(source=source, item=item, build_args={'output_dir': tmp_path})
assert ricks_path.exists()
ricks_path.unlink()

Expand Down
10 changes: 5 additions & 5 deletions loki/transformations/build_system/file_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class FileWriteTransformation(Transformation):
Parameters
----------
mode : str, optional
"Mode" identifier string to add in front of the file suffix
suffix : str, optional
File suffix to determine file type for all written file. If
omitted, it will preserve the original file type.
Expand All @@ -39,10 +37,9 @@ class FileWriteTransformation(Transformation):
traverse_file_graph = True

def __init__(
self, mode='loki', suffix=None, cuf=False,
self, suffix=None, cuf=False,
include_module_var_imports=False
):
self.mode = mode
self.suffix = suffix
self.cuf = cuf
self.include_module_var_imports = include_module_var_imports
Expand All @@ -66,9 +63,12 @@ def transform_file(self, sourcefile, **kwargs):
if not item:
raise ValueError('No Item provided; required to determine file write path')

_mode = item.mode if item.mode else 'loki'
_mode = _mode.replace('-', '_') # Sanitize mode string

path = Path(item.path)
suffix = self.suffix if self.suffix else path.suffix
sourcepath = Path(item.path).with_suffix(f'.{self.mode}{suffix}')
sourcepath = Path(item.path).with_suffix(f'.{_mode}{suffix}')
build_args = kwargs.get('build_args', {})
if build_args and (output_dir := build_args.get('output_dir', None)) is not None:
sourcepath = Path(output_dir)/sourcepath.name
Expand Down
8 changes: 5 additions & 3 deletions scripts/loki_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def convert(

config = SchedulerConfig.from_file(config)

# set default transformation mode in Scheduler config
config.default['mode'] = mode

directive = None if directive.lower() == 'none' else directive.lower()

build_args = {
Expand Down Expand Up @@ -186,7 +189,7 @@ def convert(
# Write out all modified source files into the build directory
file_write_trafo = scheduler.config.transformations.get('FileWriteTransformation', None)
if not file_write_trafo:
file_write_trafo = FileWriteTransformation(mode=mode, cuf='cuf' in mode)
file_write_trafo = FileWriteTransformation(cuf='cuf' in mode)
scheduler.process(transformation=file_write_trafo)

return
Expand Down Expand Up @@ -406,8 +409,7 @@ def convert(

# Write out all modified source files into the build directory
scheduler.process(transformation=FileWriteTransformation(
mode=mode, cuf='cuf' in mode,
include_module_var_imports=global_var_offload
cuf='cuf' in mode, include_module_var_imports=global_var_offload
))


Expand Down

0 comments on commit 74d9660

Please sign in to comment.