Skip to content

Commit

Permalink
Turn the attribute _return_attention_scores into an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
apehex committed Jan 23, 2025
1 parent 90568da commit 4c5557e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def __init__(
self.seed = seed

self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim))
self._return_attention_scores = False

# Check for flash attention constraints
if self._flash_attention and self._dropout > 0.0:
Expand Down Expand Up @@ -419,6 +418,7 @@ def _compute_attention(
value,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Applies Dot-product attention with query, key, value tensors.
Expand All @@ -442,7 +442,7 @@ def _compute_attention(
attention_scores: Multi-headed attention weights.
"""
# Check for flash attention constraints
if self._flash_attention and self._return_attention_scores:
if self._flash_attention and return_attention_scores:
raise ValueError(
"Returning attention scores is not supported when flash "
"attention is enabled. Please disable flash attention to access"
Expand All @@ -452,7 +452,7 @@ def _compute_attention(
# Determine whether to use dot-product attention
use_dot_product_attention = not (
self._dropout > 0.0
or self._return_attention_scores
or return_attention_scores
or (len(query.shape) != 4)
)

Expand Down Expand Up @@ -525,7 +525,6 @@ def call(
training=None,
use_causal_mask=False,
):
self._return_attention_scores = return_attention_scores
if key is None:
key = value

Expand Down Expand Up @@ -562,6 +561,7 @@ def call(
value,
attention_mask,
training,
return_attention_scores,
)
attention_output = self._output_dense(attention_output)

Expand Down

0 comments on commit 4c5557e

Please sign in to comment.