diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 44642757..97c71cac 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.2.4 - 12/03/24** + + - Fix type hints for pandas groupby objects + **3.2.3 - 11/21/24** - Feature: Allow users to define initialization weights as LookupTableData or an artifact key diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 987c5777..0fd660f0 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -301,7 +301,7 @@ def gather_results( if filtered_pop.empty: yield None, None, None else: - pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] + pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool] if stratification_names is None: pop = filtered_pop else: @@ -335,7 +335,7 @@ def _filter_population( @staticmethod def _get_groups( stratifications: tuple[str, ...], filtered_pop: pd.DataFrame - ) -> DataFrameGroupBy[tuple[str, ...] | str]: + ) -> DataFrameGroupBy[tuple[str, ...] | str, bool]: """Group the population by stratification. Notes diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index b3fbfcdf..ac1117e4 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -61,7 +61,10 @@ class Observation(ABC): DataFrame or one with a complete set of stratifications as the index and all values set to 0.0.""" results_gatherer: Callable[ - [pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], tuple[str, ...] | None], + [ + pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], + tuple[str, ...] | None, + ], pd.DataFrame, ] """Method or function that gathers the new observation results.""" @@ -78,7 +81,7 @@ class Observation(ABC): def observe( self, event: Event, - df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame | None: """Determine whether to observe the given event, and if so, gather the results. @@ -141,7 +144,7 @@ def __init__( to_observe: Callable[[Event], bool] = lambda event: True, ): def _wrap_results_gatherer( - df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool], _: tuple[str, ...] | None, ) -> pd.DataFrame: if isinstance(df, DataFrameGroupBy): @@ -302,7 +305,7 @@ def create_expanded_df( def get_complete_stratified_results( self, - pop_groups: DataFrameGroupBy[str], + pop_groups: DataFrameGroupBy[str, bool], stratifications: tuple[str, ...], ) -> pd.DataFrame: """Gather results for this observation. @@ -327,7 +330,7 @@ def get_complete_stratified_results( @staticmethod def _aggregate( - pop_groups: DataFrameGroupBy[str], + pop_groups: DataFrameGroupBy[str, bool], aggregator_sources: list[str] | None, aggregator: Callable[[pd.DataFrame], float | pd.Series[float]], ) -> pd.Series[float] | pd.DataFrame: