diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0ba32667..4d09c6c4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,6 @@ -**3.2.5 - TBD/TBD/TBD +**3.2.5 - 12/11/24** + - Type-hinting: Fix mypy errors in vivarium/framework/results/interface.py - Type-hinting: Fix mypy errors in vivarium/component.py - Type-hinting: Fix mypy errors in vivarium/framework/results/observer.py diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index b6bcef6c..465f14f8 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -40,6 +40,8 @@ py:class VectorMapper py:class ScalarMapper py:class PandasObject py:class DataFrameGroupBy +py:class ResultsFormatter +py:class ResultsUpdater py:exc ResultsConfigurationError py:exc vivarium.framework.results.exceptions.ResultsConfigurationError diff --git a/pyproject.toml b/pyproject.toml index c4ffa35c..435dc82b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ exclude = [ 'src/vivarium/framework/engine.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/results/interface.py', 'src/vivarium/framework/state_machine.py', 'src/vivarium/interface/cli.py', 'src/vivarium/interface/interactive.py', diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index 6a105b4f..8f621dde 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ ================= Results Interface @@ -9,11 +8,13 @@ to a simulation. """ +from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Union import pandas as pd +from pandas.core.groupby.generic import DataFrameGroupBy from vivarium.framework.event import Event from vivarium.framework.results.observation import ( @@ -23,13 +24,29 @@ UnstratifiedObservation, ) from vivarium.manager import Interface -from vivarium.types import ScalarValue +from vivarium.types import ScalarMapper, VectorMapper if TYPE_CHECKING: from vivarium.framework.results.manager import ResultsManager -def _required_function_placeholder(*args, **kwargs) -> pd.DataFrame: +ResultsUpdater = Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] +"""This is a Callable that takes existing results and new observations and returns updated results.""" +ResultsFormatter = Callable[[str, pd.DataFrame], pd.DataFrame] +"""This is a Callable that takes a measure as a string and a DataFrame of observation results and returns formatted results.""" +ResultsGathererInput = Union[ + pd.DataFrame, DataFrameGroupBy, tuple[str, ...], None # type: ignore [type-arg] +] +ResultsGatherer = Callable[[ResultsGathererInput], pd.DataFrame] +"""This is a Callable that optionally takes a possibly stratified population and returns new observation results.""" + + +def _required_function_placeholder( + *args: ResultsGathererInput + | tuple[pd.DataFrame, pd.DataFrame] + | tuple[str, pd.DataFrame], + **kwargs: Any, +) -> pd.DataFrame: """Placeholder function to indicate that a required function is missing.""" return pd.DataFrame() @@ -56,8 +73,8 @@ class ResultsInterface(Interface): """ - def __init__(self, manager: "ResultsManager") -> None: - self._manager: "ResultsManager" = manager + def __init__(self, manager: ResultsManager) -> None: + self._manager: ResultsManager = manager self._name = "results_interface" @property @@ -75,9 +92,7 @@ def register_stratification( name: str, categories: list[str], excluded_categories: list[str] | None = None, - mapper: Callable[[pd.Series | pd.DataFrame], pd.Series] - | Callable[[ScalarValue], str] - | None = None, + mapper: VectorMapper | ScalarMapper | None = None, is_vectorized: bool = False, requires_columns: list[str] = [], requires_values: list[str] = [], @@ -127,7 +142,7 @@ def register_binned_stratification( labels: list[str] = [], excluded_categories: list[str] | None = None, target_type: str = "column", - **cut_kwargs: dict, + **cut_kwargs: int | str | bool, ) -> None: """Registers a binned stratification that can be used by stratified observations. @@ -173,16 +188,12 @@ def register_stratified_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_updater: Callable[ - [pd.DataFrame, pd.DataFrame], pd.DataFrame - ] = _required_function_placeholder, - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results, + results_updater: ResultsUpdater = _required_function_placeholder, + results_formatter: ResultsFormatter = lambda measure, results: results, additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], aggregator_sources: list[str] | None = None, - aggregator: Callable[[pd.DataFrame], float | pd.Series] = len, + aggregator: Callable[[pd.DataFrame], float | pd.Series[Any]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a stratified observation to the results system. @@ -249,15 +260,9 @@ def register_unstratified_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_gatherer: Callable[ - [pd.DataFrame], pd.DataFrame - ] = _required_function_placeholder, - results_updater: Callable[ - [pd.DataFrame, pd.DataFrame], pd.DataFrame - ] = _required_function_placeholder, - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results, + results_gatherer: ResultsGatherer = _required_function_placeholder, + results_updater: ResultsUpdater = _required_function_placeholder, + results_formatter: ResultsFormatter = lambda measure, results: results, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an unstratified observation to the results system. @@ -291,7 +296,7 @@ def register_unstratified_observation( ValueError If any required callable arguments are missing. """ - required_callables = { + required_callables: dict[str, Callable[..., pd.DataFrame]] = { "results_gatherer": results_gatherer, "results_updater": results_updater, } @@ -317,13 +322,11 @@ def register_adding_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results.reset_index(), + results_formatter: ResultsFormatter = lambda measure, results: results.reset_index(), additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], aggregator_sources: list[str] | None = None, - aggregator: Callable[[pd.DataFrame], float | pd.Series] = len, + aggregator: Callable[[pd.DataFrame], int | float | pd.Series[int | float]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an adding observation to the results system. @@ -388,9 +391,7 @@ def register_concatenating_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results, + results_formatter: ResultsFormatter = lambda measure, results: results, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a concatenating observation to the results system. @@ -438,7 +439,8 @@ def register_concatenating_observation( @staticmethod def _check_for_required_callables( - observation_name: str, required_callables: dict[str, Callable] + observation_name: str, + required_callables: dict[str, ResultsFormatter | ResultsGatherer | ResultsUpdater], ) -> None: """Raises a ValueError if any required callable arguments are missing.""" missing = []