Skip to content

Commit

Permalink
Fix progress pyre fixme issues
Browse files Browse the repository at this point in the history
Summary: Fixing unresolved pyre fixme issues in corresponding file

Differential Revision: D67725994
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 31, 2024
1 parent 6bfb570 commit 88c1c47
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 64 deletions.
126 changes: 79 additions & 47 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,34 @@
# pyre-strict

import sys
import typing
import warnings
from time import time
from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO
from types import TracebackType
from typing import (
Any,
Callable,
cast,
Generic,
Iterable,
Iterator,
Literal,
Optional,
Sized,
TextIO,
Type,
TypeVar,
Union,
)

try:
from tqdm.auto import tqdm
except ImportError:
tqdm = None

T = TypeVar("T")
IterableType = TypeVar("IterableType")


class DisableErrorIOWrapper(object):
def __init__(self, wrapped: TextIO) -> None:
Expand All @@ -21,15 +40,13 @@ def __init__(self, wrapped: TextIO) -> None:
"""
self._wrapped = wrapped

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __getattr__(self, name):
def __getattr__(self, name: str) -> object:
return getattr(self._wrapped, name)

@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _wrapped_run(func, *args, **kwargs):
def _wrapped_run(
func: Callable[..., T], *args: object, **kwargs: object
) -> Union[T, None]:
try:
return func(*args, **kwargs)
except OSError as e:
Expand All @@ -38,19 +55,16 @@ def _wrapped_run(func, *args, **kwargs):
except ValueError as e:
if "closed" not in str(e):
raise
return None

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def write(self, *args, **kwargs):
def write(self, *args: object, **kwargs: object) -> Optional[int]:
return self._wrapped_run(self._wrapped.write, *args, **kwargs)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def flush(self, *args, **kwargs):
def flush(self, *args: object, **kwargs: object) -> None:
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)


class NullProgress:
class NullProgress(Iterable[IterableType]):
"""Passthrough class that implements the progress API.
This class implements the tqdm and SimpleProgressBar api but
Expand All @@ -61,27 +75,28 @@ class NullProgress:

def __init__(
self,
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
iterable: Optional[Iterable] = None,
iterable: Optional[Iterable[IterableType]] = None,
*args: Any,
**kwargs: Any,
) -> None:
del args, kwargs
self.iterable = iterable

def __enter__(self) -> "NullProgress":
def __enter__(self) -> "NullProgress[IterableType]":
return self

# pyre-fixme[2]: Parameter must be annotated.
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
def __exit__(
self,
exc_type: Union[Type[BaseException], None],
exc_value: Union[BaseException, None],
exc_traceback: Union[TracebackType, None],
) -> Literal[False]:
return False

# pyre-fixme[3]: Return type must be annotated.
def __iter__(self):
def __iter__(self) -> Iterator[IterableType]:
if not self.iterable:
return
# pyre-fixme[16]: `Optional` has no attribute `__iter__`.
for it in self.iterable:
for it in cast(Iterable[IterableType], self.iterable):
yield it

def update(self, amount: int = 1) -> None:
Expand All @@ -91,11 +106,10 @@ def close(self) -> None:
pass


