Skip to content

Commit

Permalink
factor in str types for "null_transition"
Browse files Browse the repository at this point in the history
  • Loading branch information
patricktnast committed Dec 11, 2024
1 parent eeafe9c commit 2987250
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/vivarium/framework/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _next_state(


def _groupby_new_state(
index: pd.Index[int], outputs: list[State], decisions: pd.Series[State]
) -> list[tuple[State, pd.Index[int]]]:
index: pd.Index[int], outputs: list[State | str], decisions: pd.Series[Any]
) -> list[tuple[State | str, pd.Index[int]]]:
"""Groups the simulants in the index by their new output state.
Parameters
Expand Down Expand Up @@ -230,7 +230,7 @@ def __init__(
self.state_id, allow_self_transition=allow_self_transition
)
self.initialization_weights = initialization_weights
self._model = None
self._model: str | None = None
self._sub_components = [self.transition_set]

##################
Expand Down Expand Up @@ -406,7 +406,9 @@ def setup(self, builder: Builder) -> None:
# Public methods #
##################

def choose_new_state(self, index: pd.Index[int]) -> tuple[list[State], pd.Series[State]]:
def choose_new_state(
self, index: pd.Index[int]
) -> tuple[list[State | str], pd.Series[Any]]:
"""Chooses a new state for each simulant in the index.
Parameters
Expand Down Expand Up @@ -447,8 +449,8 @@ def extend(self, transitions: Iterable[Transition]) -> None:
##################

def _normalize_probabilities(
self, outputs: list[State], probabilities: NumericArray
) -> tuple[list[State], NumericArray]:
self, outputs: list[State | str], probabilities: NumericArray
) -> tuple[list[State | str], NumericArray]:
"""Normalize probabilities to sum to 1 and add a null transition.
Parameters
Expand Down

0 comments on commit 2987250

Please sign in to comment.