diff --git a/libdetectability/detectability.py b/libdetectability/detectability.py index 0c8a4a2..7fc5156 100644 --- a/libdetectability/detectability.py +++ b/libdetectability/detectability.py @@ -6,7 +6,14 @@ from .internal.outer_middle_ear_filter import outer_middle_ear_filter class Detectability: - def __init__(self, frame_size=2048, sampling_rate=48000, taps=64, dbspl=94.0, spl=1.0, relax_threshold=False, norm="backward"): + def __init__(self, + frame_size=2048, + sampling_rate=48000, + taps=64, + dbspl=94.0, + spl=1.0, + relax_threshold=False, + norm="backward"): assert frame_size % 2 == 0, "only evenly-sized frames are supported" self.frame_size = frame_size self.freq_size = frame_size // 2 + 1 @@ -18,7 +25,9 @@ def __init__(self, frame_size=2048, sampling_rate=48000, taps=64, dbspl=94.0, sp # prealloc self.g = np.power(np.abs(gammatone_filterbank(self.taps, self.frame_size, self.sampling_rate)), 2.0) - self.h = np.power(np.abs(outer_middle_ear_filter(self.frame_size, self.spl, self.dbspl, self.sampling_rate, relax_threshold=relax_threshold)), 2.0) + self.h = np.power(np.abs( + outer_middle_ear_filter(self.frame_size, self.spl, self.dbspl, self.sampling_rate, relax_threshold=relax_threshold) + ), 2.0) self.leff = min(float(self.frame_size) / float(sampling_rate) / 0.30, 1.0) # calibration @@ -53,8 +62,10 @@ def _detectability(self, s, m, cs, ca): return cs * self.leff * (s / (m + ca)).sum() def frame(self, reference, test): - assert reference.size == self.frame_size and test.size == self.frame_size, f"input frame size different the specified upon construction" - assert len(reference.shape) == 1 and len(test.shape) == 1, f"only support for one-dimensional inputs" + assert reference.size == self.frame_size and test.size == self.frame_size, \ + f"input frame size different the specified upon construction" + assert len(reference.shape) == 1 and len(test.shape) == 1, \ + f"only support for one-dimensional inputs" e = self._spectrum(test - reference) x = self._spectrum(reference) @@ -64,8 +75,10 @@ def frame(self, reference, test): return self._detectability(e, x, self.cs, self.ca) def gain(self, reference): - assert reference.size == self.frame_size, f"input frame size different the specified upon construction" - assert len(reference.shape) == 1, f"only support for one-dimensional inputs" + assert reference.size == self.frame_size, \ + f"input frame size different the specified upon construction" + assert len(reference.shape) == 1, \ + f"only support for one-dimensional inputs" x = self._spectrum(reference) x = self._masker_power_array(x) @@ -75,7 +88,16 @@ def gain(self, reference): return np.sqrt(G.sum(axis=0)) class DetectabilityLoss(tc.nn.Module): - def __init__(self, frame_size=2048, sampling_rate=48000, taps=32, dbspl=94.0, spl=1.0, relax_threshold=True, norm = "backward", reduction="meanlog", eps=1e-8): + def __init__(self, + frame_size=2048, + sampling_rate=48000, + taps=64, + dbspl=94.0, + spl=1.0, + relax_threshold=False, + norm = "backward", + reduction="meanlog", + eps=1e-8): super(DetectabilityLoss, self).__init__() self.detectability = Detectability(frame_size=frame_size, sampling_rate=sampling_rate, taps=taps, dbspl=dbspl, \ spl=spl, relax_threshold=relax_threshold, norm=norm) diff --git a/libdetectability/test.py b/libdetectability/test.py index 17a41e9..5453d1c 100644 --- a/libdetectability/test.py +++ b/libdetectability/test.py @@ -18,6 +18,8 @@ def test_cost(): y = tc.concatenate((y, y)) tcv = tcd.frame(x, y) + print(tcv) + print(npv) assert tcv[0] == pytest.approx(npv) assert tcv[1] == pytest.approx(npv)