Skip to content

Commit

Permalink
SBML import: Support for rateOf (AMICI-dev#2120)
Browse files Browse the repository at this point in the history
Adds support for SBML's `rateOf(.)` for most not-so-exotic cases.

We currently can't check for functions with `<csymbol encoding="text" definitionURL="http://www.sbml.org/sbml/symbols/rateOf">`, and only check for functions named `rateOf`. This should be safe as long as we use the `SBMLFunctionDefinitionConverter` before. 

Tested in the following SBML semantic test suite cases: 01248,01249,01250,01251,01252,01253,01254,01255,01256,01257,01258,01259,01260,01261,01262,01263,01264,01265,01266,01267,01268,01269,01270,01293,01294,01295,01296,01297,01298,01299,01321,01322,01400,01401,01402,01403,01405,01406,01408,01409,01455,01456,01457,01458,01459,01460,01461,01462,01463,01482,01483,01525,01526,01527,01528,01529,01540,01541,01542,01543

Closes AMICI-dev#769
  • Loading branch information
dweindl authored Jun 26, 2023
1 parent a12d173 commit b1b4b2b
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 12 deletions.
3 changes: 2 additions & 1 deletion documentation/python_interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ AMICI can import :term:`SBML` models via the
Status of SBML support in Python-AMICI
++++++++++++++++++++++++++++++++++++++

Python-AMICI currently **passes 1215 out of the 1821 (~67%) test cases** from
Python-AMICI currently **passes 1247 out of the 1821 (~68%) test cases** from
the semantic
`SBML Test Suite <https://github.com/sbmlteam/sbml-test-suite/>`_
(`current status <https://github.com/AMICI-dev/AMICI/actions>`_).
Expand All @@ -42,6 +42,7 @@ The following SBML test suite tags are currently supported
* comp
* Compartment
* CSymbolAvogadro
* CSymbolRateOf
* CSymbolTime
* Deletion
* EventNoDelay
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ filterwarnings =
ignore:Model.initial_conditions will be removed in a future version. Instead, you can get a list of Initial objects with Model.initials.:DeprecationWarning:pysb\.core
# https://github.com/pytest-dev/pytest-xdist/issues/825#issuecomment-1292283870
ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning
ignore:.*:ImportWarning:tellurium
ignore:.*PyDevIPCompleter6.*:DeprecationWarning

norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples
90 changes: 87 additions & 3 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,87 @@ def transform_dxdt_to_concentration(species_id, dxdt):
for llh in si.symbols[SymbolId.LLHY].values()
)

self._process_sbml_rate_of(symbols)# substitute SBML-rateOf constructs

def _process_sbml_rate_of(self, symbols) -> None:
"""Substitute any SBML-rateOf constructs in the model equations"""
rate_of_func = sp.core.function.UndefinedFunction("rateOf")
species_sym_to_xdot = dict(zip(self.sym("x"), self.sym("xdot")))
species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))}

def get_rate(symbol: sp.Symbol):
"""Get rate of change of the given symbol"""
nonlocal symbols

if symbol.find(rate_of_func):
raise SBMLException("Nesting rateOf() is not allowed.")

# Replace all rateOf(some_species) by their respective xdot equation
with contextlib.suppress(KeyError):
return self._eqs["xdot"][species_sym_to_idx[symbol]]

# For anything other than a state, rateOf(.) is 0 or invalid
return 0

# replace rateOf-instances in xdot by xdot symbols
for i_state in range(len(self.eq("xdot"))):
if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func):
self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs(
{
# either the rateOf argument is a state, or it's 0
rate_of: species_sym_to_xdot.get(rate_of.args[0], 0)
for rate_of in rate_ofs
}
)
# substitute in topological order
subs = toposort_symbols(dict(zip(self.sym("xdot"), self.eq("xdot"))))
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)