class SimpleProgress:
class SimpleProgress(Iterable[IterableType]):
def __init__(
self,
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
iterable: Optional[Iterable] = None,
iterable: Optional[Iterable[IterableType]] = None,
desc: Optional[str] = None,
total: Optional[int] = None,
file: Optional[TextIO] = None,
Expand All @@ -117,34 +131,33 @@ def __init__(

self.desc = desc

# pyre-fixme[9]: file has type `Optional[TextIO]`; used as
# `DisableErrorIOWrapper`.
file = DisableErrorIOWrapper(file if file else sys.stderr)
cast(TextIO, file)
self.file = file
file_wrapper = DisableErrorIOWrapper(file if file else sys.stderr)
self.file: DisableErrorIOWrapper = file_wrapper

self.mininterval = mininterval
self.last_print_t = 0.0
self.closed = False
self._is_parent = False

def __enter__(self) -> "SimpleProgress":
def __enter__(self) -> "SimpleProgress[IterableType]":
self._is_parent = True
self._refresh()
return self

# pyre-fixme[2]: Parameter must be annotated.
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
def __exit__(
self,
exc_type: Union[Type[BaseException], None],
exc_value: Union[BaseException, None],
exc_traceback: Union[TracebackType, None],
) -> Literal[False]:
self.close()
return False

# pyre-fixme[3]: Return type must be annotated.
def __iter__(self):
def __iter__(self) -> Iterator[IterableType]:
if self.closed or not self.iterable:
return
self._refresh()
# pyre-fixme[16]: `Optional` has no attribute `__iter__`.
for it in self.iterable:
for it in cast(Iterable[IterableType], self.iterable):
yield it
self.update()
self.close()
Expand All @@ -153,9 +166,7 @@ def _refresh(self) -> None:
progress_str = self.desc + ": " if self.desc else ""
if self.total:
# e.g., progress: 60% 3/5
# pyre-fixme[58]: `//` is not supported for operand types `int` and
# `Optional[int]`.
progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
progress_str += f"{100 * self.cur // cast(int, self.total)}% {self.cur}/{cast(int, self.total)}"
else:
# e.g., progress: .....
progress_str += "." * self.cur
Expand All @@ -179,18 +190,39 @@ def close(self) -> None:
self.closed = True


# pyre-fixme[3]: Return type must be annotated.
@typing.overload
def progress(
iterable: None = None,
desc: Optional[str] = None,
total: Optional[int] = None,
use_tqdm: bool = True,
file: Optional[TextIO] = None,
mininterval: float = 0.5,
**kwargs: object,
) -> Union[SimpleProgress[None], tqdm]: ...


@typing.overload
def progress(
iterable: Iterable[IterableType],
desc: Optional[str] = None,
total: Optional[int] = None,
use_tqdm: bool = True,
file: Optional[TextIO] = None,
mininterval: float = 0.5,
**kwargs: object,
) -> Union[SimpleProgress[IterableType], tqdm]: ...


def progress(
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
iterable: Optional[Iterable] = None,
iterable: Optional[Iterable[IterableType]] = None,
desc: Optional[str] = None,
total: Optional[int] = None,
use_tqdm: bool = True,
file: Optional[TextIO] = None,
mininterval: float = 0.5,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
):
**kwargs: object,
) -> Union[SimpleProgress[IterableType], tqdm]:
# Try to use tqdm is possible. Fall back to simple progress print
if tqdm and use_tqdm:
return tqdm(
Expand Down
29 changes: 21 additions & 8 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import warnings
from abc import abstractmethod
from os.path import join
from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
cast,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
)

import torch
from captum._utils.av import AV
Expand Down Expand Up @@ -1033,10 +1044,12 @@ def _influence(
inputs = _format_inputs_dataset(inputs)

train_dataloader = self.train_dataloader

data_iterable: Union[Iterable[Tuple[object, ...]], DataLoader] = (
train_dataloader
)
if show_progress:
train_dataloader = progress(
train_dataloader,
data_iterable = progress(
cast(Iterable[Tuple[object, ...]], train_dataloader),
desc=(
f"Using {self.get_name()} to compute "
"influence for training batches"
Expand All @@ -1053,7 +1066,7 @@ def _influence(
return torch.cat(
[
self._influence_batch_tracincp(inputs_checkpoint_jacobians, batch)
for batch in train_dataloader
for batch in data_iterable
],
dim=1,
)
Expand Down Expand Up @@ -1250,7 +1263,7 @@ def get_checkpoint_contribution(checkpoint: str) -> Tensor:
# the same)
checkpoint_contribution = []

_inputs = inputs
_inputs: Union[DataLoader, Iterable[Tuple[Tensor, ...]]] = inputs
# If `show_progress` is true, create an inner progress bar that keeps track
# of how many batches have been processed for the current checkpoint
if show_progress:
Expand All @@ -1266,8 +1279,8 @@ def get_checkpoint_contribution(checkpoint: str) -> Tensor:
for batch in _inputs:

layer_jacobians = self._basic_computation_tracincp(
batch[0:-1],
batch[-1],
cast(Tuple[Tensor, ...], batch)[0:-1],
cast(Tuple[Tensor, ...], batch)[-1],
self.loss_fn,
self.reduction_type,
)
Expand Down
22 changes: 18 additions & 4 deletions captum/influence/_core/tracincp_fast_rand_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
import threading
import warnings
from collections import defaultdict
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)

import torch
from captum._utils.common import _get_module_from_name, _sort_key_list
Expand Down Expand Up @@ -418,10 +429,13 @@ def _influence( # type: ignore[override]
"""

train_dataloader = self.train_dataloader
train_dataloader_iterable: Union[DataLoader, Iterable[Tuple[object, ...]]] = (
train_dataloader
)

if show_progress:
train_dataloader = progress(
train_dataloader,
train_dataloader_iterable = progress(
cast(Iterable[Tuple[object, ...]], train_dataloader),
desc=(
f"Using {self.get_name()} to compute "
"influence for training batches"
Expand All @@ -432,7 +446,7 @@ def _influence( # type: ignore[override]
return torch.cat(
[
self._influence_batch_tracincp_fast(inputs, batch)
for batch in train_dataloader
for batch in train_dataloader_iterable
],
dim=1,
)
Expand Down
Loading

0 comments on commit 88c1c47

Please sign in to comment.