From 88c6c2fd0e5d67ec180677f453948e480788e30f Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Fri, 26 Jan 2024 09:57:42 +0100 Subject: [PATCH 01/11] Enable transformers as a backend --- whisper_timestamped/transcribe.py | 444 ++++++++++++++++++++++++++++-- 1 file changed, 424 insertions(+), 20 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 345338f..b48ecf0 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -10,7 +10,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # Remove warning "This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)..." os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' # GPU in the right order -# Whisper and Torch +# openai-whisper and pytorch import whisper import torch import torch.nn.functional as F @@ -50,6 +50,7 @@ import logging logger = logging.getLogger("whisper_timestamped") +DEFAULT_BACKEND = "openai-whisper" # "transformers" USE_EFFICIENT_BY_DEFAULT = True TRUST_WHISPER_TIMESTAMP_BY_DEFAULT = True DISFLUENCY_MARK = "[*]" @@ -77,6 +78,7 @@ def transcribe_timestamped( plot_word_alignment=False, word_alignement_most_top_layers=None, # Was 6 before 1.9 remove_empty_words=False, + use_backend_timestamps=False, # Reproducibility seed=1234, @@ -158,6 +160,9 @@ def transcribe_timestamped( remove_empty_words: bool Whether to remove words with no duration occuring at the end of segments (probable Whisper hallucinations). + use_backend_timestamps: bool + Whether to use word timestamps provided by the backend (openai-whisper or transformers), instead of the ones computed by more complex heuristics of whisper-timestamped. + seed: int Random seed to use for temperature sampling, for the sake of reproducibility. Choose None for unpredictable randomness. @@ -213,13 +218,16 @@ def transcribe_timestamped( if isinstance(temperature, (list, tuple)) and len(temperature) == 1: temperature = temperature[0] - if isinstance(temperature, (list, tuple)): - # temperature fallback + if isinstance(temperature, (list, tuple)): # temperature fallback + naive_approach = True + elif temperature > 0 and best_of is not None and best_of > 1: # random sampling naive_approach = True - elif temperature > 0 and best_of is not None and best_of > 1: + if beam_size is not None: # beam-search naive_approach = True - if beam_size is not None: - # beam-search + + # TODO: check if efficient approach is possible with transformers backend + # (careful: decoding heuristics are completely different from the ones used in openai-whisper) + if is_transformer_model(model) or use_backend_timestamps: naive_approach = True # Input options @@ -280,6 +288,7 @@ def transcribe_timestamped( (transcription, words) = _transcribe_timestamped_naive(model, audio, min_word_duration=0.0, # Was 0.04 before 1.11 trust_whisper_timestamps=trust_whisper_timestamps, + use_backend_timestamps=use_backend_timestamps, **alignment_options, **whisper_options, **other_options) else: (transcription, words) = _transcribe_timestamped_efficient(model, audio, @@ -297,8 +306,8 @@ def transcribe_timestamped( for word in words: if verbose and not naive_approach and not vad: print_timestamped(word) - word.pop("tokens") - word.pop("tokens_indices") + word.pop("tokens", None) + word.pop("tokens_indices", None) if "avg_logprob_reliable" in word: word.pop("avg_logprob_reliable") idx_segment = word.pop("idx_segment") @@ -975,6 +984,7 @@ def _transcribe_timestamped_naive( compute_word_confidence, include_punctuation_in_confidence, refine_whisper_precision_nframes, + use_backend_timestamps, alignment_heads, plot_word_alignment, word_alignement_most_top_layers, @@ -999,12 +1009,20 @@ def _transcribe_timestamped_naive( tokenizer = get_tokenizer(model, task=whisper_options["task"], language=language) + transformer_backend = is_transformer_model(model) + if transformer_backend: + # Additional options specific to transformer models + whisper_options["remove_punctuation_from_words"] = remove_punctuation_from_words + whisper_options["use_token_timestamps"] = use_backend_timestamps + else: + whisper_options["word_timestamps"] = use_backend_timestamps + language_probs = None def hook_output_logits(layer, ins, outs): nonlocal language_probs, tokenizer # Get language probabilities - if language_probs is None: + if language is None and language_probs is None: if outs.shape[1] == 1: embedding_weights = torch.transpose(model.decoder.token_embedding.weight, 0, 1).to(outs[0].dtype) index_start = tokenizer.sot + 1 @@ -1025,17 +1043,32 @@ def hook_output_logits(layer, ins, outs): for hook in all_hooks: hook.remove() - if verbose and language is None and not whisper_options["verbose"]: + if not transformer_backend and verbose and language is None and not whisper_options["verbose"]: # Reproduce whisper verbose (2/2) print(f"Detected language: {whisper.tokenizer.LANGUAGES[transcription['language']].title()}") sys.stdout.flush() - language = norm_language(transcription["language"]) + # End early if timestamps have been computed by the backend + if transcription.get("segments") and "words" in transcription["segments"][0]: + words = [] + for i_segment, segment in enumerate(transcription["segments"]): + ws = segment.pop("words", []) + for w in ws: + # Rename openai-whisper -> whisper-timestamped + if "word" in w: w["text"] = w.pop("word") + if "probability" in w: w["confidence"] = round_confidence(w.pop("probability")) + w["idx_segment"] = i_segment + words.extend(ws) + if language_probs: + transcription["language_probs"] = language_probs + return transcription, words + + language = norm_language(transcription.get("language", language)) use_space = should_use_space(language) n_mels = model.dims.n_mels if hasattr(model.dims, "n_mels") else 80 - attention_weights = [[] for _ in range(min(word_alignement_most_top_layers,len(model.decoder.blocks)))] + attention_weights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))] try: @@ -1047,9 +1080,15 @@ def hook_output_logits(layer, ins, outs): for i, block in enumerate(model.decoder.blocks): if i < nblocks - word_alignement_most_top_layers: continue + def hook(layer, ins, outs, index=j): + if is_transformer_model(model): + attention_weights[index] = outs[1].log() + else: + attention_weights[index] = outs[1] all_hooks.append( block.cross_attn.register_forward_hook( - lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[-1]) + hook + # lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[1]) ) ) j += 1 @@ -1159,12 +1198,19 @@ def hook_output_logits(layer, ins, outs): last_token_check = tokens[-1] tokens = tokens[:-1] + sot_sequence = tokenizer.sot_sequence + if language and len(sot_sequence) == 3: + sot_sequence = ( + sot_sequence[0], + tokenizer.to_language_token(language), + sot_sequence[2], + ) tokens = [ - *tokenizer.sot_sequence, + *sot_sequence, tokenizer.timestamp_begin, ] + tokens - i_start = len(tokenizer.sot_sequence) + i_start = len(sot_sequence) with torch.no_grad(): logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0)) @@ -1234,8 +1280,10 @@ def hook_output_logits(layer, ins, outs): segment_tokens_check.append(last_token_check) if trust_whisper_timestamps: if segment_tokens_check != segment["tokens"]: - assert len(segment_tokens_check) < len(segment["tokens"]) and segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \ - f"Got inconsistent tokens: {tokenizer.decode(segment_tokens_check)} != {tokenizer.decode(segment['tokens'])}" + assert len(segment_tokens_check) < len(segment["tokens"]), \ + f"First should be longer by one token: '{tokenizer.decode_with_timestamps(segment_tokens_check)}' should include '{tokenizer.decode_with_timestamps(segment['tokens'])}'" + assert segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \ + f"Got inconsistent tokens: {tokenizer.decode_with_timestamps(segment_tokens_check)} != {tokenizer.decode_with_timestamps(segment['tokens'])}" segment["tokens"] = segment_tokens_check segment["text"] = tokenizer.decode(segment["tokens"]) # else: TODO @@ -1293,6 +1341,10 @@ def print_timestamped(w): def get_logit_filters(model, whisper_options, prompt = None): + if is_transformer_model(model): + # import transformers + # transformers.WhisperTimeStampLogitsProcessor + raise NotImplementedError("TODO") decoding_options = get_decoding_options(whisper_options) if "initial_prompt" in decoding_options: prompt0 = decoding_options.pop("initial_prompt") @@ -1324,6 +1376,15 @@ def get_decoding_options(whisper_options): ]) def get_tokenizer(model, task="transcribe", language="en"): + if is_transformer_model(model): + tokenizer = model.tokenizer + tokenizer.sot_sequence = ( + tokenizer.sot, + tokenizer.to_language_token(language or "en"), + tokenizer.to_task_token(task), + ) + tokenizer.sot_sequence + return model.tokenizer try: return whisper.tokenizer.get_tokenizer( model.is_multilingual, @@ -2260,7 +2321,7 @@ def _get_alignment_heads(model_name, num_layers, num_heads): def _get_number_of_parameters(model): num_parameters = 0 for name, p in model.named_parameters(): - if name in ["decoder.proj_out.weight"]: + if name in ["decoder.proj_out.weight", "model.encoder.embed_positions.weight"]: continue num_parameters += p.numel() return num_parameters @@ -2269,9 +2330,54 @@ def _get_number_of_parameters(model): def load_model( name: str, device: Optional[Union[str, torch.device]] = None, + backend: str = DEFAULT_BACKEND, download_root: str = None, in_memory: bool = False, ): + """ + Load a model from the given name or path. + + Parameters + ---------- + name : str + Name of the model or path to the model. + Examples: + - OpenAI-Whisper identifier: "large-v3", "medium.en", ... + - HuggingFace identifier: "openai/whisper-large-v3", "distil-whisper/distil-large-v2", ... + - File name: "path/to/model.pt", "path/to/model.ckpt", "path/to/model.bin" + - Folder name: "path/to/folder". The folder must contain either "pytorch_model.bin", "model.safetensors", or sharded versions of those, or "whisper.ckpt". + device : str or torch.device, optional + Device to use. If None, use CUDA if there is a GPU available, otherwise CPU. + backend : str, optional + Backend to use. Either "transformers" or "openai-whisper". + download_root : str, optional + Root folder to download the model to. If None, use the default download root. + in_memory : bool, optional + Whether to preload the model weights into host memory. + """ + if backend == "transformers": + try: + import transformers + except ImportError: + raise ImportError(f"If you want to use transformers backend, please install first the transformers library") + if name in whisper.available_models(): + name = f"openai/whisper-{name}" + # TODO: use download_root + # TODO: does in_memory makes sense? + try: + generation_config = transformers.GenerationConfig.from_pretrained(name) + except OSError: + generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny") + processor = transformers.WhisperProcessor.from_pretrained(name) + model = transformers.WhisperForConditionalGeneration.from_pretrained(name) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config) + + elif backend not in ["openai", "openai-whisper"]: + raise ValueError(f"Got unexpected backend {backend}") + extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None if name in whisper.available_models() or extension == ".pt": @@ -2359,7 +2465,303 @@ def torch_load(model_path): hf_state_dict = torch.load(model_path, map_location="cpu") return hf_state_dict +# Some helpers to manage transformers/openai-whisper model + +class TransformerWhisperAsOpenAIWhisper: + """ + Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped) + """ + + def __init__(self, model, processor, generation_config): + + self.model = model # transformers.WhisperForConditionalGeneration + self.processor = processor # transformers.WhisperProcessor + self.generation_config = generation_config # transformers.GenerationConfig + + self.device = model.device + + # Dimensions + self.dims = whisper.model.ModelDimensions( + n_mels = model.get_encoder().get_input_embeddings().in_channels, + n_audio_ctx = 1500, + n_audio_state = model.get_encoder().get_input_embeddings().out_channels, + n_audio_head = model.get_encoder().layers[0].self_attn.num_heads, + n_audio_layer = len(model.get_encoder().layers), + n_vocab = model.get_decoder().get_input_embeddings().num_embeddings, + n_text_ctx = 448, + n_text_state = model.get_decoder().get_input_embeddings().embedding_dim, + n_text_head = model.get_decoder().layers[0].self_attn.num_heads, + n_text_layer = len(model.get_decoder().layers), + ) + + # Tokenization + self.tokenizer = processor.tokenizer + ( + self.tokenizer.sot, + self.tokenizer.eot, + self.tokenizer.timestamp_begin, + self.tokenizer.no_speech, + self.tokenizer.no_timestamps, + ) = self.tokenizer.convert_tokens_to_ids([ + "<|startoftranscript|>", + "<|endoftext|>", + "<|0.00|>", + "<|nospeech|>", + "<|notimestamps|>", + ]) + if self.tokenizer.decode([self.tokenizer.timestamp_begin], decode_with_timestamps=True) != "<|0.00|>": + # Sometimes, the tokenizer is weird and it is impossible to get the timestamp_begin token easily (e.g. with "qanastek/whisper-tiny-french-cased") + logger.warning("Getting timestamp_begin token is not straightforward for this model") + i = self.tokenizer.no_timestamps + 1 + maxi = i + 1000 + while self.tokenizer.decode([i], decode_with_timestamps=True) != "<|0.00|>": + i += 1 + if i == maxi: + raise RuntimeError("Could not find timestamp_begin token") + self.tokenizer.timestamp_begin = i + + self.tokenizer.all_language_tokens = self.tokenizer.convert_tokens_to_ids([ + t for t in self.tokenizer.additional_special_tokens if len(t) in [6,7] + ]) + # Update old Whisper generation config (ex: error: "The generation config is outdated and is thus not compatible with the `task` argument to `generate` [...] update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224") + if not hasattr(self.generation_config, "lang_to_id"): + self.generation_config.lang_to_id = dict( + (self.tokenizer.decode(itoken), itoken) + for itoken in self.tokenizer.all_language_tokens + ) + if not hasattr(self.generation_config, "task_to_id"): + self.generation_config.task_to_id = dict( + (task, self.tokenizer.encode("<|" + task + "|>", add_special_tokens=False)[0]) + for task in ["transcribe", "translate"]) + self.tokenizer.to_language_token = lambda language: self.generation_config.lang_to_id["<|" + norm_language(language) + "|>"] + self.tokenizer.to_task_token = lambda task: self.generation_config.task_to_id[task] + + self.tokenizer.to_timestamp_token = lambda t: self.tokenizer.encode(f"<|{t:0.2f}|>", add_special_tokens=False)[0] + self.tokenizer.decode_with_timestamps = lambda tokens: self.tokenizer.decode(tokens, decode_with_timestamps=True) + + self.generation_config.no_timestamps_token_id = self.tokenizer.no_timestamps + self.model.generation_config = self.generation_config + + # Access to layers (renamed attributes) + self.decoder = self.model.get_decoder() + self.decoder.ln = self.decoder.layer_norm + self.decoder.token_embedding = self.decoder.embed_tokens + self.decoder.blocks = self.decoder.layers + for block in self.decoder.blocks: + block.cross_attn = block.encoder_attn + + # From the config + if hasattr(generation_config, "is_multilingual"): + self.is_multilingual = generation_config.is_multilingual + else: + self.is_multilingual = generation_config.is_multilingual = (self.tokenizer.sot != 50257) + + # Alignment heads + if hasattr(generation_config, "alignment_heads"): + a = generation_config.alignment_heads + self.alignment_heads = torch.sparse_coo_tensor(np.array(a).transpose(), [True]*len(a)).coalesce().to(self.device) + + def named_parameters(self): + return self.model.named_parameters() + + def transcribe(self, audio, use_token_timestamps=False, **kwargs): + + # Decoding options + # TODO: double check that this setup is correct + generation_config = self.generation_config + generation_config.num_beams = kwargs.get("beam_size", None) or 1 + temperature = kwargs.get("temperature", 0.0) + if isinstance(temperature, (list, tuple)): + # Not supported with transformers + temperature = min(temperature) + if temperature != 0.0: + generation_config.do_sample = True + generation_config.temperature = temperature + generation_config.top_k = kwargs.get("best_of", None) + + initial_prompt = kwargs.get("initial_prompt") + prompt_ids = self.processor.get_prompt_ids(initial_prompt) if (initial_prompt and initial_prompt.strip()) else None + + generate_kwargs = dict( + return_dict_in_generate = True, + return_segments = True, + return_timestamps = True, + return_token_timestamps = use_token_timestamps, + max_length = self.dims.n_text_ctx, + is_multilingual = self.is_multilingual, + prompt_ids = prompt_ids, + generation_config = generation_config, + ) + if self.is_multilingual: + generate_kwargs["language"] = generate_kwargs.get("language") + generate_kwargs["task"] = generate_kwargs.get("task", "transcribe") + + # Extract features + features = self.processor( + audio, + return_tensors="pt", + sampling_rate=16_000, + truncation=False, + ).input_features.to(self.device) + + # Transcribe + output = self.model.generate( + features, + **generate_kwargs + ) + + # Because the output format is different when there is only one segment (e.g. audio duration < 30 seconds)... (WTF) + if "segments" not in output: + tokens = output.sequences[0] + new_output = { + "segments": [[{ + "tokens": tokens[1:], + "start": torch.tensor(0.0), + "result": { + "sequences": output.sequences[0], + "past_key_values": output.past_key_values, + } + }]] + } + if use_token_timestamps: + new_output["segments"][0][0]["result"]["token_timestamps"] = output.token_timestamps[0] + output = new_output + + # Language detection + first_segment_tokens = output["segments"][0][0]["tokens"].tolist() + if self.tokenizer.sot in first_segment_tokens: + i_sot = first_segment_tokens.index(self.tokenizer.sot) + else: + i_sot = -1 + if self.is_multilingual: + language = self.tokenizer.decode([first_segment_tokens[i_sot+1]], decode_with_timestamps=True) + assert len(language) in [6,7], f"Unexpected language detected: '{language}' ({first_segment_tokens[i_sot+1]}) in '{self.tokenizer.decode(first_segment_tokens, decode_with_timestamps=True)}'" + language = language[2:-2] + else: + language = "en" + + if use_token_timestamps: + remove_punctuation_from_words = kwargs.get("remove_punctuation_from_words", False) + use_space = should_use_space(language) + + full_text = "" + segments = [] + for id, (segment_dict, segment) in enumerate(self._iter_segments(output, prompt_ids)): + + segment_dict = segment_dict | { + "temperature": temperature, + # "avg_logprob": -0.6982866287231445, + # "compression_ratio": 0.5294117647058824, + # "no_speech_prob": 0.019023602828383446 + } + + # Accumulate + if use_token_timestamps: + tokens = segment_dict["tokens_no_timestamp"] + offset = segment_dict["offset"] + all_tokens = segment["result"]["sequences"].tolist() + token_timestamps = segment["result"]["token_timestamps"] + assert len(all_tokens) == len(token_timestamps) + n_tokens = len(tokens) + for i in range(0, len(all_tokens) + 1 - n_tokens): + if all_tokens[i:i+n_tokens] == tokens: + token_timestamps = token_timestamps[i:i+n_tokens+1] + break + assert len(tokens)+1 == len(token_timestamps) + split_tokens = split_tokens_on_spaces if use_space else split_tokens_on_unicode + words, word_tokens, word_tokens_indices = split_tokens(tokens, self.tokenizer, remove_punctuation_from_words=remove_punctuation_from_words) + words_dicts = [] + i_end = 0 + for w, toks in zip(words, word_tokens_indices): + i_start = i_end + i_end = i_start + len(toks) + words_dicts.append({ + "text": w, + "start": offset + token_timestamps[i_start].item(), + "end": offset + token_timestamps[i_end].item(), + # "probability": 0.199 + }) + segment_dict["words"] = words_dicts + + segment_dict.pop("tokens_no_timestamp") + segment_dict.pop("offset") + segments.append(segment_dict) + full_text += segment_dict["text"] + + output_dict = { + "text": full_text, + "segments": segments, + } + if not kwargs.get("language"): + output_dict["language"] = language + + return output_dict + + def _iter_segments(self, output, prompt_ids): + + id = -1 + for sub_segments in output["segments"]: + for segment in sub_segments: + id += 1 + chunk_start = round(max(0, segment["start"].item()), 2) + tokens = segment["tokens"] + if id == 0 and prompt_ids is not None: + tokens = tokens[len(prompt_ids):] + time_tokens = [(i, t.item()) for i, t in enumerate(tokens) if t >= self.tokenizer.timestamp_begin] + i = 0 + while i < len(time_tokens): + i_start, token_start = time_tokens[i] + relative_start = round((token_start - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN, 2) + assert relative_start >= 0 + if i == 0: + offset = chunk_start - relative_start + assert offset >= 0, f"Got negative offset ({offset}) with {chunk_start=} and {relative_start=}" + has_end = i + 1 < len(time_tokens) + if has_end: + i_end, token_end = time_tokens[i+1] + # Ends on either consecutive timestamps, or the next timestamp followed by <|endoftext|> + while i + 2 < len(time_tokens): + if time_tokens[i+2][0] == i_end + 1: break + if i_end + 1 >= len(tokens) or tokens[i_end+1] in [self.tokenizer.eot]: break + logger.warning(f"Unexpected prediction without 2 consecutive timestamps") + i += 1 + i_end, token_end = time_tokens[i+1] + relative_end = round((token_end - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN, 2) + else: + i_end = len(tokens) - 1 + if tokens[i_end] == self.tokenizer.eot: i_end -= 1 + relative_end = SEGMENT_DURATION + start = offset + relative_start + duration = relative_end - relative_start + assert duration >= 0, f"Got negative duration ({duration}) with {relative_end=} and {relative_start=}" + tokens_with_timestamps = tokens[i_start:i_end+1] # include timestamps + text = self.tokenizer.decode(tokens_with_timestamps, skip_special_tokens=True) + tokens_with_timestamps = tokens_with_timestamps.tolist() + tokens_no_timestamp = tokens_with_timestamps[1:-1] if has_end else tokens_with_timestamps[1:] + i += 2 + if not len(tokens_no_timestamp): continue + yield ( + { + "id": id, + "seek": round(offset * SAMPLE_RATE / HOP_LENGTH), + "start": start, + "end": start + duration, + "text": text, + "tokens": tokens_with_timestamps, + "tokens_no_timestamp": tokens_no_timestamp, + "offset": offset, + }, + segment, + ) + + + def __call__(self, mfcc, tokens): + output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True) + return output.logits + +def is_transformer_model(model): + return isinstance(model, TransformerWhisperAsOpenAIWhisper) # Credit: https://github.com/openai/whisper/discussions/830 @@ -2500,6 +2902,7 @@ def get_do_write(output_format): parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small") parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str) parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference") + parser.add_argument("--backend", default=DEFAULT_BACKEND, help="Which backend to use", choices=["openai-whisper", "transformers"], type=str) parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str) valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"] def str2output_formats(string): @@ -2551,7 +2954,7 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, "best_of", 5) setattr(namespace, "beam_size", 5) setattr(namespace, "temperature_increment_on_fallback", 0.2) - parser.add_argument('--accurate', help="Shortcut to use the same default option as in Whisper (best_of=5, beam_search=5, temperature_increment_on_fallback=0.2)", action=ActionSetAccurate) + parser.add_argument('--accurate', help="Shortcut to use the same default option as in openai-whisper (best_of=5, beam_search=5, temperature_increment_on_fallback=0.2)", action=ActionSetAccurate) class ActionSetEfficient(argparse.Action): def __init__(self, option_strings, dest, nargs=None, **kwargs): @@ -2590,8 +2993,9 @@ def __call__(self, parser, namespace, values, option_string=None): force_cudnn_initialization(device) output_format = args.pop("output_format") + backend = args.pop("backend") - model = load_model(model, device=device, download_root=model_dir) + model = load_model(model, device=device, download_root=model_dir, backend=backend) plot_word_alignment = args.pop("plot") From 801a13c1df554698850eee3dd33120a7f17d4476 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Sun, 28 Jan 2024 11:16:33 +0100 Subject: [PATCH 02/11] fix json schema --- tests/json_schema.json | 104 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/tests/json_schema.json b/tests/json_schema.json index e87b76b..c3df000 100644 --- a/tests/json_schema.json +++ b/tests/json_schema.json @@ -39,8 +39,108 @@ }, "language": {"type": "string"}, "language_probs": { - "type": "array", - "items": {"type": "number", "minimum":0, "maximum":1} + "type": "object", + "properties": { + "en": {"type": "number", "minimum":0, "maximum":1}, + "zh": {"type": "number", "minimum":0, "maximum":1}, + "de": {"type": "number", "minimum":0, "maximum":1}, + "es": {"type": "number", "minimum":0, "maximum":1}, + "ru": {"type": "number", "minimum":0, "maximum":1}, + "ko": {"type": "number", "minimum":0, "maximum":1}, + "fr": {"type": "number", "minimum":0, "maximum":1}, + "ja": {"type": "number", "minimum":0, "maximum":1}, + "pt": {"type": "number", "minimum":0, "maximum":1}, + "tr": {"type": "number", "minimum":0, "maximum":1}, + "pl": {"type": "number", "minimum":0, "maximum":1}, + "ca": {"type": "number", "minimum":0, "maximum":1}, + "nl": {"type": "number", "minimum":0, "maximum":1}, + "ar": {"type": "number", "minimum":0, "maximum":1}, + "sv": {"type": "number", "minimum":0, "maximum":1}, + "it": {"type": "number", "minimum":0, "maximum":1}, + "id": {"type": "number", "minimum":0, "maximum":1}, + "hi": {"type": "number", "minimum":0, "maximum":1}, + "fi": {"type": "number", "minimum":0, "maximum":1}, + "vi": {"type": "number", "minimum":0, "maximum":1}, + "he": {"type": "number", "minimum":0, "maximum":1}, + "uk": {"type": "number", "minimum":0, "maximum":1}, + "el": {"type": "number", "minimum":0, "maximum":1}, + "ms": {"type": "number", "minimum":0, "maximum":1}, + "cs": {"type": "number", "minimum":0, "maximum":1}, + "ro": {"type": "number", "minimum":0, "maximum":1}, + "da": {"type": "number", "minimum":0, "maximum":1}, + "hu": {"type": "number", "minimum":0, "maximum":1}, + "ta": {"type": "number", "minimum":0, "maximum":1}, + "no": {"type": "number", "minimum":0, "maximum":1}, + "th": {"type": "number", "minimum":0, "maximum":1}, + "ur": {"type": "number", "minimum":0, "maximum":1}, + "hr": {"type": "number", "minimum":0, "maximum":1}, + "bg": {"type": "number", "minimum":0, "maximum":1}, + "lt": {"type": "number", "minimum":0, "maximum":1}, + "la": {"type": "number", "minimum":0, "maximum":1}, + "mi": {"type": "number", "minimum":0, "maximum":1}, + "ml": {"type": "number", "minimum":0, "maximum":1}, + "cy": {"type": "number", "minimum":0, "maximum":1}, + "sk": {"type": "number", "minimum":0, "maximum":1}, + "te": {"type": "number", "minimum":0, "maximum":1}, + "fa": {"type": "number", "minimum":0, "maximum":1}, + "lv": {"type": "number", "minimum":0, "maximum":1}, + "bn": {"type": "number", "minimum":0, "maximum":1}, + "sr": {"type": "number", "minimum":0, "maximum":1}, + "az": {"type": "number", "minimum":0, "maximum":1}, + "sl": {"type": "number", "minimum":0, "maximum":1}, + "kn": {"type": "number", "minimum":0, "maximum":1}, + "et": {"type": "number", "minimum":0, "maximum":1}, + "mk": {"type": "number", "minimum":0, "maximum":1}, + "br": {"type": "number", "minimum":0, "maximum":1}, + "eu": {"type": "number", "minimum":0, "maximum":1}, + "is": {"type": "number", "minimum":0, "maximum":1}, + "hy": {"type": "number", "minimum":0, "maximum":1}, + "ne": {"type": "number", "minimum":0, "maximum":1}, + "mn": {"type": "number", "minimum":0, "maximum":1}, + "bs": {"type": "number", "minimum":0, "maximum":1}, + "kk": {"type": "number", "minimum":0, "maximum":1}, + "sq": {"type": "number", "minimum":0, "maximum":1}, + "sw": {"type": "number", "minimum":0, "maximum":1}, + "gl": {"type": "number", "minimum":0, "maximum":1}, + "mr": {"type": "number", "minimum":0, "maximum":1}, + "pa": {"type": "number", "minimum":0, "maximum":1}, + "si": {"type": "number", "minimum":0, "maximum":1}, + "km": {"type": "number", "minimum":0, "maximum":1}, + "sn": {"type": "number", "minimum":0, "maximum":1}, + "yo": {"type": "number", "minimum":0, "maximum":1}, + "so": {"type": "number", "minimum":0, "maximum":1}, + "af": {"type": "number", "minimum":0, "maximum":1}, + "oc": {"type": "number", "minimum":0, "maximum":1}, + "ka": {"type": "number", "minimum":0, "maximum":1}, + "be": {"type": "number", "minimum":0, "maximum":1}, + "tg": {"type": "number", "minimum":0, "maximum":1}, + "sd": {"type": "number", "minimum":0, "maximum":1}, + "gu": {"type": "number", "minimum":0, "maximum":1}, + "am": {"type": "number", "minimum":0, "maximum":1}, + "yi": {"type": "number", "minimum":0, "maximum":1}, + "lo": {"type": "number", "minimum":0, "maximum":1}, + "uz": {"type": "number", "minimum":0, "maximum":1}, + "fo": {"type": "number", "minimum":0, "maximum":1}, + "ht": {"type": "number", "minimum":0, "maximum":1}, + "ps": {"type": "number", "minimum":0, "maximum":1}, + "tk": {"type": "number", "minimum":0, "maximum":1}, + "nn": {"type": "number", "minimum":0, "maximum":1}, + "mt": {"type": "number", "minimum":0, "maximum":1}, + "sa": {"type": "number", "minimum":0, "maximum":1}, + "lb": {"type": "number", "minimum":0, "maximum":1}, + "my": {"type": "number", "minimum":0, "maximum":1}, + "bo": {"type": "number", "minimum":0, "maximum":1}, + "tl": {"type": "number", "minimum":0, "maximum":1}, + "mg": {"type": "number", "minimum":0, "maximum":1}, + "as": {"type": "number", "minimum":0, "maximum":1}, + "tt": {"type": "number", "minimum":0, "maximum":1}, + "haw": {"type": "number", "minimum":0, "maximum":1}, + "ln": {"type": "number", "minimum":0, "maximum":1}, + "ha": {"type": "number", "minimum":0, "maximum":1}, + "ba": {"type": "number", "minimum":0, "maximum":1}, + "jw": {"type": "number", "minimum":0, "maximum":1}, + "su": {"type": "number", "minimum":0, "maximum":1} + } } } } From e1fded57d7da04bb85894e9ef8185c136ed1a5f4 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Sun, 28 Jan 2024 11:17:25 +0100 Subject: [PATCH 03/11] update test --- tests/test_transcribe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 1e35573..fc5643c 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -661,15 +661,17 @@ def test_hugging_face_model(self): ) import tempfile - from transformers import WhisperForConditionalGeneration + from transformers import WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig tempfolder = os.path.join(tempfile.gettempdir(), "tmp_whisper-tiny-french-cased") for safe_serialization in False, True,: for max_shard_size in "100MB", "10GB", : shutil.rmtree(tempfolder, ignore_errors=True) model = WhisperForConditionalGeneration.from_pretrained("qanastek/whisper-tiny-french-cased") + processor = WhisperProcessor.from_pretrained("qanastek/whisper-tiny-french-cased") try: model.save_pretrained(tempfolder, safe_serialization=safe_serialization, max_shard_size=max_shard_size) + processor.save_pretrained(tempfolder) self._test_cli_( ["--model", tempfolder, "--verbose", "True"], "verbose", files=["bonjour.wav"], extensions=None, From 85fefa3db42d4d8950af6b8467554096de96e640 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Sun, 28 Jan 2024 11:18:04 +0100 Subject: [PATCH 04/11] user matplotlib old version (latest is not the best) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e49e71e..d3e55ca 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ }, include_package_data=True, extras_require={ - 'dev': ['matplotlib', 'transformers'], + 'dev': ['matplotlib==3.7.4', 'transformers'], 'vad_silero': ['onnxruntime', 'torchaudio'], 'vad_auditok': ['auditok'], 'test': ['jsonschema'], From 987c5ff29c2dc44070bdfbd3fb5aac4c536fc097 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 29 Jan 2024 13:27:09 +0100 Subject: [PATCH 05/11] Fix auditok corner case --- whisper_timestamped/transcribe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index b48ecf0..cd29896 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -1978,14 +1978,16 @@ def apply_folder_hack(): data = (audio.numpy() * 32767).astype(np.int16).tobytes() + audio_duration = len(audio) / SAMPLE_RATE + segments = auditok.split( data, sampling_rate=SAMPLE_RATE, # sampling frequency in Hz channels=1, # number of channels sample_width=2, # number of bytes per sample min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds - max_dur=len(audio)/SAMPLE_RATE, # maximum duration of an event - max_silence=min_silence_duration, # maximum duration of tolerated continuous silence within an event + max_dur=audio_duration, # maximum duration of an event + max_silence=min(audio_duration*.95, min_silence_duration), # maximum duration of tolerated continuous silence within an event energy_threshold=50, drop_trailing_silence=True, ) From dabd52ce6471cfc786879eebc016355842d631e7 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 29 Jan 2024 14:37:28 +0100 Subject: [PATCH 06/11] more convenient way of getting meta-information (dimensions) from transformers model --- whisper_timestamped/transcribe.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) mode change 100755 => 100644 whisper_timestamped/transcribe.py diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py old mode 100755 new mode 100644 index cd29896..8750725 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -2483,17 +2483,18 @@ def __init__(self, model, processor, generation_config): self.device = model.device # Dimensions + model_config = model.config self.dims = whisper.model.ModelDimensions( - n_mels = model.get_encoder().get_input_embeddings().in_channels, - n_audio_ctx = 1500, - n_audio_state = model.get_encoder().get_input_embeddings().out_channels, - n_audio_head = model.get_encoder().layers[0].self_attn.num_heads, - n_audio_layer = len(model.get_encoder().layers), - n_vocab = model.get_decoder().get_input_embeddings().num_embeddings, - n_text_ctx = 448, - n_text_state = model.get_decoder().get_input_embeddings().embedding_dim, - n_text_head = model.get_decoder().layers[0].self_attn.num_heads, - n_text_layer = len(model.get_decoder().layers), + n_mels = model_config.num_mel_bins, # model.get_encoder().get_input_embeddings().in_channels, # 80 + n_audio_ctx = model_config.max_source_positions, # 1500 + n_audio_state = model_config.d_model, # model.get_encoder().get_input_embeddings().out_channels, # 768 + n_audio_head = model_config.encoder_attention_heads, # model.get_encoder().layers[0].self_attn.num_heads, + n_audio_layer = model_config.encoder_layers, # len(model.get_encoder().layers), + n_vocab = model_config.vocab_size, # model.get_decoder().get_input_embeddings().num_embeddings, # ~51865 + n_text_ctx = model_config.max_length, # 448 + n_text_state = model_config.d_model, # model.get_decoder().get_input_embeddings().embedding_dim, # 768 + n_text_head = model_config.decoder_attention_heads, # model.get_decoder().layers[0].self_attn.num_heads, + n_text_layer = model_config.decoder_layers, # len(model.get_decoder().layers), ) # Tokenization @@ -3092,4 +3093,4 @@ def filtered_keys(result, keys = [ if __name__ == "__main__": - cli() \ No newline at end of file + cli() From e5b5819c741ef1159a5251778b4a8f5fa8095555 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 15 Feb 2024 18:15:11 +0100 Subject: [PATCH 07/11] add options about precision when decoding with transformers (WIP) --- whisper_timestamped/transcribe.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 8750725..e7127ee 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -2371,11 +2371,22 @@ def load_model( except OSError: generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny") processor = transformers.WhisperProcessor.from_pretrained(name) - model = transformers.WhisperForConditionalGeneration.from_pretrained(name) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" + precision = torch.float32 + model = transformers.WhisperForConditionalGeneration.from_pretrained( + name, + # load_in_8bit=True, + # load_in_4bit=True, + torch_dtype=precision, + # torch_dtype=torch.bfloat16, + # attn_implementation="flash_attention_2", + # attn_implementation="sdpa", + ) + # model = model.to_bettertransformer() + model = model.to(device) - return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config) + return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config, precision) elif backend not in ["openai", "openai-whisper"]: raise ValueError(f"Got unexpected backend {backend}") @@ -2474,13 +2485,14 @@ class TransformerWhisperAsOpenAIWhisper: Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped) """ - def __init__(self, model, processor, generation_config): + def __init__(self, model, processor, generation_config, precision): self.model = model # transformers.WhisperForConditionalGeneration self.processor = processor # transformers.WhisperProcessor self.generation_config = generation_config # transformers.GenerationConfig self.device = model.device + self.precision = precision # Dimensions model_config = model.config @@ -2609,7 +2621,7 @@ def transcribe(self, audio, use_token_timestamps=False, **kwargs): # Transcribe output = self.model.generate( - features, + features.to(self.precision), **generate_kwargs ) @@ -2759,7 +2771,7 @@ def _iter_segments(self, output, prompt_ids): def __call__(self, mfcc, tokens): - output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True) + output = self.model(mfcc.to(self.precision), decoder_input_ids=tokens, output_attentions=True) return output.logits From 90720b8c3af8985328b01c736260da34888e3af3 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 15 Feb 2024 18:15:43 +0100 Subject: [PATCH 08/11] add alignment_heads attribute to the model. That might be needed when decoding --- whisper_timestamped/transcribe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index e7127ee..255c994 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -1038,6 +1038,7 @@ def hook_output_logits(layer, ins, outs): all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits)) try: + model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?"" transcription = model.transcribe(audio, **whisper_options) finally: for hook in all_hooks: From 6197f0868186a6afe59828583c9cd98ce2462772 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Sun, 25 Feb 2024 15:49:24 +0100 Subject: [PATCH 09/11] forgotten break --- whisper_timestamped/transcribe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 255c994..bc0eb02 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -2427,6 +2427,7 @@ def load_model( model_path = list(set(mapping["weight_map"].values())) folder = os.path.dirname(index_file) model_path = [os.path.join(folder, p) for p in model_path] + break assert model_path is not None except: raise RuntimeError(f"Original error: {err}\nCould not find model {name} from HuggingFace nor local folders.") From cf576e529b9beea7318e7aaa633d8707f93f4c08 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Sun, 25 Feb 2024 16:43:22 +0100 Subject: [PATCH 10/11] fixes #64 : fix inconsistency between segments when there are empty text --- whisper_timestamped/transcribe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index bc0eb02..68b9fe0 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -905,6 +905,9 @@ def filter_tokens(tokens): assert len(segment_logprobs) == len(segment_tokens), f"Inconsistent number of segments: logprobs ({len(segment_logprobs)}) != tokens ({len(segment_tokens)})" whisper_segments = transcription["segments"] + # See issue 64: some segments may have empty text + if any(not s["text"] for s in whisper_segments): + whisper_segments = [s for s in whisper_segments if s["text"]] l1 = len(whisper_segments) l2 = len(timestamped_word_segments) if l1 != l2 and l1 != 0: From 58909dca962e6012bac49f67cc47cf2dadc94c7e Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Sun, 25 Feb 2024 16:45:18 +0100 Subject: [PATCH 11/11] bump version --- whisper_timestamped/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 68b9fe0..09eb6cc 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -3,7 +3,7 @@ __author__ = "Jérôme Louradour" __credits__ = ["Jérôme Louradour"] __license__ = "GPLv3" -__version__ = "1.14.4" +__version__ = "1.15.0" # Set some environment variables import os