Skip to content

Commit

Permalink
upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Nov 15, 2024
1 parent ca5fcb0 commit 2c80978
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 220 deletions.
4 changes: 3 additions & 1 deletion libdetectability/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .detectability import Detectability, DetectabilityLoss
from .detectability import Detectability
from .detectability_loss import DetectabilityLoss
from .segmented_detectability import SegmentedDetectability
131 changes: 3 additions & 128 deletions libdetectability/detectability.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import scipy as sp
import torch as tc
import torch as torch

from .internal.gammatone_filterbank import gammatone_filterbank
from .internal.outer_middle_ear_filter import outer_middle_ear_filter
Expand All @@ -14,7 +14,7 @@ def __init__(
taps=32,
dbspl=94.0,
spl=1.0,
relax_threshold=False,
threshold_mode="hearing",
normalize_gain=False,
norm="backward",
):
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
self.spl,
self.dbspl,
self.sampling_rate,
relax_threshold=relax_threshold,
thershold_mode=threshold_mode,
)
),
2.0,
Expand Down Expand Up @@ -152,128 +152,3 @@ def gain(self, reference):
factor = np.linalg.norm(gain, ord=2, axis=0)
gain = gain / factor
return gain


class DetectabilityLoss(tc.nn.Module):
def __init__(
self,
frame_size=2048,
sampling_rate=48000,
taps=32,
dbspl=94.0,
spl=1.0,
relax_threshold=False,
normalize_gain=False,
norm="backward",
reduction="meanlog",
eps=1e-8,
):
super(DetectabilityLoss, self).__init__()
self.detectability = Detectability(
frame_size=frame_size,
sampling_rate=sampling_rate,
taps=taps,
dbspl=dbspl,
spl=spl,
relax_threshold=relax_threshold,
norm=norm,
)
self.ca = self.detectability.ca
self.cs = self.detectability.cs
self.frame_size = self.detectability.frame_size
self.taps = self.detectability.taps
self.leff = self.detectability.leff
self.norm = self.detectability.norm
self.h = tc.from_numpy(self.detectability.h)
self.g = tc.from_numpy(self.detectability.g)
self.G = tc.from_numpy(self.detectability.h) * tc.from_numpy(
self.detectability.g
).unsqueeze(0)
self.reduction = reduction
self.eps = eps
self.normalize_gain = normalize_gain

def _spectrum(self, a):
return tc.pow(tc.abs(tc.fft.rfft(a, axis=1, norm=self.norm)), 2.0)

def _masker_power_array(self, a):
return tc.sum(a.unsqueeze(1) * self.G, axis=2)

def _detectability(self, s, m, cs, ca):
return cs * self.leff * (s / (m + ca)).sum(axis=1)

def to(self, device):
super().to(device)
self.G = self.G.to(device)
self.h = self.h.to(device)
self.g = self.g.to(device)
return self

def frame(self, reference, test):
assert (
len(reference.shape) == 2 and len(test.shape) == 2
), f"only support for batched one-dimensional inputs"
assert (
reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size
), f"input frame size different the specified upon construction"

if self.normalize_gain:
e = self._spectrum(test - reference)
gain = self.gain(reference)
return tc.pow(tc.norm(gain * e, p="fro", dim=1), 2.0)

e = self._spectrum(test - reference)
x = self._spectrum(reference)
e = self._masker_power_array(e)
x = self._masker_power_array(x)

return self._detectability(e, x, self.cs, self.ca)

def frame_absolute(self, reference, test):
assert (
len(reference.shape) == 2 and len(test.shape) == 2
), f"only support for batched one-dimensional inputs"
assert (
reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size
), f"input frame size different the specified upon construction"

t = self._spectrum(test)
x = self._spectrum(reference)
t = self._masker_power_array(t)
x = self._masker_power_array(x)

return self._detectability(t, x, self.cs, self.ca)

def gain(self, reference):
assert (
len(reference.shape) == 2
), f"only support for batched one-dimensional inputs"
assert (
reference.shape[1] == self.frame_size
), f"input frame size different the specified upon construction"

x = self._spectrum(reference)
x = self._masker_power_array(x)
numer = (self.cs * self.leff * self.h * self.g).unsqueeze(0)
denom = (x + self.ca).unsqueeze(-1)
G = numer / denom
gain = G.sum(axis=1).sqrt()

if self.normalize_gain:
factor = tc.norm(gain, p="fro", dim=1).unsqueeze(-1)
gain = gain / factor

return gain

def forward(self, reference, test):
batches = self.frame(reference, test)

if self.reduction == "mean":
return batches.mean()

if self.reduction == "meanlog":
batches = tc.log(batches + self.eps)
return batches.mean()

if self.reduction == None:
return batches
129 changes: 129 additions & 0 deletions libdetectability/detectability_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import numpy as np
import scipy as sp
import torch as torch

from .detectability import Detectability
from .internal.gammatone_filterbank import gammatone_filterbank
from .internal.outer_middle_ear_filter import outer_middle_ear_filter


