Skip to content

Commit

Permalink
Merge pull request #450 from ecmwf-ifs/naml-enriched-derived-type-sym…
Browse files Browse the repository at this point in the history
…bols

Correct symbols after derived type enrichment
  • Loading branch information
reuterbal authored Nov 29, 2024
2 parents c3aaa57 + fa42d02 commit ba7e230
Show file tree
Hide file tree
Showing 10 changed files with 1,518 additions and 1,358 deletions.
4 changes: 4 additions & 0 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def map_variable_symbol(self, expr, enclosing_prec, *args, **kwargs):

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
map_derived_type_symbol = map_variable_symbol

def map_meta_symbol(self, expr, enclosing_prec, *args, **kwargs):
return self.rec(expr._symbol, enclosing_prec, *args, **kwargs)
Expand Down Expand Up @@ -234,6 +235,7 @@ def map_variable_symbol(self, expr, *args, **kwargs):

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
map_derived_type_symbol = map_variable_symbol

def map_meta_symbol(self, expr, *args, **kwargs):
if not self.visit(expr):
Expand Down Expand Up @@ -611,6 +613,7 @@ def map_variable_symbol(self, expr, *args, **kwargs):

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
map_derived_type_symbol = map_variable_symbol

def map_meta_symbol(self, expr, *args, **kwargs):
symbol = self.rec(expr._symbol, *args, **kwargs)
Expand Down Expand Up @@ -823,6 +826,7 @@ def map_procedure_symbol(self, expr, *args, **kwargs):
return expr.clone(scope=kwargs['scope'])
return self.map_variable_symbol(expr, *args, **kwargs)


class DetachScopesMapper(LokiIdentityMapper):
"""
A Pymbolic expression mapper (i.e., a visitor for the expression tree)
Expand Down
39 changes: 34 additions & 5 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Mix-ins
'StrCompareMixin',
# Typed leaf nodes
'TypedSymbol', 'DeferredTypeSymbol', 'VariableSymbol', 'ProcedureSymbol',
'TypedSymbol', 'DeferredTypeSymbol', 'VariableSymbol', 'ProcedureSymbol', 'DerivedTypeSymbol',
'MetaSymbol', 'Scalar', 'Array', 'Variable',
# Non-typed leaf nodes
'FloatLiteral', 'IntLiteral', 'LogicLiteral', 'StringLiteral',
Expand Down Expand Up @@ -481,6 +481,34 @@ def __init__(self, name, scope=None, type=None, **kwargs):
mapper_method = intern('map_procedure_symbol')


class DerivedTypeSymbol(StrCompareMixin, TypedSymbol, _FunctionSymbol):
"""
Internal representation of a symbol that represents a named
derived type.
This is used to represent the derived type symbolically in
:any:`Import` statements and when defining derived types.
Parameters
----------
name : str
The name of the symbol.
scope : :any:`Scope`
The scope in which the symbol is declared.
type : optional
The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
"""

def __init__(self, name, scope=None, type=None, **kwargs):
# pylint: disable=redefined-builtin
assert type is None or isinstance(type.dtype, DerivedType)
if type is not None:
assert name.lower() == type.dtype.name.lower()
super().__init__(name=name, scope=scope, type=type, **kwargs)

mapper_method = intern('map_derived_type_symbol')


class MetaSymbol(StrCompareMixin, pmbl.AlgebraicLeaf):
"""
Base class for meta symbols to encapsulate a symbol node with optional
Expand Down Expand Up @@ -868,9 +896,8 @@ def __new__(cls, **kwargs):
return ProcedureSymbol(**kwargs)

if _type and isinstance(_type.dtype, DerivedType) and name.lower() == _type.dtype.name.lower():
# This is a constructor call (or a type imported in an ``IMPORT`` statement, in which
# case this is classified wrong...)
return ProcedureSymbol(**kwargs)
# This the name of a derived type, as found in USE import statements
return DerivedTypeSymbol(**kwargs)

if 'dimensions' in kwargs and kwargs['dimensions'] is None:
# Convenience: This way we can construct Scalar variables with `dimensions=None`
Expand Down Expand Up @@ -1315,7 +1342,9 @@ def __init__(self, function, parameters=None, kw_parameters=None, **kwargs):
# Unfortunately, have to accept MetaSymbol here for the time being as
# rescoping before injecting statement functions may create InlineCalls
# with Scalar/Variable function names.
assert isinstance(function, (ProcedureSymbol, DeferredTypeSymbol, MetaSymbol))
assert isinstance(function, (
ProcedureSymbol, DerivedTypeSymbol, DeferredTypeSymbol, MetaSymbol
))
parameters = parameters or ()
kw_parameters = kw_parameters or {}

Expand Down
40 changes: 20 additions & 20 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,26 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
# symbols in the spec part to make them coherent with the symbol table
spec = AttachScopes().visit(spec, scope=routine, recurse_to_declaration_attributes=True)

# To simplify things, we always declare the result-type of a function with
# a declaration in the spec as this can capture every possible situation.
# Therefore, if it has been declared as a prefix in the subroutine statement,
# we now have to inject a declaration instead. To ensure we do this in the
# right place in the spec to not violate the intrinsic order Fortran mandates,
# we search for the first occurence of any VariableDeclaration or
# ProcedureDeclaration and inject it before that one
if return_type is not None:
routine.symbol_attrs[routine.name] = return_type
return_var = sym.Variable(name=routine.name, scope=routine)
decl_source = self.get_source(subroutine_stmt, source=None)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source)

decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec)
if not decls:
# No other declarations: add it to the end
spec.append(return_var_decl)
else:
spec.insert(spec.body.index(decls[0]), return_var_decl)

# Now all declarations are well-defined and we can parse the member routines
if contains_ast is not None:
contains = self.visit(contains_ast, **kwargs)
Expand Down Expand Up @@ -1821,26 +1841,6 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
comment_map[node] = None
spec = Transformer(comment_map, invalidate_source=False).visit(spec)

# To simplify things, we always declare the result-type of a function with
# a declaration in the spec as this can capture every possible situation.
# Therefore, if it has been declared as a prefix in the subroutine statement,
# we now have to inject a declaration instead. To ensure we do this in the
# right place in the spec to not violate the intrinsic order Fortran mandates,
# we search for the first occurence of any VariableDeclaration or
# ProcedureDeclaration and inject it before that one
if return_type is not None:
routine.symbol_attrs[routine.name] = return_type
return_var = sym.Variable(name=routine.name, scope=routine)
decl_source = self.get_source(subroutine_stmt, source=None)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source)

decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec)
if not decls:
# No other declarations: add it to the end
spec.append(return_var_decl)
else:
spec.insert(spec.body.index(decls[0]), return_var_decl)

# Finally, call the subroutine constructor on the object again to register all
# bits and pieces in place and rescope all symbols
# pylint: disable=unnecessary-dunder-call
Expand Down
Loading

0 comments on commit ba7e230

Please sign in to comment.