Skip to content

Commit

Permalink
resolve missing type hints (work in progress)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaasRostock committed Jan 24, 2025
1 parent a458a7a commit 38f4358
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 249 deletions.
140 changes: 79 additions & 61 deletions src/sim_explorer/assertion.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import ast
from collections.abc import Callable, Iterable, Iterator
from types import CodeType
from typing import Any

import numpy as np

from sim_explorer.case import Case, Results
from sim_explorer.models import AssertionResult, Temporal


Expand All @@ -26,7 +28,7 @@ class Assertion:
funcs (dict) : Dictionary of module : <list-of-functions> of allowed functions inside assertion expressions.
"""

def __init__(self, imports: dict | None = None):
def __init__(self, imports: dict | None = None) -> None:
if imports is None:
self._imports = {"math": ["sin", "cos", "sqrt"]} # default imports
else:
Expand All @@ -46,7 +48,7 @@ def __init__(self, imports: dict | None = None):
def info(self, sym: str, typ: str = "instance") -> str | int:
"""Retrieve detailed information related to the registered symbol 'sym'."""
if sym == "t": # the independent variable
return {"instance": "none", "variable": "t", "length": 1, "model": "none"}[typ] # type: ignore
return {"instance": "none", "variable": "t", "length": 1, "model": "none"}[typ] # type: ignore[return-value]

parts = sym.split("_")
var = parts.pop()
Expand All @@ -72,7 +74,7 @@ def info(self, sym: str, typ: str = "instance") -> str | int:
return self._cases_variables[var]["model"]
raise KeyError(f"Unknown typ {typ} within info()") from None

def symbol(self, name: str, length: int = 1):
def symbol(self, name: str, length: int = 1) -> str:
"""Get or set a symbol.
Args:
Expand All @@ -87,13 +89,13 @@ def symbol(self, name: str, length: int = 1):
except KeyError: # not yet registered
assert length > 0, f"Vector length should be positive. Found {length}"
if length > 1:
self._symbols.update({name: np.ones(length, dtype=float)}) # type: ignore
self._symbols.update({name: np.ones(length, dtype=float)}) # type: ignore[dict-item]
else:
self._symbols.update({name: 1})
sym = self._symbols[name]
return sym

def expr(self, key: str, ex: str | None = None):
def expr(self, key: str, ex: str | None = None) -> str | CodeType:
"""Get or set an expression.
Args:
Expand All @@ -103,21 +105,20 @@ def expr(self, key: str, ex: str | None = None):
Returns: the sympified expression
"""

def make_func(name: str, args: dict, body: str):
def make_func(name: str, args: dict, body: str) -> str:
"""Make a python function from the body."""
code = "def _" + name + "("
for a in args:
code += a + ", "
code += "):\n"
# code += " print('dir:', dir())\n"
code += " return " + body + "\n"
return code

if ex is None: # getter
try:
ex = self._expr[key]
except KeyError as err:
raise Exception(f"Expression with identificator {key} is not found") from err
except KeyError as e:
raise KeyError(f"Expression with identificator {key} is not found") from e
else:
return ex
else: # setter
Expand All @@ -126,28 +127,26 @@ def make_func(name: str, args: dict, body: str):
self._funcs.update({key: funcs})
code = make_func(key, syms, ex)
try:
# print("GLOBALS", globals())
# print("LOCALS", locals())
# exec( code, globals(), locals()) # compile using the defined symbols
# exec( code, globals(), locals()) # compile using the defined symbols # noqa: ERA001
compiled = compile(code, "<string>", "exec") # compile using the defined symbols
except ValueError as err:
raise Exception(f"Something wrong with expression {ex}: {err}|. Cannot compile.") from None
except ValueError as e:
raise ValueError(f"Something wrong with expression {ex}: {e}|. Cannot compile.") from e
else:
self._expr.update({key: ex})
self._compiled.update({key: compiled})
# print("KEY", key, ex, syms, compiled)
# print("KEY", key, ex, syms, compiled) # noqa: ERA001
return compiled

def syms(self, key: str):
def syms(self, key: str) -> list[str]:
"""Get the symbols of the expression 'key'."""
try:
syms = self._syms[key]
except KeyError as err:
raise Exception(f"Expression {key} was not found") from err
except KeyError as e:
raise KeyError(f"Expression {key} was not found") from e
else:
return syms

def expr_get_symbols_functions(self, expr: str) -> tuple:
def expr_get_symbols_functions(self, expr: str) -> tuple[list[str], list[str]]:
"""Get the symbols used in the expression.
1. Symbol as listed in expression and function body. In general <instant>_<variable>[<index>]
Expand All @@ -164,7 +163,9 @@ def expr_get_symbols_functions(self, expr: str) -> tuple:
funcs is a list of functions used in the expression.
"""

