Skip to content

Commit

Permalink
'visualize_timeseries_attr' is too complex (#1384)
Browse files Browse the repository at this point in the history
Summary:

This diff addresses the C901 in visualization.py by breaking down the method

Reviewed By: vivekmig

Differential Revision: D64513163
  • Loading branch information
jjuncho authored and facebook-github-bot committed Oct 21, 2024
1 parent ed5daa3 commit a90529e
Showing 1 changed file with 207 additions and 106 deletions.
313 changes: 207 additions & 106 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,28 @@ def _normalize_attr(
return _normalize_scale(attr_combined, threshold)


def _create_default_plot(
# pyre-fixme[2]: Parameter must be annotated.
plt_fig_axis,
# pyre-fixme[2]: Parameter must be annotated.
use_pyplot,
# pyre-fixme[2]: Parameter must be annotated.
fig_size,
**pyplot_kwargs: Any,
) -> Tuple[Figure, Axes]:
# Create plot if figure, axis not provided
if plt_fig_axis is not None:
plt_fig, plt_axis = plt_fig_axis
else:
if use_pyplot:
plt_fig, plt_axis = plt.subplots(figsize=fig_size, **pyplot_kwargs)
else:
plt_fig = Figure(figsize=fig_size)
plt_axis = plt_fig.subplots(**pyplot_kwargs)
return plt_fig, plt_axis
# Figure.subplots returns Axes or array of Axes


def _initialize_cmap_and_vmin_vmax(
sign: str,
) -> Tuple[Union[str, Colormap], float, float]:
Expand Down Expand Up @@ -338,16 +360,7 @@ def visualize_image_attr(
>>> # Displays blended heat map visualization of computed attributions.
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
"""
# Create plot if figure, axis not provided
if plt_fig_axis is not None:
plt_fig, plt_axis = plt_fig_axis
else:
if use_pyplot:
plt_fig, plt_axis = plt.subplots(figsize=fig_size)
else:
plt_fig = Figure(figsize=fig_size)
plt_axis = plt_fig.subplots()
# Figure.subplots returns Axes or array of Axes
plt_fig, plt_axis = _create_default_plot(plt_fig_axis, use_pyplot, fig_size)

if original_image is not None:
if np.max(original_image) <= 1.0:
Expand All @@ -362,8 +375,10 @@ def visualize_image_attr(
)

# Remove ticks and tick labels from plot.
plt_axis.xaxis.set_ticks_position("none")
plt_axis.yaxis.set_ticks_position("none")
if plt_axis.xaxis is not None:
plt_axis.xaxis.set_ticks_position("none")
if plt_axis.yaxis is not None:
plt_axis.yaxis.set_ticks_position("none")
plt_axis.set_yticklabels([])
plt_axis.set_xticklabels([])
plt_axis.grid(visible=False)
Expand Down Expand Up @@ -528,6 +543,161 @@ def visualize_image_attr_multiple(
return plt_fig, plt_axis


def _plot_attrs_as_axvspan(
# pyre-fixme[2]: Parameter must be annotated.
attr_vals,
# pyre-fixme[2]: Parameter must be annotated.
x_vals,
# pyre-fixme[2]: Parameter must be annotated.
ax,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
# pyre-fixme[2]: Parameter must be annotated.
alpha_overlay,
) -> None:
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
half_col_width = (x_values[1] - x_values[0]) / 2.0

for icol, col_center in enumerate(x_vals):
left = col_center - half_col_width
right = col_center + half_col_width
ax.axvspan(
xmin=left,
xmax=right,
# pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore
edgecolor=None,
alpha=alpha_overlay,
)


def _visualize_overlay_individual(
# pyre-fixme[2]: Parameter must be annotated.
num_channels,
# pyre-fixme[2]: Parameter must be annotated.
plt_axis_list,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
data,
# pyre-fixme[2]: Parameter must be annotated.
channel_labels,
# pyre-fixme[2]: Parameter must be annotated.
norm_attr,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
# pyre-fixme[2]: Parameter must be annotated.
alpha_overlay,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs: Any,
) -> None:
# helper method for visualize_timeseries_attr
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
for chan in range(num_channels):
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
if channel_labels is not None:
plt_axis_list[chan].set_ylabel(channel_labels[chan])

_plot_attrs_as_axvspan(
norm_attr[chan],
x_values,
plt_axis_list[chan],
x_values,
cmap,
cm_norm,
alpha_overlay,
)

plt.subplots_adjust(hspace=0)
pass


def _visualize_overlay_combined(
# pyre-fixme[2]: Parameter must be annotated.
num_channels,
# pyre-fixme[2]: Parameter must be annotated.
plt_axis_list,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
data,
# pyre-fixme[2]: Parameter must be annotated.
channel_labels,
# pyre-fixme[2]: Parameter must be annotated.
norm_attr,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
# pyre-fixme[2]: Parameter must be annotated.
alpha_overlay,
**kwargs: Any,
) -> None:
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})

cycler = plt.cycler("color", matplotlib.colormaps["Dark2"].colors) # type: ignore
plt_axis_list[0].set_prop_cycle(cycler)

for chan in range(num_channels):
label = channel_labels[chan] if channel_labels else None
plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)

_plot_attrs_as_axvspan(
norm_attr,
x_values,
plt_axis_list[0],
x_values,
cmap,
cm_norm,
alpha_overlay,
)

plt_axis_list[0].legend(loc="best")


def _visualize_colored_graph(
# pyre-fixme[2]: Parameter must be annotated.
num_channels,
# pyre-fixme[2]: Parameter must be annotated.
plt_axis_list,
# pyre-fixme[2]: Parameter must be annotated.
x_values,
# pyre-fixme[2]: Parameter must be annotated.
data,
# pyre-fixme[2]: Parameter must be annotated.
channel_labels,
# pyre-fixme[2]: Parameter must be annotated.
norm_attr,
# pyre-fixme[2]: Parameter must be annotated.
cmap,
# pyre-fixme[2]: Parameter must be annotated.
cm_norm,
**kwargs: Any,
) -> None:
# helper method for visualize_timeseries_attr
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
for chan in range(num_channels):
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
lc.set_array(norm_attr[chan, :])
plt_axis_list[chan].add_collection(lc)
plt_axis_list[chan].set_ylim(
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
)
if channel_labels is not None:
plt_axis_list[chan].set_ylabel(channel_labels[chan])

plt.subplots_adjust(hspace=0)


def visualize_timeseries_attr(
attr: npt.NDArray,
data: npt.NDArray,
Expand Down Expand Up @@ -686,8 +856,8 @@ def visualize_timeseries_attr(

num_subplots = num_channels
if (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.overlay_combined
TimeseriesVisualizationMethod[method].value
== TimeseriesVisualizationMethod.overlay_combined.value
):
num_subplots = 1
attr = np.sum(attr, axis=0) # Merge attributions across channels
Expand All @@ -700,17 +870,9 @@ def visualize_timeseries_attr(
x_values = np.arange(timeseries_length)

# Create plot if figure, axis not provided
if plt_fig_axis is not None:
plt_fig, plt_axis = plt_fig_axis
else:
if use_pyplot:
plt_fig, plt_axis = plt.subplots( # type: ignore
figsize=fig_size, nrows=num_subplots, sharex=True
)
else:
plt_fig = Figure(figsize=fig_size)
plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) # type: ignore
# Figure.subplots returns Axes or array of Axes
plt_fig, plt_axis = _create_default_plot(
plt_fig_axis, use_pyplot, fig_size, nrows=num_subplots, sharex=True
)

if not isinstance(plt_axis, ndarray):
plt_axis_list = np.array([plt_axis])
Expand All @@ -720,91 +882,30 @@ def visualize_timeseries_attr(
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None)

# Set default colormap and bounds based on sign.
if VisualizeSign[sign] == VisualizeSign.all:
default_cmap: Union[str, LinearSegmentedColormap] = (
LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
)
vmin, vmax = -1, 1
elif VisualizeSign[sign] == VisualizeSign.positive:
default_cmap = "Greens"
vmin, vmax = 0, 1
elif VisualizeSign[sign] == VisualizeSign.negative:
default_cmap = "Reds"
vmin, vmax = 0, 1
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
default_cmap = "Blues"
vmin, vmax = 0, 1
else:
raise AssertionError("Visualize Sign type is not valid.")
default_cmap, vmin, vmax = _initialize_cmap_and_vmin_vmax(sign)
cmap = cmap if cmap is not None else default_cmap
cmap = cm.get_cmap(cmap) # type: ignore
cm_norm = colors.Normalize(vmin, vmax)

# pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _plot_attrs_as_axvspan(attr_vals, x_vals, ax) -> None:
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
half_col_width = (x_values[1] - x_values[0]) / 2.0
for icol, col_center in enumerate(x_vals):
left = col_center - half_col_width
right = col_center + half_col_width
ax.axvspan(
xmin=left,
xmax=right,
# pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore
edgecolor=None,
alpha=alpha_overlay,
)

if (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.overlay_individual
):
for chan in range(num_channels):
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
if channel_labels is not None:
plt_axis_list[chan].set_ylabel(channel_labels[chan])

_plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis_list[chan])

plt.subplots_adjust(hspace=0)

elif (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.overlay_combined
):
# Dark colors are better in this case
cycler = plt.cycler("color", matplotlib.colormaps["Dark2"]) # type: ignore
plt_axis_list[0].set_prop_cycle(cycler)

for chan in range(num_channels):
label = channel_labels[chan] if channel_labels else None
plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)

_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis_list[0])

plt_axis_list[0].legend(loc="best")

elif (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.colored_graph
):
for chan in range(num_channels):
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
lc.set_array(norm_attr[chan, :])
plt_axis_list[chan].add_collection(lc)
plt_axis_list[chan].set_ylim(
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
)
if channel_labels is not None:
plt_axis_list[chan].set_ylabel(channel_labels[chan])

plt.subplots_adjust(hspace=0)

visualization_methods: Dict[str, Callable[..., Union[None, AxesImage]]] = {
"overlay_individual": _visualize_overlay_individual,
"overlay_combined": _visualize_overlay_combined,
"colored_graph": _visualize_colored_graph,
}
kwargs = {
"num_channels": num_channels,
"plt_axis_list": plt_axis_list,
"x_values": x_values,
"data": data,
"channel_labels": channel_labels,
"norm_attr": norm_attr,
"cmap": cmap,
"cm_norm": cm_norm,
"alpha_overlay": alpha_overlay,
"pyplot_kwargs": pyplot_kwargs,
}
if method in visualization_methods:
visualization_methods[method](**kwargs)
else:
raise AssertionError("Invalid visualization method: {}".format(method))

Expand Down

0 comments on commit a90529e

Please sign in to comment.