class DetectabilityLoss(torch.nn.Module):
def __init__(
self,
frame_size=2048,
sampling_rate=48000,
taps=32,
dbspl=94.0,
spl=1.0,
threshold_mode="hearing",
normalize_gain=False,
norm="backward",
reduction="mean",
eps=1e-8,
):
super(DetectabilityLoss, self).__init__()
self.detectability = Detectability(
frame_size=frame_size,
sampling_rate=sampling_rate,
taps=taps,
dbspl=dbspl,
spl=spl,
threshold_mode=threshold_mode,
normalize_gain=normalize_gain,
norm=norm,
)
self.ca = self.detectability.ca
self.cs = self.detectability.cs
self.frame_size = self.detectability.frame_size
self.taps = self.detectability.taps
self.leff = self.detectability.leff
self.norm = self.detectability.norm
self.h = torch.from_numpy(self.detectability.h)
self.g = torch.from_numpy(self.detectability.g)
self.G = torch.from_numpy(self.detectability.h) * torch.from_numpy(
self.detectability.g
).unsqueeze(0)
self.reduction = reduction
self.eps = eps
self.normalize_gain = normalize_gain

def _spectrum(self, a):
return torch.pow(torch.abs(torch.fft.rfft(a, axis=1, norm=self.norm)), 2.0)

def _masker_power_array(self, a):
return torch.sum(a.unsqueeze(1) * self.G, axis=2)

def _detectability(self, s, m, cs, ca):
return cs * self.leff * (s / (m + ca)).sum(axis=1)

def to(self, device):
super().to(device)
self.G = self.G.to(device)
self.h = self.h.to(device)
self.g = self.g.to(device)
return self

def frame(self, reference, test):
assert (
len(reference.shape) == 2 and len(test.shape) == 2
), f"only support for batorchhed one-dimensional inputs"
assert (
reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size
), f"input frame size different the specified upon construction"

if self.normalize_gain:
e = self._spectrum(test - reference)
gain = self.gain(reference)
return torch.pow(torch.norm(gain * e, p="fro", dim=1), 2.0)

e = self._spectrum(test - reference)
x = self._spectrum(reference)
e = self._masker_power_array(e)
x = self._masker_power_array(x)

return self._detectability(e, x, self.cs, self.ca)

def frame_absolute(self, reference, test):
assert (
len(reference.shape) == 2 and len(test.shape) == 2
), f"only support for batorchhed one-dimensional inputs"
assert (
reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size
), f"input frame size different the specified upon construction"

t = self._spectrum(test)
x = self._spectrum(reference)
t = self._masker_power_array(t)
x = self._masker_power_array(x)

return self._detectability(t, x, self.cs, self.ca)

def gain(self, reference):
assert (
len(reference.shape) == 2
), f"only support for batorchhed one-dimensional inputs"
assert (
reference.shape[1] == self.frame_size
), f"input frame size different the specified upon construction"

x = self._spectrum(reference)
x = self._masker_power_array(x)
numer = (self.cs * self.leff * self.h * self.g).unsqueeze(0)
denom = (x + self.ca).unsqueeze(-1)
G = numer / denom
gain = G.sum(axis=1).sqrt()

if self.normalize_gain:
factor = torch.norm(gain, p="fro", dim=1).unsqueeze(-1)
gain = gain / factor

return gain

def forward(self, reference, test):
batches = self.frame(reference, test)

if self.reduction == "mean":
return batches.mean()

if self.reduction == None:
return batches
25 changes: 20 additions & 5 deletions libdetectability/internal/outer_middle_ear_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
from .threshold_in_quiet import threshold_in_quiet


def outer_middle_ear_filter(
frame_size, spl, dbspl, sampling_rate, relax_threshold=False
):
if not relax_threshold:
def outer_middle_ear_filter(frame_size, spl, dbspl, sampling_rate, threshold_mode):
if threshold_mode == "relaxed":
return np.array(
[
1.0 / threshold_in_quiet(f, spl, dbspl)
for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate))
]
)
else:

if threshold_mode == "hearing":
threshold = np.array(
[
threshold_in_quiet(f, spl, dbspl)
Expand All @@ -25,3 +24,19 @@ def outer_middle_ear_filter(
for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate))
]
)

if threshold_mode == "hearing_regularized":
threshold = np.array(
[
threshold_in_quiet(f, spl, dbspl, regularized=True)
for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate))
]
)
return np.array(
[
1.0 / np.min(threshold)
for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate))
]
)

raise (f"Invalid 'threshold_mode' {threshold_mode}")
6 changes: 4 additions & 2 deletions libdetectability/internal/threshold_in_quiet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .threshold_in_quiet_db import threshold_in_quiet_db


def threshold_in_quiet(freq, spl, dbspl):
def threshold_in_quiet(freq, spl, dbspl, regularized=False):
offset = dbspl - 20 * np.log10(spl)
return np.power(10.0, (threshold_in_quiet_db(freq) - offset) / 20.0)
return np.power(
10.0, (threshold_in_quiet_db(freq, regularized=regularized) - offset) / 20.0
)
9 changes: 7 additions & 2 deletions libdetectability/internal/threshold_in_quiet_db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import numpy as np


def threshold_in_quiet_db(freq):
return (
def threshold_in_quiet_db(freq, regularized=False):
value = (
3.64 * np.power(freq / 1000.0, -0.8)
- 6.5 * np.exp(-0.6 * np.power(freq / 1000.0 - 3.3, 2))
+ 10e-4 * np.power(freq / 1000.0, 4)
)

if regularized:
return np.maximum(value, 30)
else:
return value
Loading

0 comments on commit 2c80978

Please sign in to comment.