Skip to content

Commit

Permalink
Fix class summarizer pyre fix me issues
Browse files Browse the repository at this point in the history
Differential Revision: D67706853
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 30, 2024
1 parent bed0f68 commit 582760e
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions captum/attr/_utils/class_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
from typing import Any, cast, Dict, Generic, List, Optional, TypeVar, Union

from captum._utils.common import _format_tensor_into_tuples
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
Expand All @@ -11,8 +11,10 @@
from captum.log import log_usage
from torch import Tensor

KeyType = TypeVar("KeyType")

class ClassSummarizer(Summarizer):

class ClassSummarizer(Summarizer, Generic[KeyType]):
r"""
Used to keep track of summaries for associated classes. The
classes/labels can be of any type that are supported by `dict`.
Expand All @@ -23,8 +25,7 @@ class ClassSummarizer(Summarizer):
@log_usage()
def __init__(self, stats: List[Stat]) -> None:
Summarizer.__init__.__wrapped__(self, stats)
# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
self.summaries: Dict[Any, Summarizer] = defaultdict(
self.summaries: Dict[KeyType, Summarizer] = defaultdict(
lambda: Summarizer(stats=stats)
)

Expand Down Expand Up @@ -84,15 +85,15 @@ def update( # type: ignore
tensors_to_summarize_copy = tuple(tensor[i].clone() for tensor in x)
label = labels_typed[0] if len(labels_typed) == 1 else labels_typed[i]

self.summaries[label].update(tensors_to_summarize)
self.summaries[cast(KeyType, label)].update(tensors_to_summarize)
super().update(tensors_to_summarize_copy)

@property
# pyre-fixme[3]: Return annotation cannot contain `Any`.
def class_summaries(
self,
) -> Dict[
Any, Union[None, Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]]
KeyType,
Union[None, Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]],
]:
r"""
Returns:
Expand Down

0 comments on commit 582760e

Please sign in to comment.