diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 4dd8fcc65..de33e0472 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -412,8 +412,6 @@ def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor: """ # return tensor(1, n_tokens) if isinstance(model_input, str): - # pyre-ignore[9] pyre/mypy thinks return type may be List, but it will be - # Tensor return self.tokenizer.encode( # type: ignore model_input, return_tensors="pt" ).to(self.device) @@ -609,10 +607,14 @@ class created with the llm model that follows huggingface style else next(self.model.parameters()).device ) - def _format_model_input(self, model_input: Tensor) -> Tensor: + def _format_model_input(self, model_input: Union[Tensor, str]) -> Tensor: """ Convert str to tokenized tensor """ + if isinstance(model_input, str): + return self.tokenizer.encode( # type: ignore + model_input, return_tensors="pt" + ).to(self.device) return model_input.to(self.device) def attribute( diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index dd2c702d0..46ee2479b 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,6 +1,6 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Tuple, Union import torch @@ -116,17 +116,19 @@ def to_tensor(self) -> Tensor: pass @abstractmethod - # pyre-fixme[3]: Return annotation cannot be `Any`. - def to_model_input(self, itp_tensor: Optional[Tensor] = None) -> Any: + def to_model_input( + self, perturbed_tensor: Optional[Tensor] = None + ) -> Union[str, Tensor]: """ Get the (perturbed) input in the format required by the model based on the given (perturbed) interpretable representation. Args: - itp_tensor (Tensor, optional): tensor of the interpretable representation - of this input. If it is None, assume the interpretable - representation is pristine and return the original model input + perturbed_tensor (Tensor, optional): tensor of the interpretable + representation of this input. If it is None, assume the + interpretable representation is pristine and return the + original model input Default: None. @@ -198,13 +200,25 @@ class TextTemplateInput(InterpretableInput): """ + values: List[str] + dict_keys: List[str] + baselines: Union[List[str], Callable[[], Union[List[str], Dict[str, str]]]] + n_features: int + n_itp_features: int + format_fn: Callable[..., str] + mask: Union[List[int], Dict[str, int], None] + formatted_mask: List[int] + def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - template: Union[str, Callable], + template: Union[str, Callable[..., str]], values: Union[List[str], Dict[str, str]], - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - baselines: Union[List[str], Dict[str, str], Callable, None] = None, + baselines: Union[ + List[str], + Dict[str, str], + Callable[[], Union[List[str], Dict[str, str]]], + None, + ] = None, mask: Union[List[int], Dict[str, int], None] = None, ) -> None: # convert values dict to list @@ -217,8 +231,8 @@ def __init__( ), f"the values must be either a list or a dict, received: {type(values)}" dict_keys = [] - self.values: List[str] = values - self.dict_keys: List[str] = dict_keys + self.values = values + self.dict_keys = dict_keys n_features = len(values) @@ -261,15 +275,12 @@ def __init__( # internal compressed mask of continuous interpretable indices from 0 # cannot replace original mask of ids for grouping across values externally - # pyre-fixme[4]: Attribute must be annotated. self.formatted_mask = [mask_id_to_idx[mid] for mid in mask] n_itp_features = len(mask_ids) # number of raw features and intepretable features - # pyre-fixme[4]: Attribute must be annotated. self.n_features = n_features - # pyre-fixme[4]: Attribute must be annotated. self.n_itp_features = n_itp_features if isinstance(template, str): @@ -280,7 +291,6 @@ def __init__( f"received: {type(template)}" ) template = template - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. self.format_fn = template self.mask = mask @@ -289,8 +299,6 @@ def to_tensor(self) -> torch.Tensor: # Interpretable representation in shape(1, n_itp_features) return torch.tensor([[1.0] * self.n_itp_features]) - # pyre-fixme[14]: `to_model_input` overrides method defined in - # `InterpretableInput` inconsistently. def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str: values = list(self.values) # clone @@ -321,18 +329,12 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str: itp_val = perturbed_tensor[0][itp_idx] if not itp_val: - # pyre-fixme[16]: Item `None` of `Union[None, Dict[str, str], - # List[typing.Any]]` has no attribute `__getitem__`. values[i] = baselines[i] if self.dict_keys: dict_values = dict(zip(self.dict_keys, values)) - # pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not - # a function. input_str = self.format_fn(**dict_values) else: - # pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not - # a function. input_str = self.format_fn(*values) return input_str @@ -391,6 +393,14 @@ class TextTokenInput(InterpretableInput): """ + inp_tensor: Tensor + itp_tensor: Tensor + itp_mask: Optional[Tensor] + values: List[str] + tokenizer: TokenizerLike + n_itp_features: int + baselines: int + def __init__( self, text: str, @@ -401,11 +411,11 @@ def __init__( inp_tensor = tokenizer.encode(text, return_tensors="pt") # input tensor into the model of token ids - self.inp_tensor: Tensor = inp_tensor + self.inp_tensor = inp_tensor # tensor of interpretable token ids - self.itp_tensor: Tensor = inp_tensor + self.itp_tensor = inp_tensor # interpretable mask - self.itp_mask: Optional[Tensor] = None + self.itp_mask = None if skip_tokens: if isinstance(skip_tokens[0], str): @@ -426,13 +436,11 @@ def __init__( self.skip_tokens = skip_tokens # features values, the tokens - self.values: List[str] = tokenizer.convert_ids_to_tokens( - self.itp_tensor[0].tolist() - ) - self.tokenizer: TokenizerLike = tokenizer - self.n_itp_features: int = len(self.values) + self.values = tokenizer.convert_ids_to_tokens(self.itp_tensor[0].tolist()) + self.tokenizer = tokenizer + self.n_itp_features = len(self.values) - self.baselines: int = ( + self.baselines = ( baselines if type(baselines) is int else tokenizer.convert_tokens_to_ids([baselines])[0] # type: ignore @@ -442,8 +450,6 @@ def to_tensor(self) -> torch.Tensor: # return the perturbation indicator as interpretable tensor instead of token ids return torch.ones_like(self.itp_tensor) - # pyre-fixme[14]: `to_model_input` overrides method defined in - # `InterpretableInput` inconsistently. def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Tensor: if perturbed_tensor is None: return self.inp_tensor