def ast_walk(node: ast.AST, syms: list | None = None, funcs: list | None = None):
def ast_walk(
node: ast.AST, syms: list[str] | None = None, funcs: list[str] | None = None
) -> tuple[list[str], list[str]]:
"""Recursively walk an ast node (width first) and collect symbol and function names."""
if syms is None:
syms = []
Expand All @@ -189,7 +190,7 @@ def ast_walk(node: ast.AST, syms: list | None = None, funcs: list | None = None)
syms = sorted(syms, key=list(self._symbols.keys()).index)
return (syms, funcs)

def temporal(self, key: str, typ: Temporal | str | None = None, args: tuple | None = None):
def temporal(self, key: str, typ: Temporal | str | None = None, args: tuple | None = None) -> dict[str, Any]:
"""Get or set a temporal instruction.
Args:
Expand All @@ -199,8 +200,8 @@ def temporal(self, key: str, typ: Temporal | str | None = None, args: tuple | No
if typ is None: # getter
try:
temp = self._temporal[key]
except KeyError as err:
raise Exception(f"Temporal instruction for {key} is not found") from err
except KeyError as e:
raise KeyError(f"Temporal instruction for {key} is not found") from e
else:
return temp
else: # setter
Expand All @@ -212,33 +213,39 @@ def temporal(self, key: str, typ: Temporal | str | None = None, args: tuple | No
raise ValueError(f"Unknown temporal type {typ}") from None
return self._temporal[key]

def description(self, key: str, descr: str | None = None):
def description(self, key: str, descr: str | None = None) -> str:
"""Get or set a description."""
if descr is None: # getter
try:
_descr = self._description[key]
except KeyError as err:
raise Exception(f"Description for {key} not found") from err
except KeyError as e:
raise KeyError(f"Description for {key} not found") from e
else:
return _descr
else: # setter
self._description.update({key: descr})
return descr

def assertions(self, key: str, res: bool | None = None, details: str | None = None, case_name: str | None = None):
def assertions(
self,
key: str,
res: bool | None = None,
details: str | None = None,
case_name: str | None = None,
) -> int | float | bool:
"""Get or set an assertion result."""
if res is None: # getter
try:
_res = self._assertions[key]
except KeyError as err:
raise Exception(f"Assertion results for {key} not found") from err
except KeyError as e:
raise KeyError(f"Assertion results for {key} not found") from e
else:
return _res
else: # setter
self._assertions.update({key: {"passed": res, "details": details, "case": case_name}})
return self._assertions[key]

def register_vars(self, variables: dict):
def register_vars(self, variables: dict[str, dict[str, Any]]) -> None:
"""Register the variables in varnames as symbols.
Can be used directly from Cases with varnames = tuple( Cases.variables.keys())
Expand All @@ -247,10 +254,10 @@ def register_vars(self, variables: dict):
for key, info in variables.items():
for inst in info["instances"]:
if len(info["instances"]) == 1: # the instance is unique
self.symbol(key, len(info["names"])) # we allow to use the 'short name' if unique
self.symbol(inst + "_" + key, len(info["names"])) # fully qualified name can always be used
_ = self.symbol(key, len(info["names"])) # we allow to use the 'short name' if unique
_ = self.symbol(inst + "_" + key, len(info["names"])) # fully qualified name can always be used

def make_locals(self, loc: dict):
def make_locals(self, loc: dict[str, Any]) -> dict[str, Any]:
"""Adapt the locals with 'allowed' functions."""
from importlib import import_module

Expand All @@ -261,7 +268,9 @@ def make_locals(self, loc: dict):
loc.update({"np": import_module("numpy")})
return loc

def _eval(self, func: Callable, kvargs: dict | list | tuple):
def _eval(
self, func: Callable[..., int | float | bool], kvargs: dict[str, Any] | list[Any] | tuple[Any, ...]
) -> int | float | bool:
"""Call a function of multiple arguments and return the single result.
All internal vecor arguments are transformed to np.arrays.
"""
Expand All @@ -275,16 +284,16 @@ def _eval(self, func: Callable, kvargs: dict | list | tuple):
if isinstance(v, Iterable):
kvargs[i] = np.array(v, dtype=float)
return func(*kvargs)
if isinstance(kvargs, tuple):
_args = [] # make new, because tuple is not mutable
for v in kvargs:
if isinstance(v, Iterable):
_args.append(np.array(v, dtype=float))
else:
_args.append(v)
return func(*_args)
assert isinstance(kvargs, tuple), f"Unknown type of kvargs {kvargs}"
_args = [] # make new, because tuple is not mutable
for v in kvargs:
if isinstance(v, Iterable):
_args.append(np.array(v, dtype=float))
else:
_args.append(v)
return func(*_args)

