diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index a3246f7bf..4dd8fcc65 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -88,6 +88,9 @@ def plot_token_attr( fig, ax = plt.subplots() + # Hide the grid + ax.grid(False) + # Plot the heatmap data = token_attr.numpy() @@ -119,7 +122,7 @@ def plot_token_attr( # Create colorbar cbar = fig.colorbar(im, ax=ax) # type: ignore - cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom") + cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom") # Show all ticks and label them with the respective list entries. shortened_tokens = [ @@ -204,7 +207,7 @@ def plot_seq_attr( color="#d0365b", ) - ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom") + ax.set_ylabel("Sequence Attribution", rotation=90, va="bottom") if show: plt.show()