From 4bba5088b60de6fb8e078120b6f6b64a7045813b Mon Sep 17 00:00:00 2001 From: Jason Cho Date: Wed, 23 Oct 2024 07:35:52 -0700 Subject: [PATCH] Reducing pyre-fixme's in visualization.py 2/n (#1385) Summary: This diff helps address the number of pyre-fixme's and return type annotation pyre errors in the visualizations.py file Reviewed By: jsawruk Differential Revision: D64546026 --- captum/attr/_utils/visualization.py | 127 ++++++++++------------------ 1 file changed, 46 insertions(+), 81 deletions(-) diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 223428244..06e4651c2 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -12,7 +12,7 @@ from matplotlib import cm, colors, pyplot as plt from matplotlib.axes import Axes from matplotlib.collections import LineCollection -from matplotlib.colors import Colormap, LinearSegmentedColormap +from matplotlib.colors import Colormap, LinearSegmentedColormap, Normalize from matplotlib.figure import Figure from matplotlib.image import AxesImage from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -111,23 +111,20 @@ def _normalize_attr( def _create_default_plot( - # pyre-fixme[2]: Parameter must be annotated. - plt_fig_axis, - # pyre-fixme[2]: Parameter must be annotated. - use_pyplot, - # pyre-fixme[2]: Parameter must be annotated. - fig_size, - **pyplot_kwargs: Any, -) -> Tuple[Figure, Axes]: + plt_fig_axis: Optional[Tuple[Figure, Union[Axes, List[Axes]]]], + use_pyplot: bool, + fig_size: Tuple[int, int], + **kwargs: Any, +) -> Tuple[Figure, Union[Axes, List[Axes]]]: # Create plot if figure, axis not provided if plt_fig_axis is not None: plt_fig, plt_axis = plt_fig_axis else: if use_pyplot: - plt_fig, plt_axis = plt.subplots(figsize=fig_size, **pyplot_kwargs) + plt_fig, plt_axis = plt.subplots(figsize=fig_size, **kwargs) else: plt_fig = Figure(figsize=fig_size) - plt_axis = plt_fig.subplots(**pyplot_kwargs) + plt_axis = plt_fig.subplots(**kwargs) return plt_fig, plt_axis # Figure.subplots returns Axes or array of Axes @@ -362,6 +359,9 @@ def visualize_image_attr( >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map") """ plt_fig, plt_axis = _create_default_plot(plt_fig_axis, use_pyplot, fig_size) + if isinstance(plt_axis, list): + # To ensure plt_axis is always a single axis, not a list of axes. + plt_axis = plt_axis[0] if original_image is not None: if np.max(original_image) <= 1.0: @@ -545,31 +545,21 @@ def visualize_image_attr_multiple( def _plot_attrs_as_axvspan( - # pyre-fixme[2]: Parameter must be annotated. - attr_vals, - # pyre-fixme[2]: Parameter must be annotated. - x_vals, - # pyre-fixme[2]: Parameter must be annotated. - ax, - # pyre-fixme[2]: Parameter must be annotated. - x_values, - # pyre-fixme[2]: Parameter must be annotated. - cmap, - # pyre-fixme[2]: Parameter must be annotated. - cm_norm, - # pyre-fixme[2]: Parameter must be annotated. - alpha_overlay, + attr_vals: npt.NDArray, + x_vals: npt.NDArray, + ax: Axes, + x_values: npt.NDArray, + cmap: LinearSegmentedColormap, + cm_norm: Normalize, + alpha_overlay: float, ) -> None: - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. half_col_width = (x_values[1] - x_values[0]) / 2.0 - for icol, col_center in enumerate(x_vals): left = col_center - half_col_width right = col_center + half_col_width ax.axvspan( xmin=left, xmax=right, - # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function. facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore edgecolor=None, alpha=alpha_overlay, @@ -577,29 +567,20 @@ def _plot_attrs_as_axvspan( def _visualize_overlay_individual( - # pyre-fixme[2]: Parameter must be annotated. - num_channels, - # pyre-fixme[2]: Parameter must be annotated. - plt_axis_list, - # pyre-fixme[2]: Parameter must be annotated. - x_values, - # pyre-fixme[2]: Parameter must be annotated. - data, - # pyre-fixme[2]: Parameter must be annotated. - channel_labels, - # pyre-fixme[2]: Parameter must be annotated. - norm_attr, - # pyre-fixme[2]: Parameter must be annotated. - cmap, - # pyre-fixme[2]: Parameter must be annotated. - cm_norm, - # pyre-fixme[2]: Parameter must be annotated. - alpha_overlay, - # pyre-fixme[2]: Parameter must be annotated. + num_channels: int, + plt_axis_list: npt.NDArray, + x_values: npt.NDArray, + data: npt.NDArray, + channel_labels: List[str], + norm_attr: npt.NDArray, + cmap: LinearSegmentedColormap, + cm_norm: Normalize, + alpha_overlay: float, **kwargs: Any, ) -> None: # helper method for visualize_timeseries_attr pyplot_kwargs = kwargs.get("pyplot_kwargs", {}) + for chan in range(num_channels): plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs) if channel_labels is not None: @@ -620,24 +601,15 @@ def _visualize_overlay_individual( def _visualize_overlay_combined( - # pyre-fixme[2]: Parameter must be annotated. - num_channels, - # pyre-fixme[2]: Parameter must be annotated. - plt_axis_list, - # pyre-fixme[2]: Parameter must be annotated. - x_values, - # pyre-fixme[2]: Parameter must be annotated. - data, - # pyre-fixme[2]: Parameter must be annotated. - channel_labels, - # pyre-fixme[2]: Parameter must be annotated. - norm_attr, - # pyre-fixme[2]: Parameter must be annotated. - cmap, - # pyre-fixme[2]: Parameter must be annotated. - cm_norm, - # pyre-fixme[2]: Parameter must be annotated. - alpha_overlay, + num_channels: int, + plt_axis_list: npt.NDArray, + x_values: npt.NDArray, + data: npt.NDArray, + channel_labels: List[str], + norm_attr: npt.NDArray, + cmap: LinearSegmentedColormap, + cm_norm: Normalize, + alpha_overlay: float, **kwargs: Any, ) -> None: pyplot_kwargs = kwargs.get("pyplot_kwargs", {}) @@ -663,22 +635,15 @@ def _visualize_overlay_combined( def _visualize_colored_graph( - # pyre-fixme[2]: Parameter must be annotated. - num_channels, - # pyre-fixme[2]: Parameter must be annotated. - plt_axis_list, - # pyre-fixme[2]: Parameter must be annotated. - x_values, - # pyre-fixme[2]: Parameter must be annotated. - data, - # pyre-fixme[2]: Parameter must be annotated. - channel_labels, - # pyre-fixme[2]: Parameter must be annotated. - norm_attr, - # pyre-fixme[2]: Parameter must be annotated. - cmap, - # pyre-fixme[2]: Parameter must be annotated. - cm_norm, + num_channels: int, + plt_axis_list: npt.NDArray, + x_values: npt.NDArray, + data: npt.NDArray, + channel_labels: List[str], + norm_attr: npt.NDArray, + cmap: LinearSegmentedColormap, + cm_norm: Normalize, + alpha_overlay: float, **kwargs: Any, ) -> None: # helper method for visualize_timeseries_attr