Skip to content

Commit

Permalink
Reducing pyre-fixme's in visualization.py 2/n (#1385)
Browse files Browse the repository at this point in the history
Summary:

This diff helps address the number of pyre-fixme's and return type annotation pyre errors in the visualizations.py file

Reviewed By: jsawruk

Differential Revision: D64546026
  • Loading branch information
jjuncho authored and facebook-github-bot committed Oct 23, 2024
1 parent b80e488 commit 4bba508
Showing 1 changed file with 46 additions and 81 deletions.
127 changes: 46 additions & 81 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from matplotlib import cm, colors, pyplot as plt
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection
from matplotlib.colors import Colormap, LinearSegmentedColormap
from matplotlib.colors import Colormap, LinearSegmentedColormap, Normalize
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand Down Expand Up @@ -111,23 +111,20 @@ def _normalize_attr(


def _create_default_plot(
# pyre-fixme[2]: Parameter must be annotated.
plt_fig_axis,
# pyre-fixme[2]: Parameter must be annotated.
use_pyplot,
# pyre-fixme[2]: Parameter must be annotated.
fig_size,
**pyplot_kwargs: Any,
) -> Tuple[Figure, Axes]:
plt_fig_axis: Optional[Tuple[Figure, Union[Axes, List[Axes]]]],
use_pyplot: bool,
fig_size: Tuple[int, int],
**kwargs: Any,
) -> Tuple[Figure, Union[Axes, List[Axes]]]:
# Create plot if figure, axis not provided
if plt_fig_axis is not None:
plt_fig, plt_axis = plt_fig_axis
else:
if use_pyplot:
plt_fig, plt_axis = plt.subplots(figsize=fig_size, **pyplot_kwargs)
plt_fig, plt_axis = plt.subplots(figsize=fig_size, **kwargs)
else:
plt_fig = Figure(figsize=fig_size)
plt_axis = plt_fig.subplots(**pyplot_kwargs)
plt_axis = plt_fig.subplots(**kwargs)
return plt_fig, plt_axis
# Figure.subplots returns Axes or array of Axes

Expand Down Expand Up @@ -362,6 +359,9 @@ def visualize_image_attr(
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
"""
plt_fig, plt_axis = _create_default_plot(plt_fig_axis, use_pyplot, fig_size)
if isinstance(plt_axis, list):
# To ensure plt_axis is always a single axis, not a list of axes.
plt_axis = plt_axis[0]

if original_image is not None:
if np.max(original_image) <= 1.0:
Expand Down Expand Up @@ -545,61 +545,42 @@ def visualize_image_attr_multiple(


def _plot_attrs_as_axvspan(
# pyre-fixme[2]: Parameter must be annotated.
attr_vals,
# pyre-fixme[2]: Parameter must be annotated.
x_vals,
# pyre-fixme[2]: Parameter must be annotated.
ax,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
# pyre-fixme[2]: Parameter must be annotated.
alpha_overlay,
attr_vals: npt.NDArray,
x_vals: npt.NDArray,
ax: Axes,
x_values: npt.NDArray,
cmap: LinearSegmentedColormap,
cm_norm: Normalize,
alpha_overlay: float,
) -> None:
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
half_col_width = (x_values[1] - x_values[0]) / 2.0

for icol, col_center in enumerate(x_vals):
left = col_center - half_col_width
right = col_center + half_col_width
ax.axvspan(
xmin=left,
xmax=right,
# pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore
edgecolor=None,
alpha=alpha_overlay,
)


def _visualize_overlay_individual(
# pyre-fixme[2]: Parameter must be annotated.
num_channels,
# pyre-fixme[2]: Parameter must be annotated.
plt_axis_list,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
data,
# pyre-fixme[2]: Parameter must be annotated.
channel_labels,
# pyre-fixme[2]: Parameter must be annotated.
norm_attr,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
# pyre-fixme[2]: Parameter must be annotated.
alpha_overlay,
# pyre-fixme[2]: Parameter must be annotated.
num_channels: int,
plt_axis_list: npt.NDArray,
x_values: npt.NDArray,
data: npt.NDArray,
channel_labels: List[str],
norm_attr: npt.NDArray,
cmap: LinearSegmentedColormap,
cm_norm: Normalize,
alpha_overlay: float,
**kwargs: Any,
) -> None:
# helper method for visualize_timeseries_attr
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})

for chan in range(num_channels):
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
if channel_labels is not None:
Expand All @@ -620,24 +601,15 @@ def _visualize_overlay_individual(


def _visualize_overlay_combined(
# pyre-fixme[2]: Parameter must be annotated.
num_channels,
# pyre-fixme[2]: Parameter must be annotated.
plt_axis_list,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
data,
# pyre-fixme[2]: Parameter must be annotated.
channel_labels,
# pyre-fixme[2]: Parameter must be annotated.
norm_attr,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
# pyre-fixme[2]: Parameter must be annotated.
alpha_overlay,
num_channels: int,
plt_axis_list: npt.NDArray,
x_values: npt.NDArray,
data: npt.NDArray,
channel_labels: List[str],
norm_attr: npt.NDArray,
cmap: LinearSegmentedColormap,
cm_norm: Normalize,
alpha_overlay: float,
**kwargs: Any,
) -> None:
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
Expand All @@ -663,22 +635,15 @@ def _visualize_overlay_combined(


def _visualize_colored_graph(
# pyre-fixme[2]: Parameter must be annotated.
num_channels,
# pyre-fixme[2]: Parameter must be annotated.
plt_axis_list,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
data,
# pyre-fixme[2]: Parameter must be annotated.
channel_labels,
# pyre-fixme[2]: Parameter must be annotated.
norm_attr,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
num_channels: int,
plt_axis_list: npt.NDArray,
x_values: npt.NDArray,
data: npt.NDArray,
channel_labels: List[str],
norm_attr: npt.NDArray,
cmap: LinearSegmentedColormap,
cm_norm: Normalize,
alpha_overlay: float,
**kwargs: Any,
) -> None:
# helper method for visualize_timeseries_attr
Expand Down

0 comments on commit 4bba508

Please sign in to comment.