Skip to content

Commit

Permalink
Fix pandas groupby type hints for pandas stubs update (#548)
Browse files Browse the repository at this point in the history
Fix pandas groupby type hints for pandas stubs update

Fix type hints for pandas gropuby
- *Category*: Typing
- *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5605

Changes and notes
-fix pd.groupby objects type hints

### Testing
<!--
Details on how code was verified, any unit tests local for the
repo, regression testing, etc. At a minimum, this should include an
integration test for a framework change. Consider: plots, images,
(small) csv file.
-->
  • Loading branch information
albrja authored Dec 3, 2024
1 parent feace39 commit aaa271f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.4 - 12/03/24**

- Fix type hints for pandas groupby objects

**3.2.3 - 11/21/24**

- Feature: Allow users to define initialization weights as LookupTableData or an artifact key
Expand Down
4 changes: 2 additions & 2 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def gather_results(
if filtered_pop.empty:
yield None, None, None
else:
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str]
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool]
if stratification_names is None:
pop = filtered_pop
else:
Expand Down Expand Up @@ -335,7 +335,7 @@ def _filter_population(
@staticmethod
def _get_groups(
stratifications: tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy[tuple[str, ...] | str]:
) -> DataFrameGroupBy[tuple[str, ...] | str, bool]:
"""Group the population by stratification.
Notes
Expand Down
13 changes: 8 additions & 5 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ class Observation(ABC):
DataFrame or one with a complete set of stratifications as the index and
all values set to 0.0."""
results_gatherer: Callable[
[pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], tuple[str, ...] | None],
[
pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool],
tuple[str, ...] | None,
],
pd.DataFrame,
]
"""Method or function that gathers the new observation results."""
Expand All @@ -78,7 +81,7 @@ class Observation(ABC):
def observe(
self,
event: Event,
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str],
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool],
stratifications: tuple[str, ...] | None,
) -> pd.DataFrame | None:
"""Determine whether to observe the given event, and if so, gather the results.
Expand Down Expand Up @@ -141,7 +144,7 @@ def __init__(
to_observe: Callable[[Event], bool] = lambda event: True,
):
def _wrap_results_gatherer(
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str],
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool],
_: tuple[str, ...] | None,
) -> pd.DataFrame:
if isinstance(df, DataFrameGroupBy):
Expand Down Expand Up @@ -302,7 +305,7 @@ def create_expanded_df(

def get_complete_stratified_results(
self,
pop_groups: DataFrameGroupBy[str],
pop_groups: DataFrameGroupBy[str, bool],
stratifications: tuple[str, ...],
) -> pd.DataFrame:
"""Gather results for this observation.
Expand All @@ -327,7 +330,7 @@ def get_complete_stratified_results(

@staticmethod
def _aggregate(
pop_groups: DataFrameGroupBy[str],
pop_groups: DataFrameGroupBy[str, bool],
aggregator_sources: list[str] | None,
aggregator: Callable[[pd.DataFrame], float | pd.Series[float]],
) -> pd.Series[float] | pd.DataFrame:
Expand Down

0 comments on commit aaa271f

Please sign in to comment.