Skip to content

Commit

Permalink
results/interface.py typing (#550)
Browse files Browse the repository at this point in the history
results/interface.py typing
  • Loading branch information
hussain-jafari authored Dec 11, 2024
1 parent 778bf82 commit d585f57
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
**3.2.5 - TBD/TBD/TBD
**3.2.5 - 12/11/24**

- Type-hinting: Fix mypy errors in vivarium/framework/results/interface.py
- Type-hinting: Fix mypy errors in vivarium/component.py
- Type-hinting: Fix mypy errors in vivarium/framework/results/observer.py

Expand Down
2 changes: 2 additions & 0 deletions docs/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ py:class VectorMapper
py:class ScalarMapper
py:class PandasObject
py:class DataFrameGroupBy
py:class ResultsFormatter
py:class ResultsUpdater
py:exc ResultsConfigurationError
py:exc vivarium.framework.results.exceptions.ResultsConfigurationError

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ exclude = [
'src/vivarium/framework/engine.py',
'src/vivarium/framework/population/manager.py',
'src/vivarium/framework/population/population_view.py',
'src/vivarium/framework/results/interface.py',
'src/vivarium/framework/state_machine.py',
'src/vivarium/interface/cli.py',
'src/vivarium/interface/interactive.py',
Expand Down
72 changes: 37 additions & 35 deletions src/vivarium/framework/results/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
=================
Results Interface
Expand All @@ -9,11 +8,13 @@
to a simulation.
"""
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Union

import pandas as pd
from pandas.core.groupby.generic import DataFrameGroupBy

from vivarium.framework.event import Event
from vivarium.framework.results.observation import (
Expand All @@ -23,13 +24,29 @@
UnstratifiedObservation,
)
from vivarium.manager import Interface
from vivarium.types import ScalarValue
from vivarium.types import ScalarMapper, VectorMapper

if TYPE_CHECKING:
from vivarium.framework.results.manager import ResultsManager


def _required_function_placeholder(*args, **kwargs) -> pd.DataFrame:
ResultsUpdater = Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
"""This is a Callable that takes existing results and new observations and returns updated results."""
ResultsFormatter = Callable[[str, pd.DataFrame], pd.DataFrame]
"""This is a Callable that takes a measure as a string and a DataFrame of observation results and returns formatted results."""
ResultsGathererInput = Union[
pd.DataFrame, DataFrameGroupBy, tuple[str, ...], None # type: ignore [type-arg]
]
ResultsGatherer = Callable[[ResultsGathererInput], pd.DataFrame]
"""This is a Callable that optionally takes a possibly stratified population and returns new observation results."""


def _required_function_placeholder(
*args: ResultsGathererInput
| tuple[pd.DataFrame, pd.DataFrame]
| tuple[str, pd.DataFrame],
**kwargs: Any,
) -> pd.DataFrame:
"""Placeholder function to indicate that a required function is missing."""
return pd.DataFrame()

Expand All @@ -56,8 +73,8 @@ class ResultsInterface(Interface):
"""

def __init__(self, manager: "ResultsManager") -> None:
self._manager: "ResultsManager" = manager
def __init__(self, manager: ResultsManager) -> None:
self._manager: ResultsManager = manager
self._name = "results_interface"

@property
Expand All @@ -75,9 +92,7 @@ def register_stratification(
name: str,
categories: list[str],
excluded_categories: list[str] | None = None,
mapper: Callable[[pd.Series | pd.DataFrame], pd.Series]
| Callable[[ScalarValue], str]
| None = None,
mapper: VectorMapper | ScalarMapper | None = None,
is_vectorized: bool = False,
requires_columns: list[str] = [],
requires_values: list[str] = [],
Expand Down Expand Up @@ -127,7 +142,7 @@ def register_binned_stratification(
labels: list[str] = [],
excluded_categories: list[str] | None = None,
target_type: str = "column",
**cut_kwargs: dict,
**cut_kwargs: int | str | bool,
) -> None:
"""Registers a binned stratification that can be used by stratified observations.
Expand Down Expand Up @@ -173,16 +188,12 @@ def register_stratified_observation(
when: str = "collect_metrics",
requires_columns: list[str] = [],
requires_values: list[str] = [],
results_updater: Callable[
[pd.DataFrame, pd.DataFrame], pd.DataFrame
] = _required_function_placeholder,
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
results_updater: ResultsUpdater = _required_function_placeholder,
results_formatter: ResultsFormatter = lambda measure, results: results,
additional_stratifications: list[str] = [],
excluded_stratifications: list[str] = [],
aggregator_sources: list[str] | None = None,
aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
aggregator: Callable[[pd.DataFrame], float | pd.Series[Any]] = len,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers a stratified observation to the results system.
Expand Down Expand Up @@ -249,15 +260,9 @@ def register_unstratified_observation(
when: str = "collect_metrics",
requires_columns: list[str] = [],
requires_values: list[str] = [],
results_gatherer: Callable[
[pd.DataFrame], pd.DataFrame
] = _required_function_placeholder,
results_updater: Callable[
[pd.DataFrame, pd.DataFrame], pd.DataFrame
] = _required_function_placeholder,
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
results_gatherer: ResultsGatherer = _required_function_placeholder,
results_updater: ResultsUpdater = _required_function_placeholder,
results_formatter: ResultsFormatter = lambda measure, results: results,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers an unstratified observation to the results system.
Expand Down Expand Up @@ -291,7 +296,7 @@ def register_unstratified_observation(
ValueError
If any required callable arguments are missing.
"""
required_callables = {
required_callables: dict[str, Callable[..., pd.DataFrame]] = {
"results_gatherer": results_gatherer,
"results_updater": results_updater,
}
Expand All @@ -317,13 +322,11 @@ def register_adding_observation(
when: str = "collect_metrics",
requires_columns: list[str] = [],
requires_values: list[str] = [],
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results.reset_index(),
results_formatter: ResultsFormatter = lambda measure, results: results.reset_index(),
additional_stratifications: list[str] = [],
excluded_stratifications: list[str] = [],
aggregator_sources: list[str] | None = None,
aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
aggregator: Callable[[pd.DataFrame], int | float | pd.Series[int | float]] = len,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers an adding observation to the results system.
Expand Down Expand Up @@ -388,9 +391,7 @@ def register_concatenating_observation(
when: str = "collect_metrics",
requires_columns: list[str] = [],
requires_values: list[str] = [],
results_formatter: Callable[
[str, pd.DataFrame], pd.DataFrame
] = lambda measure, results: results,
results_formatter: ResultsFormatter = lambda measure, results: results,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers a concatenating observation to the results system.
Expand Down Expand Up @@ -438,7 +439,8 @@ def register_concatenating_observation(

@staticmethod
def _check_for_required_callables(
observation_name: str, required_callables: dict[str, Callable]
observation_name: str,
required_callables: dict[str, ResultsFormatter | ResultsGatherer | ResultsUpdater],
) -> None:
"""Raises a ValueError if any required callable arguments are missing."""
missing = []
Expand Down

0 comments on commit d585f57

Please sign in to comment.