From ee35e7c356beb2b04296ceeacc19989f85bc9840 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 30 Oct 2024 09:14:45 +0100 Subject: [PATCH] Fixes #212 : workaround that disable SPD attention in latest version of openai-whisper (20240930) which prevents from accessing attention weights --- whisper_timestamped/transcribe.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 9a444c2..2162ca2 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -3,7 +3,7 @@ __author__ = "Jérôme Louradour" __credits__ = ["Jérôme Louradour"] __license__ = "GPLv3" -__version__ = "1.15.4" +__version__ = "1.15.5" # Set some environment variables import os @@ -46,6 +46,20 @@ AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / SAMPLE_RATE # 0.02 (sec) SEGMENT_DURATION = N_FRAMES * HOP_LENGTH / SAMPLE_RATE # 30.0 (sec) +# Access attention in latest versions... +if whisper.__version__ >= "20240930": + from whisper.model import disable_sdpa +else: + from contextlib import contextmanager + + # Dummy context manager that does nothing + @contextmanager + def disable_sdpa(): + try: + yield + finally: + pass + # Logs import logging logger = logging.getLogger("whisper_timestamped") @@ -885,7 +899,8 @@ def hook_output_logits(layer, ins, outs): if compute_word_confidence or no_speech_threshold is not None: all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits)) - transcription = model.transcribe(audio, **whisper_options) + with disable_sdpa(): + transcription = model.transcribe(audio, **whisper_options) finally: @@ -1047,7 +1062,8 @@ def hook_output_logits(layer, ins, outs): try: model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?"" - transcription = model.transcribe(audio, **whisper_options) + with disable_sdpa(): + transcription = model.transcribe(audio, **whisper_options) finally: for hook in all_hooks: hook.remove()