# replace rateOf-instances in x0 by xdot equation
for i_state in range(len(self.eq("x0"))):
if rate_ofs := self._eqs["x0"][i_state].find(rate_of_func):
self._eqs["x0"][i_state] = self._eqs["x0"][i_state].subs(
{rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
)

for component in chain(self.observables(), self.expressions(), self.events(), self._algebraic_equations):
if rate_ofs := component.get_val().find(rate_of_func):
if isinstance(component, Event):
# TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
# see, e.g., sbml test case 01293
raise SBMLException(
"AMICI does currently not support rateOf(.) inside event trigger functions."
)

if isinstance(component, AlgebraicEquation):
# TODO IDACalcIC fails with
# "The linesearch algorithm failed: step too small or too many backtracks."
# see, e.g., sbml test case 01482
raise SBMLException(
"AMICI does currently not support rateOf(.) inside AlgebraicRules."
)

component.set_val(
component.get_val().subs(
{rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
)
)

for event in self.events():
if event._state_update is None:
continue

for i_state in range(len(event._state_update)):
if rate_ofs := event._state_update[i_state].find(rate_of_func):
raise SBMLException(
"AMICI does currently not support rateOf(.) inside event state updates."
)
# TODO here we need xdot sym, not eqs
# event._state_update[i_state] = event._state_update[i_state].subs(
# {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
# )


def add_component(
self, component: ModelQuantity, insert_first: Optional[bool] = False
) -> None:
Expand Down Expand Up @@ -2758,11 +2839,11 @@ def _generate_c_code(self) -> None:
# only generate for those that have nontrivial implementation,
# check for both basic variables (not in functions) and function
# computed values
if (
if ((
name in self.functions
and not self.functions[name].body
and name not in nobody_functions
) or (name not in self.functions and len(self.model.sym(name)) == 0):
) or name not in self.functions) and len(self.model.sym(name)) == 0:
continue
self._write_index_files(name)

Expand Down Expand Up @@ -2982,7 +3063,10 @@ def _write_function_file(self, function: str) -> None:
else:
iszero = len(self.model.sym(sym)) == 0

if iszero:
if iszero and not (
(sym == "y" and "Jy" in function)
or (sym == "w" and "xdot" in function and len(self.model.sym(sym)))
):
continue

lines.append(f'#include "{sym}.h"')
Expand Down
7 changes: 4 additions & 3 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,10 @@ def _check_unsupported_functions(
sp.core.function.UndefinedFunction,
)

if isinstance(sym.func, unsupported_functions) or isinstance(
sym, unsupported_functions
):
if (
isinstance(sym.func, unsupported_functions)
or isinstance(sym, unsupported_functions)
) and getattr(sym.func, "name", "") != "rateOf":
raise RuntimeError(
f"Encountered unsupported expression "
f'"{sym.func}" of type '
Expand Down
59 changes: 56 additions & 3 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,10 +804,20 @@ def _process_species_initial(self):
if species:
species["init"] = initial

# hide rateOf-arguments from toposort and the substitution below
all_rateof_dummies = []
for species in self.symbols[SymbolId.SPECIES].values():
species["init"], rateof_dummies = _rateof_to_dummy(species["init"])
all_rateof_dummies.append(rateof_dummies)

# don't assign this since they need to stay in order
sorted_species = toposort_symbols(self.symbols[SymbolId.SPECIES], "init")
for species in self.symbols[SymbolId.SPECIES].values():
species["init"] = smart_subs_dict(species["init"], sorted_species, "init")
for species, rateof_dummies in zip(self.symbols[SymbolId.SPECIES].values(), all_rateof_dummies):
species["init"] = _dummy_to_rateof(
smart_subs_dict(species["init"], sorted_species, "init"),
rateof_dummies
)


@log_execution_time("processing SBML rate rules", logger)
def _process_rate_rules(self):
Expand Down Expand Up @@ -990,6 +1000,18 @@ def _process_parameters(self, constant_parameters: List[str] = None) -> None:
"value": par.getValue(),
}

# Parameters that need to be turned into expressions
# so far, this concerns parameters with initial assignments containing rateOf(.)
# (those have been skipped above)
for par in self.sbml.getListOfParameters():
if (ia := self._get_element_initial_assignment(par.getId())) is not None \
and ia.find(sp.core.function.UndefinedFunction("rateOf")):
self.symbols[SymbolId.EXPRESSION][_get_identifier_symbol(par)] = {
"name": par.getName() if par.isSetName() else par.getId(),
"value": ia,
}


@log_execution_time("processing SBML reactions", logger)
def _process_reactions(self):
"""
Expand Down Expand Up @@ -1774,16 +1796,19 @@ def _make_initial(
:return:
transformed expression
"""

if not isinstance(sym_math, sp.Expr):
return sym_math

sym_math, rateof_to_dummy = _rateof_to_dummy(sym_math)

for species_id, species in self.symbols[SymbolId.SPECIES].items():
if "init" in species:
sym_math = smart_subs(sym_math, species_id, species["init"])

sym_math = smart_subs(sym_math, self._local_symbols["time"], sp.Float(0))

sym_math = _dummy_to_rateof(sym_math, rateof_to_dummy)

return sym_math

def process_conservation_laws(self, ode_model) -> None:
Expand Down Expand Up @@ -2663,3 +2688,31 @@ def _non_const_conservation_laws_supported(sbml_model: sbml.Model) -> bool:
return False

return True


def _rateof_to_dummy(sym_math):
"""Replace rateOf(...) by dummy variable
if `rateOf(some_species)` is used in an initial assignment, we don't want to substitute the species argument
by its initial value.
Usage:
sym_math, rateof_to_dummy = _rateof_to_dummy(sym_math)
[...substitute...]
sym_math = _dummy_to_rateof(sym_math, rateof_to_dummy)
"""
if rate_ofs := sym_math.find(
sp.core.function.UndefinedFunction("rateOf")
):
# replace by dummies to avoid species substitution
rateof_dummies = {rate_of: sp.Dummy(f"Dummy_RateOf_{rate_of.args[0].name}") for rate_of in rate_ofs}

return sym_math.subs(rateof_dummies), rateof_dummies
return sym_math, {}


def _dummy_to_rateof(sym_math, rateof_dummies):
"""Back-substitution of dummies from `_rateof_to_dummy`"""
if rateof_dummies:
return sym_math.subs({v: k for k, v in rateof_dummies.items()})
return sym_math
1 change: 1 addition & 0 deletions python/sdist/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ test =
pytest-rerunfailures
coverage
shyaml
tellurium
vis =
matplotlib
seaborn
Expand Down
60 changes: 58 additions & 2 deletions python/tests/test_sbml_import_special_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import os

import amici
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal_nulp, assert_approx_equal
from scipy.special import loggamma

import amici
from amici.gradient_check import check_derivatives
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from scipy.special import loggamma


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -153,3 +155,57 @@ def negative_binomial_nllh(m: np.ndarray, y: np.ndarray, p: float):
- r * np.log(1 - p)
- m * np.log(p)
)

@pytest.mark.filterwarnings("ignore:the imp module is deprecated:DeprecationWarning")
def test_rateof():
"""Test chained rateOf to verify that model expressions are evaluated in the correct order."""
import tellurium as te

ant_model = """
model test_chained_rateof
species S1, S2, S3, S4;
S1 = 0;
S3 = 0;
p2 = 1;
rate = 1;
S4 = 0.5 * rateOf(S3);
S2' = 2 * rateOf(S3);
S1' = S2 + rateOf(S2);
S3' = rate;
p1 = 2 * rateOf(S1);
p2' = rateOf(S1);
p3 = rateOf(rate);
end
"""
sbml_str = te.antimonyToSBML(ant_model)
sbml_importer = amici.SbmlImporter(sbml_str, from_file=False)

module_name = "test_chained_rateof"
with TemporaryDirectoryWinSafe(prefix=module_name) as outdir:
sbml_importer.sbml2amici(
model_name=module_name,
output_dir=outdir,
)
model_module = amici.import_model_module(module_name=module_name, module_path=outdir)
amici_model = model_module.getModel()
t = np.linspace(0, 10, 11)
amici_model.setTimepoints(t)
amici_solver = amici_model.getSolver()
rdata = amici.runAmiciSimulation(amici_model, amici_solver)

state_ids_solver = amici_model.getStateIdsSolver()
i_S1 = state_ids_solver.index("S1")
i_S2 = state_ids_solver.index("S2")
i_S3 = state_ids_solver.index("S3")
i_p2 = state_ids_solver.index("p2")
assert_approx_equal(rdata["xdot"][i_S3], 1)
assert_approx_equal(rdata["xdot"][i_S2], 2)
assert_approx_equal(rdata["xdot"][i_S1], rdata.by_id("S2")[-1] + rdata["xdot"][i_S2])
assert_approx_equal(rdata["xdot"][i_S1], rdata["xdot"][i_p2])

assert_array_almost_equal_nulp(rdata.by_id("S3"), t, 10)
assert_array_almost_equal_nulp(rdata.by_id("S2"), 2 * rdata.by_id("S3"))
assert_array_almost_equal_nulp(rdata.by_id("S4")[1:], 0.5 * np.diff(rdata.by_id("S3")), 10)
assert_array_almost_equal_nulp(rdata.by_id("p3"), 0)
assert_array_almost_equal_nulp(rdata.by_id("p2"), 1 + rdata.by_id("S1"))

0 comments on commit b1b4b2b

Please sign in to comment.