Skip to content

Commit

Permalink
#2845 add new _replace_symbol method
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Feb 4, 2025
1 parent 1adb642 commit 0c3ed07
Showing 1 changed file with 80 additions and 51 deletions.
131 changes: 80 additions & 51 deletions src/psyclone/psyir/symbols/symbol_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ def check_for_clashes(self, other_table, symbols_to_skip=()):
f"check_for_clashes: 'symbols_to_skip' must be an instance of "
f"Iterable but got '{type(symbols_to_skip).__name__}'")

if other_table is self:
return

# Check whether there are any wildcard imports common to both tables.
self_imports = set(sym.name for sym in self.wildcard_imports())
other_imports = set(sym.name for sym in other_table.wildcard_imports())
Expand Down Expand Up @@ -649,7 +652,7 @@ def check_for_clashes(self, other_table, symbols_to_skip=()):
f"interface '{this_sym.interface}' but the supplied "
f"table imports it via '{other_sym.interface}'.")
continue

if other_sym.is_unresolved and this_sym.is_unresolved:
# Both Symbols are unresolved.
if shared_wildcard_imports and not unique_wildcard_imports:
Expand All @@ -675,6 +678,16 @@ def check_for_clashes(self, other_table, symbols_to_skip=()):
f"A symbol named '{this_sym.name}' is present but "
f"unresolved in one or both tables.")

elif other_sym.is_unresolved or this_sym.is_unresolved:
# Only one is unresolved. Could it be imported from the same
# location?
if len(shared_wildcard_imports) == 1:
usym = this_sym if this_sym.is_unresolved else other_sym
(name,) = shared_wildcard_imports
csym = self.lookup(name)
usym.interface = ImportInterface(csym)
continue

# Can either of them be renamed?
try:
self.rename_symbol(this_sym, "", dry_run=True)
Expand Down Expand Up @@ -1108,11 +1121,47 @@ def swap(self, old_symbol, new_symbol):
raise SymbolError(
f"Cannot swap symbols that have different names, got: "
f"'{old_symbol.name}' and '{new_symbol.name}'")
# TODO #898 remove() does not currently check for any uses of
# old_symbol.
self._replace_symbol(self.node, old_symbol, new_symbol)
self.remove(old_symbol)
self.add(new_symbol)

@staticmethod
def _replace_symbol(node, test_symbol, outer_sym):
'''
'''
if not node:
return

norm_name = SymbolTable._normalize(test_symbol.name)

from psyclone.core.variables_access_info import VariablesAccessInfo
from psyclone.psyir.nodes import Call, Literal, ScopingNode, Reference
vai = VariablesAccessInfo()
node.reference_accesses(vai)
for sig in vai.all_signatures:
if sig.var_name.lower() != norm_name:
continue
for access in vai[sig].all_accesses:
if access.node.scope is access.node:
# This is just the symbol table associated
# with a ScopingNode.
continue
if access.node.scope.symbol_table.lookup(norm_name) is test_symbol:
if isinstance(access.node, Reference):
access.node.symbol = outer_sym
elif isinstance(access.node, Call):
import pdb; pdb.set_trace()
print(access.node)
elif isinstance(access.node, Literal):
oldtype = access.node.datatype
newtype = ScalarType(oldtype.intrinsic,
outer_sym)
access.node.replace_with(
Literal(access.node.value, newtype))
else:
import pdb; pdb.set_trace()
print("oh dear")

def _validate_remove_routinesymbol(self, symbol):
'''
Checks whether the supplied RoutineSymbol can be removed from this
Expand Down Expand Up @@ -1218,7 +1267,7 @@ def remove(self, symbol):
if isinstance(symbol, RoutineSymbol):
self._validate_remove_routinesymbol(symbol)
elif self.node:
from psyclone.psyir.nodes import ScopingNode
from psyclone.psyir.nodes import Call, Literal, Reference, ScopingNode
from psyclone.core.variables_access_info import VariablesAccessInfo
vai = VariablesAccessInfo()
self.node.reference_accesses(vai)
Expand All @@ -1234,22 +1283,32 @@ def remove(self, symbol):
# symbol we want to remove, not whether it just has the
# same name...
# TODO
if access.node.scope.symbol_table.lookup(norm_name) is symbol:
# This access does refer to the target symbol.
if symbol.find_symbol_table(access.node) is self:
# The symbol we've found an access of is the one
# in this table. Therefore, we can only remove it
# provided that it also exists in an outer scope.
outer_sym = self.parent_symbol_table().lookup(
norm_name, otherwise=None)
if outer_sym is not symbol:
from psyclone.psyir.nodes import Statement
stmt = access.node.ancestor(Statement)
if not stmt:
stmt = access.node
raise ValueError(
f"Cannot remove {type(symbol).__name__} '{symbol.name}' because it is "
f"accessed in '{stmt.debug_string().strip()}'")
if isinstance(access.node, Reference):
this_sym = access.node.symbol
elif isinstance(access.node, Call):
this_sym = access.node.routine.symbol
elif isinstance(access.node, Literal):
this_sym = access.node.datatype.precision
else:
import pdb; pdb.set_trace()
print("oh dear2")
if this_sym is not symbol:
continue
# This access does refer to the target symbol.
if symbol.find_symbol_table(access.node) is self:
# The symbol we've found an access of is the one
# in this table. Therefore, we can only remove it
# provided that it also exists in an outer scope.
outer_sym = self.parent_symbol_table().lookup(
norm_name, otherwise=None)
if outer_sym is not symbol:
from psyclone.psyir.nodes import Statement
stmt = access.node.ancestor(Statement)
if not stmt:
stmt = access.node
raise ValueError(
f"Cannot remove {type(symbol).__name__} '{symbol.name}' because it is "
f"accessed in '{stmt.debug_string().strip()}'")

# If the symbol had a tag, it should be disassociated
for tag, tagged_symbol in list(self._tags.items()):
Expand Down Expand Up @@ -1787,44 +1846,14 @@ def resolve_imports(self, container_symbols=None, symbol_target=None):
wildcard_imports = symbol_table.wildcard_imports()
if not all(csym in self.containersymbols for
csym in wildcard_imports):
# TODO, check whether the wildcard imports are all at
# the same scope.
import pdb; pdb.set_trace()
# There are wildcard imports other than those in the
# outer scope so we can't be certain of the origin of
# this symbol.
continue

# We want to replace the local symbol with the new one
# in the outer scope (`outer_sym`).
from psyclone.core.variables_access_info import VariablesAccessInfo
vai = VariablesAccessInfo()
scoping_node.reference_accesses(vai)
for sig in vai.all_signatures:
if sig.var_name.lower() != norm_name:
continue
if norm_name == "wp":
import pdb; pdb.set_trace()
for access in vai[sig].all_accesses:
if access.node.scope is access.node:
# This is just the symbol table associated
# with a ScopingNode.
continue
if access.node.scope.symbol_table.lookup(norm_name) is test_symbol:
if isinstance(access.node, Reference):
access.node.symbol = outer_sym
elif isinstance(access.node, Call):
import pdb; pdb.set_trace()
print(access.node)
elif isinstance(access.node, Literal):
oldtype = access.node.datatype
newtype = ScalarType(oldtype.intrinsic,
outer_sym)
access.node.replace_with(
Literal(access.node.value, newtype))
else:
import pdb; pdb.set_trace()
print("oh dear")
self._replace_symbol(scoping_node, test_symbol, outer_sym)
symbol_table.remove(test_symbol)

if symbol_target:
Expand Down

0 comments on commit 0c3ed07

Please sign in to comment.