Skip to content

Commit

Permalink
Lazy import audio (#2181)
Browse files Browse the repository at this point in the history
(cherry picked from commit fe7830c)
  • Loading branch information
SkafteNicki authored and Borda committed Nov 30, 2023
1 parent f00ceb1 commit 7740724
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 29 deletions.
8 changes: 2 additions & 6 deletions src/torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _MULTIPROCESSING_AVAILABLE, _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
import pesq as pesq_backend
else:
pesq_backend = None


__doctest_requires__ = {("perceptual_evaluation_speech_quality",): ["pesq"]}


Expand Down Expand Up @@ -88,6 +82,8 @@ def perceptual_evaluation_speech_quality(
"PESQ metric requires that pesq is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pesq`."
)
import pesq as pesq_backend

if fs not in (8000, 16000):
raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}")
if mode not in ("wb", "nb"):
Expand Down
11 changes: 3 additions & 8 deletions src/torchmetrics/functional/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE

solve = torch.linalg.solve

if _FAST_BSS_EVAL_AVAILABLE:
from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient
else:
toeplitz_conjugate_gradient = None


def _symmetric_toeplitz(vector: Tensor) -> Tensor:
"""Construct a symmetric Toeplitz matrix using one vector.
Expand Down Expand Up @@ -176,6 +169,8 @@ def signal_distortion_ratio(
r_0[..., 0] += load_diag

if use_cg_iter is not None and _FAST_BSS_EVAL_AVAILABLE:
from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient

# use preconditioned conjugate gradient
sol = toeplitz_conjugate_gradient(r_0, b, n_iter=use_cg_iter)
else:
Expand All @@ -189,7 +184,7 @@ def signal_distortion_ratio(
)
# regular matrix solver
r = _symmetric_toeplitz(r_0) # the auto-correlation of the L shifts of `target`
sol = solve(r, b)
sol = torch.linalg.solve(r, b)

# compute the coherence
coh = torch.einsum("...l,...l->...", b, sol)
Expand Down
21 changes: 10 additions & 11 deletions src/torchmetrics/functional/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,14 @@
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

if _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10:
from torchaudio.functional.filtering import lfilter
else:
lfilter = None
__doctest_skip__ = ["speech_reverberation_modulation_energy_ratio"]

if _GAMMATONE_AVAILABLE:
from gammatone.fftweight import fft_gtgram
from gammatone.filters import centre_freqs, make_erb_filters
else:
fft_gtgram, centre_freqs, make_erb_filters = None, None, None
if not _TORCHAUDIO_AVAILABLE or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABLE:
__doctest_skip__ = ["speech_reverberation_modulation_energy_ratio"]


@lru_cache(maxsize=100)
def _calc_erbs(low_freq: float, fs: int, n_filters: int, device: torch.device) -> Tensor:
from gammatone.filters import centre_freqs

ear_q = 9.26449 # Glasberg and Moore Parameters
min_bw = 24.7
order = 1
Expand All @@ -55,6 +47,8 @@ def _calc_erbs(low_freq: float, fs: int, n_filters: int, device: torch.device) -

@lru_cache(maxsize=100)
def _make_erb_filters(fs: int, num_freqs: int, cutoff: float, device: torch.device) -> Tensor:
from gammatone.filters import centre_freqs, make_erb_filters

cfs = centre_freqs(fs, num_freqs, cutoff)
fcoefs = make_erb_filters(fs, cfs)
return torch.tensor(fcoefs, device=device)
Expand Down Expand Up @@ -130,6 +124,8 @@ def _erb_filterbank(wave: Tensor, coefs: Tensor) -> Tensor:
Tensor: shape [B, N, time]
"""
from torchaudio.functional.filtering import lfilter

num_batch, time = wave.shape
wave = wave.to(dtype=coefs.dtype).reshape(num_batch, 1, time) # [B, time]
wave = wave.expand(-1, coefs.shape[0], -1) # [B, N, time]
Expand Down Expand Up @@ -239,6 +235,9 @@ def speech_reverberation_modulation_energy_ratio(
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
"``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
)
from gammatone.fftweight import fft_gtgram
from torchaudio.functional.filtering import lfilter

_srmr_arg_validate(
fs=fs,
n_cochlear_filters=n_cochlear_filters,
Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/functional/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE

if _PYSTOI_AVAILABLE:
from pystoi import stoi as stoi_backend
else:
stoi_backend = None
if not _PYSTOI_AVAILABLE:
__doctest_skip__ = ["short_time_objective_intelligibility"]


Expand Down Expand Up @@ -76,6 +73,8 @@ def short_time_objective_intelligibility(
"ShortTimeObjectiveIntelligibility metric requires that `pystoi` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pystoi`."
)
from pystoi import stoi as stoi_backend

_check_same_shape(preds, target)

if len(preds.shape) == 1:
Expand Down

0 comments on commit 7740724

Please sign in to comment.