def eval_single(self, key: str, kvargs: dict | list | tuple):
def eval_single(self, key: str, kvargs: dict[str, Any] | list[Any] | tuple[Any, ...]) -> int | float | bool:
"""Perform assertion of 'key' on a single data point.
Args:
Expand All @@ -296,11 +305,19 @@ def eval_single(self, key: str, kvargs: dict | list | tuple):
"""
assert key in self._compiled, f"Expression {key} not found"
loc = self.make_locals(locals())
exec(self._compiled[key], loc, loc)
# print("kvargs", kvargs, self._syms[key], self.expr_get_symbols_functions(key))
exec(self._compiled[key], loc, loc) # noqa: S102
# print("kvargs", kvargs, self._syms[key], self.expr_get_symbols_functions(key)) # noqa: ERA001
return self._eval(locals()["_" + key], kvargs)

def eval_series(self, key: str, data: list, ret: float | str | Callable | None = None):
def eval_series(
self,
key: str,
data: list[list[int | float | bool]],
ret: float
| str
| Callable[[list[int | float | bool]], list[int | float | bool] | int | float | bool]
| None = None,
) -> tuple[int | float | list[int | float], int | float | bool | list[int | float | bool]]:
"""Perform assertion on a (time) series.
Args:
Expand Down Expand Up @@ -328,12 +345,13 @@ def eval_series(self, key: str, data: list, ret: float | str | Callable | None =
)
argnames = self._syms[key]
loc = self.make_locals(locals())
exec(self._compiled[key], loc, loc) # the function is then available as _<key> among locals()
exec(self._compiled[key], loc, loc) # the function is then available as _<key> among locals() # noqa: S102
func = locals()["_" + key] # scalar function of all used arguments
_temp = self._temporal[key]["type"] if ret is None else Temporal.UNDEFINED

for row in data:
if not isinstance(row, (tuple, list)): # can happen if the time itself is evaluated
for _row in data:
row = _row
if not isinstance(row, tuple | list): # can happen if the time itself is evaluated
time = row
row = [row]
elif "t" not in argnames: # the independent variable is not explicitly used in the expression
Expand Down Expand Up @@ -370,19 +388,21 @@ def eval_series(self, key: str, data: list, ret: float | str | Callable | None =
else:
assert len(self._temporal[key]["args"]), "Need a temporal argument (time at which to interpolate)"
t0 = self._temporal[key]["args"][0]
# idx = min(range(len(times)), key=lambda i: abs(times[i]-t0))
# print("INDEX", t0, idx, results[idx-10:idx+10])
# return (t0, results[idx])
# else:
interpolated = np.interp(t0, times, results)
return (t0, bool(interpolated) if all(isinstance(res, bool) for res in results) else interpolated)
if callable(ret):
return (times, ret(results))
raise ValueError(f"Unknown return type '{ret}'") from None

def do_assert(self, key: str, result: Any, case_name: str | None = None):
def do_assert(
self,
key: str,
result: Results,
case_name: str | None = None,
) -> int | float | bool | list[int | float | bool]:
"""Perform assert action 'key' on data of 'result' object."""
assert isinstance(key, str) and key in self._temporal, f"Assertion key {key} not found"
assert isinstance(key, str), f"Key should be a string. Found {key}"
assert key in self._temporal, f"Assertion key {key} not found"
from sim_explorer.case import Results

assert isinstance(result, Results), f"Results object expected. Found {result}"
Expand All @@ -406,7 +426,7 @@ def do_assert(self, key: str, result: Any, case_name: str | None = None):
self.assertions(key, res[1], f"@{res[0]} (interpolated)", case_name)
return res[1]

def do_assert_case(self, result: Any) -> list[int]:
def do_assert_case(self, result: Results) -> list[int]:
"""Perform all assertions defined for the case related to the result object."""
count = [0, 0]
for key in result.case.asserts:
Expand All @@ -415,19 +435,17 @@ def do_assert_case(self, result: Any) -> list[int]:
count[1] += 1
return count

def report(self, case: Any = None) -> Iterator[AssertionResult]:
def report(self, case: Case | None = None) -> Iterator[AssertionResult]:
"""Report on all registered asserts.
If case denotes a case object, only the results for this case are reported.
"""

def do_report(key: str):
def do_report(key: str) -> AssertionResult:
time_arg = self._temporal[key].get("args", None)
return AssertionResult(
key=key,
expression=self._expr[key],
time=time_arg[0]
if len(time_arg) > 0 and (isinstance(time_arg[0], int) or isinstance(time_arg[0], float))
else None,
time=(time_arg[0] if len(time_arg) > 0 and (isinstance(time_arg[0], int | float)) else None),
result=self._assertions[key].get("passed", False),
description=self._description[key],
temporal=self._temporal[key].get("type", None),
Expand Down
Loading

0 comments on commit 38f4358

Please sign in to comment.