From f5ae0b656363dbe7362ed2e1919a4f1e180599de Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:18:38 -0700 Subject: [PATCH 1/5] fix mypy errors: component.py (#549) --- CHANGELOG.rst | 4 + pyproject.toml | 1 - src/vivarium/component.py | 89 ++++++++++++----------- src/vivarium/framework/logging/manager.py | 5 +- tests/framework/test_engine.py | 2 +- 5 files changed, 55 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 97c71cac..6e873123 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**TBD/TBD/TBD** + + - Type-hinting: Fix mypy errors in vivarium/component.py + **3.2.4 - 12/03/24** - Fix type hints for pandas groupby objects diff --git a/pyproject.toml b/pyproject.toml index dcb48a96..2dfd433f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ exclude = [ # You will need to remove the mypy: ignore-errors comment from the file heading as well 'docs/source/conf.py', 'setup.py', - 'src/vivarium/component.py', 'src/vivarium/examples/boids/forces.py', 'src/vivarium/examples/boids/movement.py', 'src/vivarium/examples/boids/neighbors.py', diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 818bd174..9f5882e8 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ ========= Component @@ -18,16 +17,17 @@ from importlib import import_module from inspect import signature from numbers import Number -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import pandas as pd from layered_config_tree import ConfigurationError, LayeredConfigTree -from loguru._logger import Logger from vivarium.framework.artifact import ArtifactException from vivarium.framework.population import PopulationError if TYPE_CHECKING: + import loguru + from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.lookup import LookupTable @@ -92,7 +92,24 @@ class Component(ABC): component. An empty dictionary indicates no managed configurations. """ - def __repr__(self): + def __init__(self) -> None: + """Initializes a new instance of the Component class. + + This method is the initializer for the Component class. It initializes + logger of type Logger and population_view of type PopulationView to None. + These attributes will be fully initialized in the setup_component method + of this class. + """ + self._repr: str = "" + self._name: str = "" + self._sub_components: list["Component"] = [] + self.logger: loguru.Logger | None = None + self.get_value_columns: Callable[[str | pd.DataFrame], list[str]] | None = None + self.configuration: LayeredConfigTree | None = None + self._population_view: PopulationView | None = None + self.lookup_tables: dict[str, LookupTable] = {} + + def __repr__(self) -> str: """Returns a string representation of the __init__ call made to create this object. @@ -111,16 +128,17 @@ def __repr__(self): object. """ if not self._repr: - args = [ - f"{name}={value.__repr__() if isinstance(value, Component) else value}" - for name, value in self.get_initialization_parameters().items() - ] - args = ", ".join(args) + args = ", ".join( + [ + f"{name}={value.__repr__() if isinstance(value, Component) else value}" + for name, value in self.get_initialization_parameters().items() + ] + ) self._repr = f"{type(self).__name__}({args})" return self._repr - def __str__(self): + def __str__(self) -> str: return self._repr ############## @@ -244,15 +262,15 @@ def initialization_requirements( return [] @property - def population_view_query(self) -> str | None: + def population_view_query(self) -> str: """Provides a query to use when filtering the component's `PopulationView`. Returns ------- A pandas query string for filtering the component's `PopulationView`. - Returns `None` if no filtering is required. + Returns an empty string if no filtering is required. """ - return None + return "" @property def post_setup_priority(self) -> int: @@ -324,23 +342,6 @@ def simulation_end_priority(self) -> int: # Lifecycle methods # ##################### - def __init__(self) -> None: - """Initializes a new instance of the Component class. - - This method is the initializer for the Component class. It initializes - logger of type Logger and population_view of type PopulationView to None. - These attributes will be fully initialized in the setup_component method - of this class. - """ - self._repr: str = "" - self._name: str = "" - self._sub_components: list["Component"] = [] - self.logger: Logger | None = None - self.get_value_columns: Callable[[str | pd.DataFrame], list[str]] | None = None - self.configuration: LayeredConfigTree | None = None - self._population_view: PopulationView | None = None - self.lookup_tables: dict[str, LookupTable] = {} - def setup_component(self, builder: "Builder") -> None: """Sets up the component for a Vivarium simulation. @@ -515,7 +516,7 @@ def get_initialization_parameters(self) -> dict[str, Any]: """ return { parameter_name: getattr(self, parameter_name) - for parameter_name in signature(self.__init__).parameters + for parameter_name in signature(self.__init__).parameters # type: ignore[misc] if hasattr(self, parameter_name) } @@ -538,7 +539,7 @@ def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None: """ if self.name in builder.configuration: - return builder.configuration[self.name] + return builder.configuration.get_tree(self.name) return None def build_all_lookup_tables(self, builder: "Builder") -> None: @@ -606,7 +607,9 @@ def build_lookup_table( raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.") if isinstance(data, list): - return builder.lookup.build_table(data, value_columns=list(value_columns)) + return builder.lookup.build_table( + data, value_columns=list(value_columns) if value_columns else () + ) if isinstance(data, pd.DataFrame): duplicated_columns = set(data.columns[data.columns.duplicated()]) if duplicated_columns: @@ -627,11 +630,15 @@ def build_lookup_table( return builder.lookup.build_table(data) def _get_columns( - self, value_columns: Sequence[str] | None, data: float | pd.DataFrame - ) -> tuple[list[str], list[str], list[str]]: + self, value_columns: Sequence[str] | None, data: pd.DataFrame + ) -> tuple[Sequence[str], list[str], list[str]]: all_columns = list(data.columns) if value_columns is None: - value_columns = self.get_value_columns(data) + # NOTE: self.get_value_columns cannot be None at this point of the call stack + value_column_getter = cast( + Callable[[str | pd.DataFrame], list[str]], self.get_value_columns + ) + value_columns = value_column_getter(data) potential_parameter_columns = [ str(col).removesuffix("_start") @@ -685,9 +692,9 @@ def get_data(self, builder: Builder, data_source: DataInput) -> Any: module, method = data_source.split("::") try: if module == "self": - data_source = getattr(self, method) + data_source_callable = getattr(self, method) else: - data_source = getattr(import_module(module), method) + data_source_callable = getattr(import_module(module), method) except ModuleNotFoundError: raise ConfigurationError(f"Unable to find module '{module}'.") except AttributeError: @@ -697,7 +704,7 @@ def get_data(self, builder: Builder, data_source: DataInput) -> Any: raise ConfigurationError( f"There is no method '{method}' for the {module_string}." ) - data = data_source(builder) + data = data_source_callable(builder) else: try: data = builder.data.load(data_source) @@ -705,7 +712,7 @@ def get_data(self, builder: Builder, data_source: DataInput) -> Any: raise ConfigurationError( f"Failed to find key '{data_source}' in artifact." ) - elif isinstance(data_source, Callable): + elif callable(data_source): data = data_source(builder) else: data = data_source @@ -791,7 +798,7 @@ def _register_simulant_initializer(self, builder: Builder) -> None: if type(self).on_initialize_simulants != Component.on_initialize_simulants: builder.population.initializes_simulants( - self, creates_columns=self.columns_created, **initialization_requirements + self, creates_columns=self.columns_created, **initialization_requirements # type: ignore[arg-type] ) def _register_time_step_prepare_listener(self, builder: "Builder") -> None: diff --git a/src/vivarium/framework/logging/manager.py b/src/vivarium/framework/logging/manager.py index 0e442761..64b15b57 100644 --- a/src/vivarium/framework/logging/manager.py +++ b/src/vivarium/framework/logging/manager.py @@ -7,7 +7,6 @@ from __future__ import annotations import loguru -from loguru import logger from vivarium.framework.logging.utilities import configure_logging_to_terminal from vivarium.manager import Interface, Manager @@ -38,7 +37,7 @@ def _terminal_logging_not_configured() -> bool: # fragile since it depends on a loguru's internals as well as the stability of code # paths in vivarium, but both are quite stable at this point, so I think it's pretty, # low risk. - return 1 not in logger._core.handlers # type: ignore[attr-defined] + return 1 not in loguru.logger._core.handlers # type: ignore[attr-defined] @property def name(self) -> str: @@ -48,7 +47,7 @@ def get_logger(self, component_name: str | None = None) -> loguru.Logger: bind_args = {"simulation": self._simulation_name} if component_name: bind_args["component"] = component_name - return logger.bind(**bind_args) + return loguru.logger.bind(**bind_args) class LoggingInterface(Interface): diff --git a/tests/framework/test_engine.py b/tests/framework/test_engine.py index dd888ac1..37025db5 100644 --- a/tests/framework/test_engine.py +++ b/tests/framework/test_engine.py @@ -62,7 +62,7 @@ def components(): @pytest.fixture def log(mocker): - return mocker.patch("vivarium.framework.logging.manager.logger") + return mocker.patch("vivarium.framework.logging.manager.loguru.logger") def test_simulation_with_non_components(SimulationContext, components: list[Component]): From 778bf82b6ed8f466447673b7b6e9e89e44920e22 Mon Sep 17 00:00:00 2001 From: Jim Albright <37345113+albrja@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:04:54 -0800 Subject: [PATCH 2/5] Albrja/mic-5603/myppy-results-observer (#551) Albrja/mic-5603/myppy-results-observer Remove mypy errors in framework/results/observer.py - *Category*: Type-hinting - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5603 Changes and notes -remove errors in framework/results/observer.py ### Testing --- .github/workflows/build.yml | 1 + CHANGELOG.rst | 3 ++- pyproject.toml | 1 - src/vivarium/framework/results/observer.py | 1 - 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 20244e6c..ffd1738a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,6 +19,7 @@ jobs: ihmeuw/vivarium_build_utils/.github/workflows/build.yml@main with: dependencies: "layered_config_tree" + python_version: ${{ matrix.python-version }} secrets: notify_email: ${{ secrets.NOTIFY_EMAIL }} NOTIFY_PASSWORD: ${{ secrets.NOTIFY_PASSWORD }} diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6e873123..0ba32667 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,7 @@ -**TBD/TBD/TBD** +**3.2.5 - TBD/TBD/TBD - Type-hinting: Fix mypy errors in vivarium/component.py + - Type-hinting: Fix mypy errors in vivarium/framework/results/observer.py **3.2.4 - 12/03/24** diff --git a/pyproject.toml b/pyproject.toml index 2dfd433f..c4ffa35c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ exclude = [ 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', 'src/vivarium/framework/results/interface.py', - 'src/vivarium/framework/results/observer.py', 'src/vivarium/framework/state_machine.py', 'src/vivarium/interface/cli.py', 'src/vivarium/interface/interactive.py', diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 9cab55a6..39df182c 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ ========= Observers From d585f57de99dec0736b24734ff1727f7da52a4e9 Mon Sep 17 00:00:00 2001 From: Hussain Jafari Date: Wed, 11 Dec 2024 15:04:40 -0800 Subject: [PATCH 3/5] results/interface.py typing (#550) results/interface.py typing --- CHANGELOG.rst | 3 +- docs/nitpick-exceptions | 2 + pyproject.toml | 1 - src/vivarium/framework/results/interface.py | 72 +++++++++++---------- 4 files changed, 41 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0ba32667..4d09c6c4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,6 @@ -**3.2.5 - TBD/TBD/TBD +**3.2.5 - 12/11/24** + - Type-hinting: Fix mypy errors in vivarium/framework/results/interface.py - Type-hinting: Fix mypy errors in vivarium/component.py - Type-hinting: Fix mypy errors in vivarium/framework/results/observer.py diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index b6bcef6c..465f14f8 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -40,6 +40,8 @@ py:class VectorMapper py:class ScalarMapper py:class PandasObject py:class DataFrameGroupBy +py:class ResultsFormatter +py:class ResultsUpdater py:exc ResultsConfigurationError py:exc vivarium.framework.results.exceptions.ResultsConfigurationError diff --git a/pyproject.toml b/pyproject.toml index c4ffa35c..435dc82b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ exclude = [ 'src/vivarium/framework/engine.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/results/interface.py', 'src/vivarium/framework/state_machine.py', 'src/vivarium/interface/cli.py', 'src/vivarium/interface/interactive.py', diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index 6a105b4f..8f621dde 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ ================= Results Interface @@ -9,11 +8,13 @@ to a simulation. """ +from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Union import pandas as pd +from pandas.core.groupby.generic import DataFrameGroupBy from vivarium.framework.event import Event from vivarium.framework.results.observation import ( @@ -23,13 +24,29 @@ UnstratifiedObservation, ) from vivarium.manager import Interface -from vivarium.types import ScalarValue +from vivarium.types import ScalarMapper, VectorMapper if TYPE_CHECKING: from vivarium.framework.results.manager import ResultsManager -def _required_function_placeholder(*args, **kwargs) -> pd.DataFrame: +ResultsUpdater = Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] +"""This is a Callable that takes existing results and new observations and returns updated results.""" +ResultsFormatter = Callable[[str, pd.DataFrame], pd.DataFrame] +"""This is a Callable that takes a measure as a string and a DataFrame of observation results and returns formatted results.""" +ResultsGathererInput = Union[ + pd.DataFrame, DataFrameGroupBy, tuple[str, ...], None # type: ignore [type-arg] +] +ResultsGatherer = Callable[[ResultsGathererInput], pd.DataFrame] +"""This is a Callable that optionally takes a possibly stratified population and returns new observation results.""" + + +def _required_function_placeholder( + *args: ResultsGathererInput + | tuple[pd.DataFrame, pd.DataFrame] + | tuple[str, pd.DataFrame], + **kwargs: Any, +) -> pd.DataFrame: """Placeholder function to indicate that a required function is missing.""" return pd.DataFrame() @@ -56,8 +73,8 @@ class ResultsInterface(Interface): """ - def __init__(self, manager: "ResultsManager") -> None: - self._manager: "ResultsManager" = manager + def __init__(self, manager: ResultsManager) -> None: + self._manager: ResultsManager = manager self._name = "results_interface" @property @@ -75,9 +92,7 @@ def register_stratification( name: str, categories: list[str], excluded_categories: list[str] | None = None, - mapper: Callable[[pd.Series | pd.DataFrame], pd.Series] - | Callable[[ScalarValue], str] - | None = None, + mapper: VectorMapper | ScalarMapper | None = None, is_vectorized: bool = False, requires_columns: list[str] = [], requires_values: list[str] = [], @@ -127,7 +142,7 @@ def register_binned_stratification( labels: list[str] = [], excluded_categories: list[str] | None = None, target_type: str = "column", - **cut_kwargs: dict, + **cut_kwargs: int | str | bool, ) -> None: """Registers a binned stratification that can be used by stratified observations. @@ -173,16 +188,12 @@ def register_stratified_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_updater: Callable[ - [pd.DataFrame, pd.DataFrame], pd.DataFrame - ] = _required_function_placeholder, - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results, + results_updater: ResultsUpdater = _required_function_placeholder, + results_formatter: ResultsFormatter = lambda measure, results: results, additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], aggregator_sources: list[str] | None = None, - aggregator: Callable[[pd.DataFrame], float | pd.Series] = len, + aggregator: Callable[[pd.DataFrame], float | pd.Series[Any]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a stratified observation to the results system. @@ -249,15 +260,9 @@ def register_unstratified_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_gatherer: Callable[ - [pd.DataFrame], pd.DataFrame - ] = _required_function_placeholder, - results_updater: Callable[ - [pd.DataFrame, pd.DataFrame], pd.DataFrame - ] = _required_function_placeholder, - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results, + results_gatherer: ResultsGatherer = _required_function_placeholder, + results_updater: ResultsUpdater = _required_function_placeholder, + results_formatter: ResultsFormatter = lambda measure, results: results, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an unstratified observation to the results system. @@ -291,7 +296,7 @@ def register_unstratified_observation( ValueError If any required callable arguments are missing. """ - required_callables = { + required_callables: dict[str, Callable[..., pd.DataFrame]] = { "results_gatherer": results_gatherer, "results_updater": results_updater, } @@ -317,13 +322,11 @@ def register_adding_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results.reset_index(), + results_formatter: ResultsFormatter = lambda measure, results: results.reset_index(), additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], aggregator_sources: list[str] | None = None, - aggregator: Callable[[pd.DataFrame], float | pd.Series] = len, + aggregator: Callable[[pd.DataFrame], int | float | pd.Series[int | float]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an adding observation to the results system. @@ -388,9 +391,7 @@ def register_concatenating_observation( when: str = "collect_metrics", requires_columns: list[str] = [], requires_values: list[str] = [], - results_formatter: Callable[ - [str, pd.DataFrame], pd.DataFrame - ] = lambda measure, results: results, + results_formatter: ResultsFormatter = lambda measure, results: results, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a concatenating observation to the results system. @@ -438,7 +439,8 @@ def register_concatenating_observation( @staticmethod def _check_for_required_callables( - observation_name: str, required_callables: dict[str, Callable] + observation_name: str, + required_callables: dict[str, ResultsFormatter | ResultsGatherer | ResultsUpdater], ) -> None: """Raises a ValueError if any required callable arguments are missing.""" missing = [] From 0520d788da33b8180f38954e2aac8031c98802e9 Mon Sep 17 00:00:00 2001 From: patricktnast <130876799+patricktnast@users.noreply.github.com> Date: Thu, 12 Dec 2024 12:18:59 -0800 Subject: [PATCH 4/5] add git branch to conda env name (#555) * add git branch to conda env name * Update Jenkinsfile * drop "env." * Update CHANGELOG.rst --- CHANGELOG.rst | 4 ++++ Jenkinsfile | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4d09c6c4..6e6159e2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.2.6 - 12/12/24** + + - Change Jenkins conda env name + **3.2.5 - 12/11/24** - Type-hinting: Fix mypy errors in vivarium/framework/results/interface.py diff --git a/Jenkinsfile b/Jenkinsfile index 4d36290d..53a6f2af 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -11,7 +11,7 @@ def githubUsernameToSlackName(github_author) { } pipeline_name="vivarium" -conda_env_name="${pipeline_name}-${BUILD_NUMBER}" +conda_env_name="${pipeline_name}-${BRANCH_NAME}-${BUILD_NUMBER}" conda_env_path="/tmp/${conda_env_name}" // defaults for conda and pip are a local directory /svc-simsci for improved speed. // In the past, we used /ihme/code/* on the NFS (which is slower) @@ -260,4 +260,4 @@ pipeline { } // Python matrix bracket } // Python matrix stage bracket } // stages bracket -} // pipeline bracket \ No newline at end of file +} // pipeline bracket From cc190ea3b373d77d3d09ccb7b37a5323960afe9a Mon Sep 17 00:00:00 2001 From: patricktnast <130876799+patricktnast@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:35:48 -0800 Subject: [PATCH 5/5] Mypy State Machine (#554) * address first 20ish errors * use sequence for variance * remove number.Number * fix Categorical signature * remove deprecated fn * factor in str types for "null_transition" * use np.divide * add int * Update CHANGELOG.rst --- CHANGELOG.rst | 4 ++ pyproject.toml | 1 - src/vivarium/component.py | 7 +- src/vivarium/framework/lookup/manager.py | 5 +- src/vivarium/framework/state_machine.py | 83 +++++++++++++----------- src/vivarium/types.py | 3 +- 6 files changed, 55 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6e6159e2..a7c19de5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.2.7 - 12/12/24** + + - Type-hinting: Fix mypy errors in vivarium/framework/state_machine.py + **3.2.6 - 12/12/24** - Change Jenkins conda env name diff --git a/pyproject.toml b/pyproject.toml index 435dc82b..ae6b0ffb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ exclude = [ 'src/vivarium/framework/engine.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/state_machine.py', 'src/vivarium/interface/cli.py', 'src/vivarium/interface/interactive.py', 'src/vivarium/testing_utilities.py', diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 9f5882e8..8ccc8f26 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -16,7 +16,6 @@ from datetime import datetime, timedelta from importlib import import_module from inspect import signature -from numbers import Number from typing import TYPE_CHECKING, Any, cast import pandas as pd @@ -102,7 +101,7 @@ def __init__(self) -> None: """ self._repr: str = "" self._name: str = "" - self._sub_components: list["Component"] = [] + self._sub_components: Sequence["Component"] = [] self.logger: loguru.Logger | None = None self.get_value_columns: Callable[[str | pd.DataFrame], list[str]] | None = None self.configuration: LayeredConfigTree | None = None @@ -205,7 +204,7 @@ def population_view(self) -> PopulationView: return self._population_view @property - def sub_components(self) -> list["Component"]: + def sub_components(self) -> Sequence["Component"]: """Provide components managed by this component. Returns @@ -603,7 +602,7 @@ def build_lookup_table( data = self.get_data(builder, data_source) # TODO update this to use vivarium.types.LookupTableData once we drop # support for Python 3.9 - if not isinstance(data, (Number, timedelta, datetime, pd.DataFrame, list, tuple)): + if not isinstance(data, (float, int, timedelta, datetime, pd.DataFrame, list, tuple)): raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.") if isinstance(data, list): diff --git a/src/vivarium/framework/lookup/manager.py b/src/vivarium/framework/lookup/manager.py index 2863a3b9..fb96b974 100644 --- a/src/vivarium/framework/lookup/manager.py +++ b/src/vivarium/framework/lookup/manager.py @@ -14,7 +14,6 @@ from collections.abc import Sequence from datetime import datetime, timedelta -from numbers import Number from typing import TYPE_CHECKING import pandas as pd @@ -92,7 +91,7 @@ def _build_table( ) # Note datetime catches pandas timestamps - if isinstance(data, (Number, datetime, timedelta, list, tuple)): + if isinstance(data, (float, int, datetime, timedelta, list, tuple)): table: LookupTable = ScalarTable( table_number=table_number, data=data, @@ -204,7 +203,7 @@ def validate_build_table_parameters( ): raise ValueError("Must supply some data") - acceptable_types = (Number, datetime, timedelta, list, tuple, pd.DataFrame) + acceptable_types = (float, int, datetime, timedelta, list, tuple, pd.DataFrame) if not isinstance(data, acceptable_types): raise TypeError( f"The only allowable types for data are {acceptable_types}. " diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index f1ccc28b..fdc736c6 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ ============= State Machine @@ -7,11 +6,12 @@ A state machine implementation for use in ``vivarium`` simulations. """ + from __future__ import annotations -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterator import numpy as np import pandas as pd @@ -23,16 +23,16 @@ from vivarium.framework.event import Event from vivarium.framework.population import PopulationView, SimulantData from vivarium.framework.resource import Resource - from vivarium.types import ClockTime, DataInput + from vivarium.types import ClockTime, DataInput, NumericArray -def default_probability_function(index: pd.Index) -> pd.Series: +def default_probability_function(index: pd.Index[int]) -> pd.Series[float]: """Transition decision function that always triggers this transition.""" return pd.Series(1.0, index=index) def _next_state( - index: pd.Index, + index: pd.Index[int], event_time: ClockTime, transition_set: TransitionSet, population_view: PopulationView, @@ -72,8 +72,8 @@ def _next_state( def _groupby_new_state( - index: pd.Index, outputs: list, decisions: pd.Series -) -> list[tuple[str, pd.Index]]: + index: pd.Index[int], outputs: list[State | str], decisions: pd.Series[Any] +) -> list[tuple[State | str, pd.Index[int]]]: """Groups the simulants in the index by their new output state. Parameters @@ -93,7 +93,7 @@ def _groupby_new_state( into that state. """ groups = pd.Series(index).groupby( - pd.Categorical(decisions.values, categories=outputs), observed=False + pd.CategoricalIndex(decisions.values, categories=outputs), observed=False ) return [(output, pd.Index(sub_group.values)) for output, sub_group in groups] @@ -104,7 +104,7 @@ class Trigger(Enum): START_ACTIVE = 2 -def _process_trigger(trigger): +def _process_trigger(trigger: Trigger) -> tuple[pd.Index[int] | None, bool]: if trigger == Trigger.NOT_TRIGGERED: return None, False elif trigger == Trigger.START_INACTIVE: @@ -126,10 +126,10 @@ def __init__( self, input_state: State, output_state: State, - probability_func: Callable[[pd.Index], pd.Series] = lambda index: pd.Series( - 1.0, index=index - ), - triggered=Trigger.NOT_TRIGGERED, + probability_func: Callable[ + [pd.Index[int]], pd.Series[float] + ] = lambda index: pd.Series(1.0, index=index), + triggered: Trigger = Trigger.NOT_TRIGGERED, ): """Initializes a transition between two states. @@ -155,7 +155,7 @@ def __init__( # Public methods # ################## - def set_active(self, index: pd.Index) -> None: + def set_active(self, index: pd.Index[int]) -> None: if self._active_index is None: raise ValueError( "This transition is not triggered. An active index cannot be set or modified." @@ -163,7 +163,7 @@ def set_active(self, index: pd.Index) -> None: else: self._active_index = self._active_index.union(pd.Index(index)) - def set_inactive(self, index: pd.Index) -> None: + def set_inactive(self, index: pd.Index[int]) -> None: if self._active_index is None: raise ValueError( "This transition is not triggered. An active index cannot be set or modified." @@ -171,7 +171,7 @@ def set_inactive(self, index: pd.Index) -> None: else: self._active_index = self._active_index.difference(pd.Index(index)) - def probability(self, index: pd.Index) -> pd.Series: + def probability(self, index: pd.Index[int]) -> pd.Series[float]: if self._active_index is None: return self._probability(index) @@ -180,7 +180,8 @@ def probability(self, index: pd.Index) -> pd.Series: null_index = index.difference(self._active_index) activated = pd.Series(self._probability(activated_index), index=activated_index) null = pd.Series(np.zeros(len(null_index), dtype=float), index=null_index) - return activated.append(null) + activated.update(null) + return activated class State(Component): @@ -210,7 +211,7 @@ def configuration_defaults(self) -> dict[str, Any]: } @property - def model(self) -> str: + def model(self) -> str | None: return self._model ##################### @@ -229,7 +230,7 @@ def __init__( self.state_id, allow_self_transition=allow_self_transition ) self.initialization_weights = initialization_weights - self._model = None + self._model: str | None = None self._sub_components = [self.transition_set] ################## @@ -248,7 +249,7 @@ def set_model(self, model_name: str) -> None: self._model = model_name def next_state( - self, index: pd.Index, event_time: ClockTime, population_view: PopulationView + self, index: pd.Index[int], event_time: ClockTime, population_view: PopulationView ) -> None: """Moves a population between different states. @@ -264,7 +265,7 @@ def next_state( return _next_state(index, event_time, self.transition_set, population_view) def transition_effect( - self, index: pd.Index, event_time: ClockTime, population_view: PopulationView + self, index: pd.Index[int], event_time: ClockTime, population_view: PopulationView ) -> None: """Updates the simulation state and triggers any side-effects associated with entering this state. @@ -280,15 +281,17 @@ def transition_effect( population_view.update(pd.Series(self.state_id, index=index)) self.transition_side_effect(index, event_time) - def cleanup_effect(self, index: pd.Index, event_time: ClockTime) -> None: + def cleanup_effect(self, index: pd.Index[int], event_time: ClockTime) -> None: pass def add_transition( self, transition: Transition | None = None, output_state: State | None = None, - probability_function: Callable[[pd.Index], pd.Series] = default_probability_function, - triggered=Trigger.NOT_TRIGGERED, + probability_function: Callable[ + [pd.Index[int]], pd.Series[float] + ] = default_probability_function, + triggered: Trigger = Trigger.NOT_TRIGGERED, ) -> Transition: """Adds a transition to this state and its `TransitionSet`. @@ -334,7 +337,7 @@ def allow_self_transitions(self) -> None: # Helper methods # ################## - def transition_side_effect(self, index: pd.Index, event_time: ClockTime) -> None: + def transition_side_effect(self, index: pd.Index[int], event_time: ClockTime) -> None: pass @@ -383,7 +386,7 @@ def __init__( super().__init__() self.state_id = state_id self.allow_null_transition = allow_self_transition - self.transitions = [] + self.transitions: list[Transition] = [] self._sub_components = self.transitions self.extend(transitions) @@ -403,7 +406,9 @@ def setup(self, builder: Builder) -> None: # Public methods # ################## - def choose_new_state(self, index: pd.Index) -> tuple[list, pd.Series]: + def choose_new_state( + self, index: pd.Index[int] + ) -> tuple[list[State | str], pd.Series[Any]]: """Chooses a new state for each simulant in the index. Parameters @@ -443,7 +448,9 @@ def extend(self, transitions: Iterable[Transition]) -> None: # Helper methods # ################## - def _normalize_probabilities(self, outputs, probabilities): + def _normalize_probabilities( + self, outputs: list[State | str], probabilities: NumericArray + ) -> tuple[list[State | str], NumericArray]: """Normalize probabilities to sum to 1 and add a null transition. Parameters @@ -493,17 +500,17 @@ def _normalize_probabilities(self, outputs, probabilities): if np.any(total == 0): raise ValueError("No valid transitions for some simulants.") else: # total might be less than zero in some places - probabilities /= total[:, np.newaxis] + probabilities = np.divide(probabilities, total[:, np.newaxis]) return outputs, probabilities - def __iter__(self): + def __iter__(self) -> Iterator[Transition]: return iter(self.transitions) - def __len__(self): + def __len__(self) -> int: return len(self.transitions) - def __hash__(self): + def __hash__(self) -> int: return hash(id(self)) @@ -524,7 +531,7 @@ class Machine(Component): ############## @property - def sub_components(self): + def sub_components(self) -> Sequence[Component]: return self.states @property @@ -548,7 +555,7 @@ def __init__( initial_state: State | None = None, ) -> None: super().__init__() - self.states = [] + self.states: list[State] = [] self.state_column = state_column if states: self.add_states(states) @@ -615,7 +622,7 @@ def add_states(self, states: Iterable[State]) -> None: self.states.append(state) state.set_model(self.state_column) - def transition(self, index: pd.Index, event_time: ClockTime) -> None: + def transition(self, index: pd.Index[int], event_time: ClockTime) -> None: """Finds the population in each state and moves them to the next state. Parameters @@ -633,12 +640,12 @@ def transition(self, index: pd.Index, event_time: ClockTime) -> None: self.population_view.subview(self.state_column), ) - def cleanup(self, index: pd.Index, event_time: ClockTime) -> None: + def cleanup(self, index: pd.Index[int], event_time: ClockTime) -> None: for state, affected in self._get_state_pops(index): if not affected.empty: state.cleanup_effect(affected.index, event_time) - def _get_state_pops(self, index: pd.Index) -> list[tuple[State, pd.DataFrame]]: + def _get_state_pops(self, index: pd.Index[int]) -> list[tuple[State, pd.DataFrame]]: population = self.population_view.get(index) return [ (state, population[population[self.state_column] == state.state_id]) diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 0708cfa3..ba8d4a2f 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -1,6 +1,5 @@ from collections.abc import Callable from datetime import datetime, timedelta -from numbers import Number from typing import TYPE_CHECKING, Union import numpy as np @@ -17,7 +16,7 @@ ClockTime = Time | int ClockStepSize = Timedelta | int -ScalarValue = Number | Timedelta | Time +ScalarValue = float | int | Timedelta | Time LookupTableData = ScalarValue | pd.DataFrame | list[ScalarValue] | tuple[ScalarValue] DataInput = LookupTableData | str | Callable[["Builder"], LookupTableData]