Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bo-nemk committed Feb 14, 2024
1 parent dffce64 commit 9ff7f6d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
36 changes: 29 additions & 7 deletions libdetectability/detectability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from .internal.outer_middle_ear_filter import outer_middle_ear_filter

class Detectability:
def __init__(self, frame_size=2048, sampling_rate=48000, taps=64, dbspl=94.0, spl=1.0, relax_threshold=False, norm="backward"):
def __init__(self,
frame_size=2048,
sampling_rate=48000,
taps=64,
dbspl=94.0,
spl=1.0,
relax_threshold=False,
norm="backward"):
assert frame_size % 2 == 0, "only evenly-sized frames are supported"
self.frame_size = frame_size
self.freq_size = frame_size // 2 + 1
Expand All @@ -18,7 +25,9 @@ def __init__(self, frame_size=2048, sampling_rate=48000, taps=64, dbspl=94.0, sp

# prealloc
self.g = np.power(np.abs(gammatone_filterbank(self.taps, self.frame_size, self.sampling_rate)), 2.0)
self.h = np.power(np.abs(outer_middle_ear_filter(self.frame_size, self.spl, self.dbspl, self.sampling_rate, relax_threshold=relax_threshold)), 2.0)
self.h = np.power(np.abs(
outer_middle_ear_filter(self.frame_size, self.spl, self.dbspl, self.sampling_rate, relax_threshold=relax_threshold)
), 2.0)
self.leff = min(float(self.frame_size) / float(sampling_rate) / 0.30, 1.0)

# calibration
Expand Down Expand Up @@ -53,8 +62,10 @@ def _detectability(self, s, m, cs, ca):
return cs * self.leff * (s / (m + ca)).sum()

def frame(self, reference, test):
assert reference.size == self.frame_size and test.size == self.frame_size, f"input frame size different the specified upon construction"
assert len(reference.shape) == 1 and len(test.shape) == 1, f"only support for one-dimensional inputs"
assert reference.size == self.frame_size and test.size == self.frame_size, \
f"input frame size different the specified upon construction"
assert len(reference.shape) == 1 and len(test.shape) == 1, \
f"only support for one-dimensional inputs"

e = self._spectrum(test - reference)
x = self._spectrum(reference)
Expand All @@ -64,8 +75,10 @@ def frame(self, reference, test):
return self._detectability(e, x, self.cs, self.ca)

def gain(self, reference):
assert reference.size == self.frame_size, f"input frame size different the specified upon construction"
assert len(reference.shape) == 1, f"only support for one-dimensional inputs"
assert reference.size == self.frame_size, \
f"input frame size different the specified upon construction"
assert len(reference.shape) == 1, \
f"only support for one-dimensional inputs"

x = self._spectrum(reference)
x = self._masker_power_array(x)
Expand All @@ -75,7 +88,16 @@ def gain(self, reference):
return np.sqrt(G.sum(axis=0))

class DetectabilityLoss(tc.nn.Module):
def __init__(self, frame_size=2048, sampling_rate=48000, taps=32, dbspl=94.0, spl=1.0, relax_threshold=True, norm = "backward", reduction="meanlog", eps=1e-8):
def __init__(self,
frame_size=2048,
sampling_rate=48000,
taps=64,
dbspl=94.0,
spl=1.0,
relax_threshold=False,
norm = "backward",
reduction="meanlog",
eps=1e-8):
super(DetectabilityLoss, self).__init__()
self.detectability = Detectability(frame_size=frame_size, sampling_rate=sampling_rate, taps=taps, dbspl=dbspl, \
spl=spl, relax_threshold=relax_threshold, norm=norm)
Expand Down
2 changes: 2 additions & 0 deletions libdetectability/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def test_cost():
y = tc.concatenate((y, y))

tcv = tcd.frame(x, y)
print(tcv)
print(npv)

assert tcv[0] == pytest.approx(npv)
assert tcv[1] == pytest.approx(npv)
Expand Down

0 comments on commit 9ff7f6d

Please sign in to comment.