Skip to content

Commit

Permalink
Handle reserved names during code-printing
Browse files Browse the repository at this point in the history
Avoid symbolic replacements for handling reserved symbols in amici.
Instead, only handle them during code printing. This has the advantage that
users won't be bothered by changed IDs.
This, however, makes it a breaking change for anybody currently relying on "amici_*" entity IDs.
I don't expect this to be much of a problem, since probably most users would have renamed their model
entities to avoid this "amici_"-prefixing.

Fixes AMICI-dev#2461.
  • Loading branch information
dweindl committed Jul 8, 2024
1 parent 837e70b commit c2caa45
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 31 deletions.
7 changes: 7 additions & 0 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AmiciCxxCodePrinter(CXX11CodePrinter):
"""

optimizations: Iterable[Optimization] = ()
RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"]

def __init__(self):
"""Create code printer"""
Expand Down Expand Up @@ -67,6 +68,12 @@ def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str:
f'Encountered unsupported function in expression "{expr}"'
) from e

def _print_Symbol(self, expr):
name = super()._print_Symbol(expr)
if name in self.RESERVED_SYMBOLS:
return f"amici_{name}"
return name

def _print_min_max(self, expr, cpp_fun: str, sympy_fun):
# C++ doesn't like mixing int and double for arguments for min/max,
# therefore, we just always convert to float
Expand Down
8 changes: 7 additions & 1 deletion python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def _write_index_files(self, name: str) -> None:
lines = []
for index, symbol in enumerate(symbols):
symbol_name = strip_pysb(symbol)
# symbol_name is a mix of symbols and strings
symbol_name = self._code_printer._print_Symbol(
sp.Symbol(symbol_name)
if isinstance(symbol_name, str)
else symbol_name
)
if str(symbol) == "0":
continue
if str(symbol_name) == "":
Expand Down Expand Up @@ -1221,7 +1227,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
f'"{symbol}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
8 changes: 0 additions & 8 deletions python/sdist/amici/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sympy as sp

from .import_utils import (
RESERVED_SYMBOLS,
ObservableTransformation,
amici_time_symbol,
cast_to_sym,
Expand Down Expand Up @@ -66,13 +65,6 @@ def __init__(
f"identifier must be sympy.Symbol, was " f"{type(identifier)}"
)

if str(identifier) in RESERVED_SYMBOLS or (
hasattr(identifier, "name") and identifier.name in RESERVED_SYMBOLS
):
raise ValueError(
f'Cannot add model quantity with name "{name}", '
f"please rename."
)
self._identifier: sp.Symbol = identifier

if not isinstance(name, str):
Expand Down
2 changes: 0 additions & 2 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from sympy.logic.boolalg import BooleanAtom
from toposort import toposort

RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"]

try:
import pysb
except ImportError:
Expand Down
20 changes: 0 additions & 20 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .de_model_components import symbol_to_type, Expression
from .sympy_utils import smart_is_zero_matrix, smart_multiply
from .import_utils import (
RESERVED_SYMBOLS,
_check_unsupported_functions,
_get_str_symbol_identifiers,
_parse_special_functions,
Expand Down Expand Up @@ -523,7 +522,6 @@ def _build_ode_model(
)
self._replace_compartments_with_volumes()

self._clean_reserved_symbols()
self._process_time()

ode_model = DEModel(
Expand Down Expand Up @@ -2596,24 +2594,6 @@ def _replace_in_all_expressions(
for spline in self.splines:
spline._replace_in_all_expressions(old, new)

def _clean_reserved_symbols(self) -> None:
"""
Remove all reserved symbols from self.symbols
"""
for sym in RESERVED_SYMBOLS:
old_symbol = symbol_with_assumptions(sym)
new_symbol = symbol_with_assumptions(f"amici_{sym}")
self._replace_in_all_expressions(
old_symbol, new_symbol, replace_identifiers=True
)
for symbols_ids, symbols in self.symbols.items():
if old_symbol in symbols:
# reconstitute the whole dict in order to keep the ordering
self.symbols[symbols_ids] = {
new_symbol if k is old_symbol else k: v
for k, v in symbols.items()
}

def _sympy_from_sbml_math(
self, var_or_math: [sbml.SBase, str]
) -> sp.Expr | float | None:
Expand Down
36 changes: 36 additions & 0 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,39 @@ def test_constraints():
amici_solver.getAbsoluteTolerance(),
)
)


def test_reserved_symbols():
"""Test handling of reserved one-letter names."""
from amici.antimony_import import antimony2amici

ant_model = """
model test_non_negative_species
t = 0.1
x = 0.2
y = 0.3
w = 0.4
h = 0.5
p = 0.6
k = 0.7
x' = k + x + p + y + w + h + t
end
"""
module_name = "test_reserved_symbols"
with TemporaryDirectory(prefix=module_name) as outdir:
antimony2amici(
ant_model,
model_name=module_name,
output_dir=outdir,
compute_conservation_laws=False,
)
# ensure it compiled successfully and can be imported
model_module = amici.import_model_module(
module_name=module_name, module_path=outdir
)
model = model_module.get_model()
ids = list(model.getParameterIds())
ids.extend(model.getStateIds())
# all symbols should be present with their original IDs
for symbol in "txywhpk":
assert symbol in ids

0 comments on commit c2caa45

Please sign in to comment.