Skip to content

Commit

Permalink
Fix baselines utils pyre fix me issues
Browse files Browse the repository at this point in the history
Differential Revision: D67706854
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent 2bc4321 commit e5f2af2
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions captum/attr/_utils/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

# pyre-strict
import random
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, Generic, List, Tuple, TypeVar, Union

GenericBaselineType = TypeVar("GenericBaselineType")

class ProductBaselines:

class ProductBaselines(Generic[GenericBaselineType]):
"""
A Callable Baselines class that returns a sample from the Cartesian product of
the inputs' available baselines.
Expand All @@ -22,10 +24,9 @@ class ProductBaselines:

def __init__(
self,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
baseline_values: Union[
List[List[Any]],
Dict[Union[str, Tuple[str, ...]], List[Any]],
List[List[GenericBaselineType]],
Dict[Union[str, Tuple[str, ...]], List[GenericBaselineType]],
],
) -> None:
if isinstance(baseline_values, dict):
Expand All @@ -38,9 +39,10 @@ def __init__(
self.dict_keys = dict_keys
self.baseline_values = baseline_values

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def sample(self) -> Union[List[Any], Dict[str, Any]]:
baselines = [
def sample(
self,
) -> Union[List[GenericBaselineType], Dict[str, GenericBaselineType]]:
baselines: List[GenericBaselineType] = [
random.choice(baseline_list) for baseline_list in self.baseline_values
]

Expand All @@ -50,15 +52,18 @@ def sample(self) -> Union[List[Any], Dict[str, Any]]:
dict_baselines = {}
for key, val in zip(self.dict_keys, baselines):
if not isinstance(key, tuple):
key, val = (key,), (val,)
key_tuple, val_tuple = (key,), (val,)
else:
key_tuple, val_tuple = key, val

for k, v in zip(key, val):
for k, v in zip(key_tuple, val_tuple):
dict_baselines[k] = v

return dict_baselines

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def __call__(self) -> Union[List[Any], Dict[str, Any]]:
def __call__(
self,
) -> Union[List[GenericBaselineType], Dict[str, GenericBaselineType]]:
"""
Returns:
Expand Down

0 comments on commit e5f2af2

Please sign in to comment.