Skip to content

Commit

Permalink
Fix pandas groupby type hints for pandas stubs update
Browse files Browse the repository at this point in the history
  • Loading branch information
albrja committed Dec 2, 2024
1 parent feace39 commit de57b41
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def gather_results(
if filtered_pop.empty:
yield None, None, None
else:
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str]
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool]
if stratification_names is None:
pop = filtered_pop
else:
Expand Down Expand Up @@ -335,7 +335,7 @@ def _filter_population(
@staticmethod
def _get_groups(
stratifications: tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy[tuple[str, ...] | str]:
) -> DataFrameGroupBy[tuple[str, ...] | str, bool]:
"""Group the population by stratification.
Notes
Expand All @@ -356,7 +356,7 @@ def _get_groups(
)
else:
pop_groups = filtered_pop.groupby(lambda _: "all")
return pop_groups # type: ignore[return-value]
return pop_groups # type: ignore [return-value]

def _rename_stratification_columns(self, results: pd.DataFrame) -> None:
"""Convert the temporary stratified mapped index names back to their original names."""
Expand Down
13 changes: 8 additions & 5 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ class Observation(ABC):
DataFrame or one with a complete set of stratifications as the index and
all values set to 0.0."""
results_gatherer: Callable[
[pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], tuple[str, ...] | None],
[
pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool],
tuple[str, ...] | None,
],
pd.DataFrame,
]
"""Method or function that gathers the new observation results."""
Expand All @@ -78,7 +81,7 @@ class Observation(ABC):
def observe(
self,
event: Event,
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str],
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool],
stratifications: tuple[str, ...] | None,
) -> pd.DataFrame | None:
"""Determine whether to observe the given event, and if so, gather the results.
Expand Down Expand Up @@ -141,7 +144,7 @@ def __init__(
to_observe: Callable[[Event], bool] = lambda event: True,
):
def _wrap_results_gatherer(
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str],
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool],
_: tuple[str, ...] | None,
) -> pd.DataFrame:
if isinstance(df, DataFrameGroupBy):
Expand Down Expand Up @@ -302,7 +305,7 @@ def create_expanded_df(

def get_complete_stratified_results(
self,
pop_groups: DataFrameGroupBy[str],
pop_groups: DataFrameGroupBy[str, bool],
stratifications: tuple[str, ...],
) -> pd.DataFrame:
"""Gather results for this observation.
Expand All @@ -327,7 +330,7 @@ def get_complete_stratified_results(

@staticmethod
def _aggregate(
pop_groups: DataFrameGroupBy[str],
pop_groups: DataFrameGroupBy[str, bool],
aggregator_sources: list[str] | None,
aggregator: Callable[[pd.DataFrame], float | pd.Series[float]],
) -> pd.Series[float] | pd.DataFrame:
Expand Down

0 comments on commit de57b41

Please sign in to comment.