From 3626f6a6b47e81c12b7a7cdaf7b64dd60f2ab6ac Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Thu, 31 Oct 2024 12:50:01 +0300 Subject: [PATCH 01/11] initial commit --- faster_whisper/audio.py | 21 +-- faster_whisper/feature_extractor.py | 235 ++++++++++++++++++++++------ faster_whisper/transcribe.py | 64 +++----- faster_whisper/vad.py | 11 +- 4 files changed, 225 insertions(+), 106 deletions(-) diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index e7e225a2..222e73ba 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -14,7 +14,6 @@ import av import numpy as np -import torch def decode_audio( @@ -72,9 +71,9 @@ def decode_audio( if split_stereo: left_channel = audio[0::2] right_channel = audio[1::2] - return torch.from_numpy(left_channel), torch.from_numpy(right_channel) + return left_channel, right_channel - return torch.from_numpy(audio) + return audio def _ignore_invalid_frames(frames): @@ -113,20 +112,12 @@ def pad_or_trim(array, length: int = 3000, *, axis: int = -1): """ Pad or trim the Mel features array to 3000, as expected by the encoder. """ - axis = axis % array.ndim if array.shape[axis] > length: - idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1) - return array[idx] + array = array.take(indices=range(length), axis=axis) if array.shape[axis] < length: - pad_widths = ( - [ - 0, - ] - * array.ndim - * 2 - ) - pad_widths[2 * axis] = length - array.shape[axis] - array = torch.nn.functional.pad(array, tuple(pad_widths[::-1])) + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) return array diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 6371d5ef..70c74b25 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -1,7 +1,20 @@ -import torch +import numpy as np + +try: + import cupy as cp + + CUPY_AVAILABLE = True +except ImportError: + CUPY_AVAILABLE = False + + +def get_array_module(device: str = "auto"): + if device in ["auto", "cuda"] and CUPY_AVAILABLE and cp.cuda.is_available(): + return cp, "cuda" + else: + return np, "cpu" -# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501 class FeatureExtractor: def __init__( self, @@ -12,10 +25,8 @@ def __init__( chunk_length=30, n_fft=400, ): - if device == "auto": - self.device = "cuda" if torch.cuda.is_available() else "cpu" - else: - self.device = device + self.array_module: np + self.array_module, self._device = get_array_module(device) self.n_fft = n_fft self.hop_length = hop_length self.chunk_length = chunk_length @@ -25,24 +36,20 @@ def __init__( self.sampling_rate = sampling_rate self.mel_filters = self.get_mel_filters( sampling_rate, n_fft, n_mels=feature_size - ) + ).astype("float32") - @staticmethod - def get_mel_filters(sr, n_fft, n_mels=128): - """ - Implementation of librosa.filters.mel in Pytorch - """ + def get_mel_filters(self, sr, n_fft, n_mels=128): # Initialize the weights n_mels = int(n_mels) # Center freqs of each FFT bin - fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr) + fftfreqs = self.array_module.fft.rfftfreq(n=n_fft, d=1.0 / sr) # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = 0.0 max_mel = 45.245640471924965 - mels = torch.linspace(min_mel, max_mel, n_mels + 2) + mels = self.array_module.linspace(min_mel, max_mel, n_mels + 2) # Fill in the linear scale f_min = 0.0 @@ -52,30 +59,32 @@ def get_mel_filters(sr, n_fft, n_mels=128): # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region + logstep = self.array_module.log(6.4) / 27.0 # step size for log region # If we have vector data, vectorize log_t = mels >= min_log_mel - freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) - - mel_f = freqs + freqs[log_t] = min_log_hz * self.array_module.exp( + logstep * (mels[log_t] - min_log_mel) + ) - fdiff = torch.diff(mel_f) - ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1) + fdiff = self.array_module.diff(freqs) + ramps = freqs.reshape(-1, 1) - fftfreqs.reshape(1, -1) - lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1) - upper = ramps[2:] / fdiff[1:].unsqueeze(1) + lower = -ramps[:-2] / self.array_module.expand_dims(fdiff[:-1], axis=1) + upper = ramps[2:] / self.array_module.expand_dims(fdiff[1:], axis=1) # Intersect them with each other and zero, vectorized across all i - weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper)) + weights = self.array_module.maximum( + self.array_module.zeros_like(lower), self.array_module.minimum(lower, upper) + ) # Slaney-style mel is scaled to be approx constant energy per channel - enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) - weights *= enorm.unsqueeze(1) + enorm = 2.0 / (freqs[2 : n_mels + 2] - freqs[:n_mels]) + weights *= self.array_module.expand_dims(enorm, axis=1) return weights - def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False): + def __call__(self, waveform: np.ndarray, padding=True, chunk_length=None): """ Compute the log-Mel spectrogram of the provided audio. """ @@ -84,31 +93,167 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False): self.n_samples = chunk_length * self.sampling_rate self.nb_max_frames = self.n_samples // self.hop_length - if waveform.dtype is not torch.float32: - waveform = waveform.to(torch.float32) - - waveform = ( - waveform.to(self.device) - if self.device == "cuda" and not waveform.is_cuda - else waveform - ) + if waveform.dtype is not np.float32: + waveform = waveform.astype(np.float32) if padding: - waveform = torch.nn.functional.pad(waveform, (0, self.n_samples)) + waveform = np.pad(waveform, (0, self.n_samples)) - window = torch.hann_window(self.n_fft).to(waveform.device) + window = self.array_module.hanning(self.n_fft + 1)[:-1].astype("float32") - stft = torch.stft( - waveform, self.n_fft, self.hop_length, window=window, return_complex=True + stft_output = stft( + self.array_module, + waveform, + self.n_fft, + self.hop_length, + window=window, + return_complex=True, ) - magnitudes = stft[..., :-1].abs() ** 2 + magnitudes = self.array_module.abs(stft_output[..., :-1]) ** 2 - mel_spec = self.mel_filters.to(waveform.device) @ magnitudes + mel_spec = self.mel_filters @ magnitudes - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = self.array_module.log10( + self.array_module.clip(mel_spec, a_min=1e-10, a_max=None) + ) + log_spec = self.array_module.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - # When the model is running on multiple GPUs, the output should be moved - # to the CPU since we don't know which GPU will handle the next job. - return log_spec.cpu() if to_cpu else log_spec + return np.asarray(log_spec.tolist(), dtype=log_spec.dtype) + + @property + def device(self): + return self._device + + @device.setter + def device(self, device): + if device != self.device: + self.array_module, self._device = get_array_module(device) + + +def stft( + array_module: np, + input_tensor: np.ndarray, + n_fft: int, + hop_length: int = None, + win_length: int = None, + window: np.ndarray = None, + center=True, + mode="reflect", + normalized=False, + onesided=None, + return_complex=None, +): + + # Default initialization for hop_length and win_length + hop_length = hop_length if hop_length is not None else n_fft // 4 + win_length = win_length if win_length is not None else n_fft + input_is_complex = np.iscomplexobj(input_tensor) + + # Determine if the output should be complex + return_complex = ( + return_complex + if return_complex is not None + else (input_is_complex or (window is not None and np.iscomplexobj(window))) + ) + + if not return_complex and return_complex is None: + raise ValueError("stft requires the return_complex parameter for real inputs.") + + # Input checks + if not np.issubdtype(input_tensor.dtype, np.floating) and not input_is_complex: + raise ValueError( + f"stft: expected a tensor of floating point or complex values, got {input_tensor.dtype}" + ) + + if input_tensor.ndim > 2 or input_tensor.ndim < 1: + raise ValueError( + f"stft: expected a 1D or 2D tensor, but got {input_tensor.ndim}D tensor" + ) + + # Handle 1D input + if input_tensor.ndim == 1: + input_tensor = np.expand_dims(input_tensor, axis=0) + input_tensor_1d = True + else: + input_tensor_1d = False + + # Center padding if required + if center: + pad_amount = n_fft // 2 + input_tensor = np.pad( + input_tensor, ((0, 0), (pad_amount, pad_amount)), mode=mode + ) + + batch, length = input_tensor.shape + + # Additional input checks + if n_fft <= 0 or n_fft > length: + raise ValueError(f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}") + + if hop_length <= 0: + raise ValueError( + f"stft: expected hop_length > 0, but got hop_length={hop_length}" + ) + + if win_length <= 0 or win_length > n_fft: + raise ValueError( + f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}" + ) + + if window is not None: + if window.ndim != 1 or window.shape[0] != win_length: + raise ValueError( + f"stft: expected a 1D window tensor of size equal to win_length={win_length}, " + f"but got window with size {window.shape}" + ) + + # Handle padding of the window if necessary + if win_length < n_fft: + left = (n_fft - win_length) // 2 + window_ = array_module.zeros(n_fft, dtype=window.dtype) + window_[left : left + win_length] = window + else: + window_ = window + + # Calculate the number of frames + n_frames = 1 + (length - n_fft) // hop_length + + # Time to columns + input_tensor = np.lib.stride_tricks.as_strided( + input_tensor, + (batch, n_frames, n_fft), + ( + input_tensor.strides[0], + hop_length * input_tensor.strides[1], + input_tensor.strides[1], + ), + ) + + if window_ is not None: + input_tensor = array_module.asarray(input_tensor) * window_ + + # FFT and transpose + complex_fft = input_is_complex + onesided = onesided if onesided is not None else not complex_fft + + if normalized: + norm = "ortho" + else: + norm = None + + if complex_fft: + if onesided: + raise ValueError( + "Cannot have onesided output if window or input is complex" + ) + output = array_module.fft.fft(input_tensor, n=n_fft, axis=-1, norm=norm) + else: + output = array_module.fft.rfft(input_tensor, n=n_fft, axis=-1, norm=norm) + + output = output.transpose((0, 2, 1)) + + if input_tensor_1d: + output = output.squeeze(0) + + return output if return_complex else array_module.real(output) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index a8db5715..b5959ec7 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -13,7 +13,6 @@ import ctranslate2 import numpy as np import tokenizers -import torch from tqdm import tqdm @@ -206,7 +205,7 @@ def get_language_and_tokenizer( def transcribe( self, - audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], + audio: Union[str, BinaryIO, np.ndarray], language: Optional[str] = None, task: str = None, log_progress: bool = False, @@ -335,9 +334,7 @@ def transcribe( sampling_rate = self.model.feature_extractor.sampling_rate - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio) - elif not isinstance(audio, torch.Tensor): + if not isinstance(audio, np.ndarray): audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate @@ -435,14 +432,11 @@ def transcribe( ) audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps) - to_cpu = ( - self.model.model.device == "cuda" and len(self.model.model.device_index) > 1 - ) features = ( - torch.stack( + np.stack( [ pad_or_trim( - self.model.feature_extractor(chunk, to_cpu=to_cpu)[ + self.model.feature_extractor(chunk)[ ..., : chunk.shape[0] // self.model.feature_extractor.hop_length, ] @@ -629,7 +623,7 @@ def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: def transcribe( self, - audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], + audio: Union[str, BinaryIO, np.ndarray], language: Optional[str] = None, task: str = "transcribe", beam_size: int = 5, @@ -757,9 +751,7 @@ def transcribe( sampling_rate = self.feature_extractor.sampling_rate - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio) - elif not isinstance(audio, torch.Tensor): + if not isinstance(audio, np.ndarray): audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate @@ -776,7 +768,7 @@ def transcribe( vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) - audio = torch.cat(audio_chunks, dim=0) + audio = np.concatenate(audio_chunks, axis=0) duration_after_vad = audio.shape[0] / sampling_rate self.logger.info( @@ -800,10 +792,7 @@ def transcribe( else: speech_chunks = None - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = self.feature_extractor( - audio, chunk_length=chunk_length, to_cpu=to_cpu - ) + features = self.feature_extractor(audio, chunk_length=chunk_length) encoder_output = None all_language_probs = None @@ -1034,7 +1023,7 @@ def _split_segments_by_timestamps( def generate_segments( self, - features: torch.Tensor, + features: np.ndarray, tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output: Optional[ctranslate2.StorageView] = None, @@ -1338,13 +1327,13 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: prompt_reset_since = len(all_tokens) - def encode(self, features: torch.Tensor) -> ctranslate2.StorageView: + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 if features.ndim == 2: - features = features.unsqueeze(0) + features = np.expand_dims(features, 0) features = get_ctranslate2_storage(features) return self.model.encode(features, to_cpu=to_cpu) @@ -1715,7 +1704,7 @@ def find_alignment( def generate_segment_batched( self, - features: torch.Tensor, + features: np.ndarray, tokenizer: Tokenizer, options: dict, ): @@ -1764,9 +1753,8 @@ def generate_segment_batched( return encoder_output, output - def detect_language(self, audio: torch.Tensor): - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ + def detect_language(self, audio: np.ndarray): + segment = self.feature_extractor(audio)[ :, : self.feature_extractor.nb_max_frames ] encoder_output = self.encode(pad_or_trim(segment)) @@ -1780,7 +1768,7 @@ def detect_language(self, audio: torch.Tensor): return language, language_probability, all_language_probs def detect_language_multi_segment( - self, audio: Union[str, BinaryIO, torch.Tensor], params: Optional[dict] = None + self, audio: Union[str, BinaryIO, np.ndarray], params: Optional[dict] = None ): """ Detect language based on N highly-confident segments of a language. @@ -1816,8 +1804,8 @@ def detect_language_multi_segment( # decode audio if it is not decoded already sampling_rate = self.feature_extractor.sampling_rate - if not isinstance(audio, torch.Tensor): - audio: torch.Tensor = decode_audio(audio, sampling_rate=sampling_rate) + if not isinstance(audio, np.ndarray): + audio: np.ndarray = decode_audio(audio, sampling_rate=sampling_rate) # calculate duration of audio as number of seconds # audio.shape[0] is the number of samples in the audio @@ -1832,7 +1820,7 @@ def detect_language_multi_segment( speech_chunks = get_speech_timestamps(audio, vad_params) # merge chunks of audio that contain speech into a single array audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) - audio = torch.cat(audio_chunks, dim=0) + audio = np.concatenate(audio_chunks, axis=0) # calculate new duration of audio without silence duration_vad = audio.shape[0] / sampling_rate @@ -1856,8 +1844,7 @@ def detect_language_multi_segment( nb_max_frames = self.feature_extractor.nb_max_frames # extract features from audio with padding (default) - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = self.feature_extractor(audio, to_cpu=to_cpu) + features = self.feature_extractor(audio) # number of segments in the audio num_segments = features.shape[-1] // nb_max_frames @@ -1969,8 +1956,8 @@ def key_func(language): dc_offset = audio.mean() audio_minus_dc_offset = audio - dc_offset is_silent = ( - torch.all(audio.abs() < 0.01) - or torch.sqrt(torch.mean(audio_minus_dc_offset**2)) < 0.01 + all(np.abs(audio) < 0.1) + or np.sqrt(np.mean(audio_minus_dc_offset**2)) < 0.01 ) if is_silent: @@ -2020,12 +2007,9 @@ def restore_speech_timestamps( yield segment -def get_ctranslate2_storage(segment: torch.Tensor) -> ctranslate2.StorageView: - segment = segment.contiguous() - segment = ctranslate2.StorageView.from_array( - segment if segment.is_cuda else segment.numpy() - ) # torch cpu tensors don't implement __array_interface__ - # https://github.com/pytorch/pytorch/issues/51156 +def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: + segment = np.ascontiguousarray(segment) + segment = ctranslate2.StorageView.from_array(segment) return segment diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index d448f5b4..d47515ec 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -5,7 +5,6 @@ from typing import Dict, List, NamedTuple, Optional, Tuple import numpy as np -import torch from faster_whisper.utils import get_assets_path @@ -42,7 +41,7 @@ class VadOptions(NamedTuple): def get_speech_timestamps( - audio: torch.Tensor, + audio: np.ndarray, vad_options: Optional[VadOptions] = None, sampling_rate: int = 16000, **kwargs, @@ -82,7 +81,7 @@ def get_speech_timestamps( model = get_vad_model() padded_audio = np.pad( - audio.numpy(), (0, window_size_samples - audio.shape[0] % window_size_samples) + audio, (0, window_size_samples - audio.shape[0] % window_size_samples) ) speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0) @@ -181,15 +180,15 @@ def get_speech_timestamps( def collect_chunks( - audio: torch.Tensor, chunks: List[dict], sampling_rate: int = 16000 -) -> Tuple[List[torch.Tensor], List[Dict[str, int]]]: + audio: np.ndarray, chunks: List[dict], sampling_rate: int = 16000 +) -> Tuple[List[np.ndarray], List[Dict[str, int]]]: """Collects audio chunks.""" if not chunks: chunk_metadata = { "start_time": 0, "end_time": 0, } - return [torch.tensor([], dtype=torch.float32)], [chunk_metadata] + return [np.array([], dtype=np.float32)], [chunk_metadata] audio_chunks = [] chunks_metadata = [] From 5f905ee1a8df8a486b5df09e6a1c0f50cacfd35f Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Thu, 31 Oct 2024 14:36:09 +0300 Subject: [PATCH 02/11] remove torch from reqirements --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 71fc482e..1b61b2c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,5 @@ ctranslate2>=4.0,<5 huggingface_hub>=0.13 tokenizers>=0.13,<1 onnxruntime>=1.14,<2 -torch>=2.1.1 av>=11 tqdm \ No newline at end of file From 65dc596177c5d2fbe87fbb85ef18ccb72490ac4c Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Thu, 31 Oct 2024 15:11:34 +0300 Subject: [PATCH 03/11] fix formatting and dtype --- faster_whisper/feature_extractor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 70c74b25..62d130bb 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -108,7 +108,7 @@ def __call__(self, waveform: np.ndarray, padding=True, chunk_length=None): self.hop_length, window=window, return_complex=True, - ) + ).astype("complex64") magnitudes = self.array_module.abs(stft_output[..., :-1]) ** 2 mel_spec = self.mel_filters @ magnitudes @@ -129,6 +129,10 @@ def device(self): def device(self, device): if device != self.device: self.array_module, self._device = get_array_module(device) + feature_size = self.mel_filters.shape[0] + self.mel_filters = self.get_mel_filters( + self.sampling_rate, self.n_fft, feature_size + ).astype("float32") def stft( @@ -144,7 +148,6 @@ def stft( onesided=None, return_complex=None, ): - # Default initialization for hop_length and win_length hop_length = hop_length if hop_length is not None else n_fft // 4 win_length = win_length if win_length is not None else n_fft From dcc95afd27e26f9e10aa35852015375f52cf954d Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Thu, 31 Oct 2024 15:37:04 +0300 Subject: [PATCH 04/11] add type annotations --- faster_whisper/feature_extractor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 62d130bb..0e83bca9 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -142,11 +142,11 @@ def stft( hop_length: int = None, win_length: int = None, window: np.ndarray = None, - center=True, - mode="reflect", - normalized=False, - onesided=None, - return_complex=None, + center: bool = True, + mode: str = "reflect", + normalized: bool = False, + onesided: bool = None, + return_complex: bool = None, ): # Default initialization for hop_length and win_length hop_length = hop_length if hop_length is not None else n_fft // 4 From f6adb22e8018d9638a2d3114b098ff9b6fc7bbd1 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 1 Nov 2024 00:20:01 +0200 Subject: [PATCH 05/11] remove CuPy --- faster_whisper/feature_extractor.py | 83 ++++++++--------------------- faster_whisper/transcribe.py | 4 +- 2 files changed, 24 insertions(+), 63 deletions(-) diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 0e83bca9..7ad5df5e 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -1,32 +1,15 @@ import numpy as np -try: - import cupy as cp - - CUPY_AVAILABLE = True -except ImportError: - CUPY_AVAILABLE = False - - -def get_array_module(device: str = "auto"): - if device in ["auto", "cuda"] and CUPY_AVAILABLE and cp.cuda.is_available(): - return cp, "cuda" - else: - return np, "cpu" - class FeatureExtractor: def __init__( self, - device: str = "auto", feature_size=80, sampling_rate=16000, hop_length=160, chunk_length=30, n_fft=400, ): - self.array_module: np - self.array_module, self._device = get_array_module(device) self.n_fft = n_fft self.hop_length = hop_length self.chunk_length = chunk_length @@ -38,18 +21,19 @@ def __init__( sampling_rate, n_fft, n_mels=feature_size ).astype("float32") - def get_mel_filters(self, sr, n_fft, n_mels=128): + @staticmethod + def get_mel_filters(sr, n_fft, n_mels=128): # Initialize the weights n_mels = int(n_mels) # Center freqs of each FFT bin - fftfreqs = self.array_module.fft.rfftfreq(n=n_fft, d=1.0 / sr) + fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = 0.0 max_mel = 45.245640471924965 - mels = self.array_module.linspace(min_mel, max_mel, n_mels + 2) + mels = np.linspace(min_mel, max_mel, n_mels + 2) # Fill in the linear scale f_min = 0.0 @@ -59,32 +43,28 @@ def get_mel_filters(self, sr, n_fft, n_mels=128): # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = self.array_module.log(6.4) / 27.0 # step size for log region + logstep = np.log(6.4) / 27.0 # step size for log region # If we have vector data, vectorize log_t = mels >= min_log_mel - freqs[log_t] = min_log_hz * self.array_module.exp( - logstep * (mels[log_t] - min_log_mel) - ) + freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) - fdiff = self.array_module.diff(freqs) + fdiff = np.diff(freqs) ramps = freqs.reshape(-1, 1) - fftfreqs.reshape(1, -1) - lower = -ramps[:-2] / self.array_module.expand_dims(fdiff[:-1], axis=1) - upper = ramps[2:] / self.array_module.expand_dims(fdiff[1:], axis=1) + lower = -ramps[:-2] / np.expand_dims(fdiff[:-1], axis=1) + upper = ramps[2:] / np.expand_dims(fdiff[1:], axis=1) # Intersect them with each other and zero, vectorized across all i - weights = self.array_module.maximum( - self.array_module.zeros_like(lower), self.array_module.minimum(lower, upper) - ) + weights = np.maximum(np.zeros_like(lower), np.minimum(lower, upper)) # Slaney-style mel is scaled to be approx constant energy per channel enorm = 2.0 / (freqs[2 : n_mels + 2] - freqs[:n_mels]) - weights *= self.array_module.expand_dims(enorm, axis=1) + weights *= np.expand_dims(enorm, axis=1) return weights - def __call__(self, waveform: np.ndarray, padding=True, chunk_length=None): + def __call__(self, waveform: np.ndarray, padding=480000, chunk_length=None): """ Compute the log-Mel spectrogram of the provided audio. """ @@ -97,46 +77,29 @@ def __call__(self, waveform: np.ndarray, padding=True, chunk_length=None): waveform = waveform.astype(np.float32) if padding: - waveform = np.pad(waveform, (0, self.n_samples)) + waveform = np.pad(waveform, (0, padding)) - window = self.array_module.hanning(self.n_fft + 1)[:-1].astype("float32") + window = np.hanning(self.n_fft + 1)[:-1].astype("float32") stft_output = stft( - self.array_module, waveform, self.n_fft, self.hop_length, window=window, return_complex=True, ).astype("complex64") - magnitudes = self.array_module.abs(stft_output[..., :-1]) ** 2 + magnitudes = np.abs(stft_output[..., :-1]) ** 2 mel_spec = self.mel_filters @ magnitudes - log_spec = self.array_module.log10( - self.array_module.clip(mel_spec, a_min=1e-10, a_max=None) - ) - log_spec = self.array_module.maximum(log_spec, log_spec.max() - 8.0) + log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) + log_spec = np.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - return np.asarray(log_spec.tolist(), dtype=log_spec.dtype) - - @property - def device(self): - return self._device - - @device.setter - def device(self, device): - if device != self.device: - self.array_module, self._device = get_array_module(device) - feature_size = self.mel_filters.shape[0] - self.mel_filters = self.get_mel_filters( - self.sampling_rate, self.n_fft, feature_size - ).astype("float32") + return log_spec def stft( - array_module: np, input_tensor: np.ndarray, n_fft: int, hop_length: int = None, @@ -214,7 +177,7 @@ def stft( # Handle padding of the window if necessary if win_length < n_fft: left = (n_fft - win_length) // 2 - window_ = array_module.zeros(n_fft, dtype=window.dtype) + window_ = np.zeros(n_fft, dtype=window.dtype) window_[left : left + win_length] = window else: window_ = window @@ -234,7 +197,7 @@ def stft( ) if window_ is not None: - input_tensor = array_module.asarray(input_tensor) * window_ + input_tensor = input_tensor * window_ # FFT and transpose complex_fft = input_is_complex @@ -250,13 +213,13 @@ def stft( raise ValueError( "Cannot have onesided output if window or input is complex" ) - output = array_module.fft.fft(input_tensor, n=n_fft, axis=-1, norm=norm) + output = np.fft.fft(input_tensor, n=n_fft, axis=-1, norm=norm) else: - output = array_module.fft.rfft(input_tensor, n=n_fft, axis=-1, norm=norm) + output = np.fft.rfft(input_tensor, n=n_fft, axis=-1, norm=norm) output = output.transpose((0, 2, 1)) if input_tensor_1d: output = output.squeeze(0) - return output if return_complex else array_module.real(output) + return output if return_complex else np.real(output) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index b5959ec7..56e62129 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -582,9 +582,7 @@ def __init__( "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") ) self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes) - self.feature_extractor = FeatureExtractor( - **self.feat_kwargs, device=self.device - ) + self.feature_extractor = FeatureExtractor(**self.feat_kwargs) self.input_stride = 2 self.num_samples_per_token = ( self.feature_extractor.hop_length * self.input_stride From bc8650329e2c31913648d3d2ccf1e2d28bec0451 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 1 Nov 2024 00:23:26 +0200 Subject: [PATCH 06/11] reduce padding to `hop_length` insteadn of `n_samples` --- faster_whisper/feature_extractor.py | 2 +- faster_whisper/transcribe.py | 6 ++---- tests/test_transcribe.py | 6 +++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 7ad5df5e..354d3369 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -64,7 +64,7 @@ def get_mel_filters(sr, n_fft, n_mels=128): return weights - def __call__(self, waveform: np.ndarray, padding=480000, chunk_length=None): + def __call__(self, waveform: np.ndarray, padding=160, chunk_length=None): """ Compute the log-Mel spectrogram of the provided audio. """ diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 56e62129..61392c3a 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -818,9 +818,7 @@ def transcribe( if isinstance(clip_timestamps, str) else clip_timestamps[0] ) - content_frames = ( - features.shape[-1] - self.feature_extractor.nb_max_frames - ) + content_frames = features.shape[-1] - 1 seek = ( int(start_timestamp * self.frames_per_second) if start_timestamp * self.frames_per_second < content_frames @@ -1026,7 +1024,7 @@ def generate_segments( options: TranscriptionOptions, encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: - content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames + content_frames = features.shape[-1] - 1 content_duration = float(content_frames * self.feature_extractor.time_per_frame) if isinstance(options.clip_timestamps, str): diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 08cc3cc7..e25af3ac 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -32,7 +32,7 @@ def test_transcribe(jfk_path): segment = segments[0] assert segment.text == ( - " And so my fellow Americans ask not what your country can do for you, " + " And so my fellow Americans, ask not what your country can do for you, " "ask what you can do for your country." ) @@ -97,12 +97,12 @@ def test_prefix_with_timestamps(jfk_path): segment = segments[0] assert segment.text == ( - " And so my fellow Americans ask not what your country can do for you, " + " And so my fellow Americans, ask not what your country can do for you, " "ask what you can do for your country." ) assert segment.start == 0 - assert 10 < segment.end < 11 + assert 10 < segment.end <= 11 def test_vad(jfk_path): From fdadc6420a2ff17f0296578c75115f5b9b76fb0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Wed, 6 Nov 2024 15:53:36 +0100 Subject: [PATCH 07/11] Attempt to fix logging for batched Whisper --- faster_whisper/transcribe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index b629c565..f1207c95 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -133,6 +133,7 @@ def __init__( tokenizer=None, language: Optional[str] = None, ): + self.logger = get_logger() self.model: WhisperModel = model self.tokenizer = tokenizer self.options = options From 15c39f4bcf291cb1618f93846fda3719f12ed823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Thu, 7 Nov 2024 10:10:43 +0100 Subject: [PATCH 08/11] Hopefully same logging as with normal WhisperModel --- faster_whisper/transcribe.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index f1207c95..8b4446a2 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -361,6 +361,10 @@ def transcribe( audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate + self.logger.info( + "Processing audio with duration %s", format_timestamp(duration) + ) + chunk_length = chunk_length or self.model.feature_extractor.chunk_length # if no segment split is provided, use vad_model and generate segments if not clip_timestamps: @@ -410,6 +414,11 @@ def transcribe( / sampling_rate ) + self.logger.info( + "VAD filter removed %s of audio", + format_timestamp(duration - duration_after_vad), + ) + # batched options: see the difference with default options in WhisperModel batched_options = TranscriptionOptions( beam_size=beam_size, From ef8c2cea318ff09ffd0aa2d9969d0c2b97e30e4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Tue, 3 Dec 2024 11:35:15 +0100 Subject: [PATCH 09/11] Revert irrelevant changes --- faster_whisper/feature_extractor.py | 62 ----------------------------- 1 file changed, 62 deletions(-) diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 7e7a0fa2..cfa3aee2 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -228,65 +228,3 @@ def __call__(self, waveform: np.ndarray, padding=160, chunk_length=None): log_spec = (log_spec + 4.0) / 4.0 return log_spec - - if win_length <= 0 or win_length > n_fft: - raise ValueError( - f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}" - ) - - if window is not None: - if window.ndim != 1 or window.shape[0] != win_length: - raise ValueError( - f"stft: expected a 1D window tensor of size equal to win_length={win_length}, " - f"but got window with size {window.shape}" - ) - - # Handle padding of the window if necessary - if win_length < n_fft: - left = (n_fft - win_length) // 2 - window_ = np.zeros(n_fft, dtype=window.dtype) - window_[left : left + win_length] = window - else: - window_ = window - - # Calculate the number of frames - n_frames = 1 + (length - n_fft) // hop_length - - # Time to columns - input_tensor = np.lib.stride_tricks.as_strided( - input_tensor, - (batch, n_frames, n_fft), - ( - input_tensor.strides[0], - hop_length * input_tensor.strides[1], - input_tensor.strides[1], - ), - ) - - if window_ is not None: - input_tensor = input_tensor * window_ - - # FFT and transpose - complex_fft = input_is_complex - onesided = onesided if onesided is not None else not complex_fft - - if normalized: - norm = "ortho" - else: - norm = None - - if complex_fft: - if onesided: - raise ValueError( - "Cannot have onesided output if window or input is complex" - ) - output = np.fft.fft(input_tensor, n=n_fft, axis=-1, norm=norm) - else: - output = np.fft.rfft(input_tensor, n=n_fft, axis=-1, norm=norm) - - output = output.transpose((0, 2, 1)) - - if input_tensor_1d: - output = output.squeeze(0) - - return output if return_complex else np.real(output) From 24448f699b17e9319478c66cb8f96a5f921c1c12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Mon, 23 Dec 2024 14:48:49 +0100 Subject: [PATCH 10/11] Avoid duplicate logger --- faster_whisper/transcribe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 8abdb033..a5401a8b 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -114,7 +114,6 @@ def __init__( self, model, ): - self.logger = get_logger() self.model: WhisperModel = model self.last_speech_timestamp = 0.0 From ecd74c3501cc5cb1c9b3c81a013c16c3a9db10b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Mon, 23 Dec 2024 14:52:41 +0100 Subject: [PATCH 11/11] Replace self.logger with self.model.logger --- faster_whisper/transcribe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index a5401a8b..1efd2eb0 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -388,7 +388,7 @@ def transcribe( audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate - self.logger.info( + self.model.logger.info( "Processing audio with duration %s", format_timestamp(duration) ) @@ -425,7 +425,7 @@ def transcribe( / sampling_rate ) - self.logger.info( + self.model.logger.info( "VAD filter removed %s of audio", format_timestamp(duration - duration_after_vad), )