From 2c80978b3e4fa26f73fbab30e07317aa3635fb21 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 15 Nov 2024 10:21:39 +0100 Subject: [PATCH] upgrades --- libdetectability/__init__.py | 4 +- libdetectability/detectability.py | 131 +----------------- libdetectability/detectability_loss.py | 129 +++++++++++++++++ .../internal/outer_middle_ear_filter.py | 25 +++- .../internal/threshold_in_quiet.py | 6 +- .../internal/threshold_in_quiet_db.py | 9 +- ...tability.py => segmented_detectability.py} | 62 +++++---- libdetectability/test.py | 55 +++++++- .../testSegmenterDetectability.py | 47 ------- setup.py | 4 +- validation/test_outer_middle_ear_filter.py | 15 +- 11 files changed, 267 insertions(+), 220 deletions(-) create mode 100644 libdetectability/detectability_loss.py rename libdetectability/{segmentedDetectability.py => segmented_detectability.py} (53%) delete mode 100644 libdetectability/testSegmenterDetectability.py diff --git a/libdetectability/__init__.py b/libdetectability/__init__.py index 451224d..6b8b25e 100644 --- a/libdetectability/__init__.py +++ b/libdetectability/__init__.py @@ -1 +1,3 @@ -from .detectability import Detectability, DetectabilityLoss +from .detectability import Detectability +from .detectability_loss import DetectabilityLoss +from .segmented_detectability import SegmentedDetectability diff --git a/libdetectability/detectability.py b/libdetectability/detectability.py index 3008018..923fe3f 100644 --- a/libdetectability/detectability.py +++ b/libdetectability/detectability.py @@ -1,6 +1,6 @@ import numpy as np import scipy as sp -import torch as tc +import torch as torch from .internal.gammatone_filterbank import gammatone_filterbank from .internal.outer_middle_ear_filter import outer_middle_ear_filter @@ -14,7 +14,7 @@ def __init__( taps=32, dbspl=94.0, spl=1.0, - relax_threshold=False, + threshold_mode="hearing", normalize_gain=False, norm="backward", ): @@ -42,7 +42,7 @@ def __init__( self.spl, self.dbspl, self.sampling_rate, - relax_threshold=relax_threshold, + thershold_mode=threshold_mode, ) ), 2.0, @@ -152,128 +152,3 @@ def gain(self, reference): factor = np.linalg.norm(gain, ord=2, axis=0) gain = gain / factor return gain - - -class DetectabilityLoss(tc.nn.Module): - def __init__( - self, - frame_size=2048, - sampling_rate=48000, - taps=32, - dbspl=94.0, - spl=1.0, - relax_threshold=False, - normalize_gain=False, - norm="backward", - reduction="meanlog", - eps=1e-8, - ): - super(DetectabilityLoss, self).__init__() - self.detectability = Detectability( - frame_size=frame_size, - sampling_rate=sampling_rate, - taps=taps, - dbspl=dbspl, - spl=spl, - relax_threshold=relax_threshold, - norm=norm, - ) - self.ca = self.detectability.ca - self.cs = self.detectability.cs - self.frame_size = self.detectability.frame_size - self.taps = self.detectability.taps - self.leff = self.detectability.leff - self.norm = self.detectability.norm - self.h = tc.from_numpy(self.detectability.h) - self.g = tc.from_numpy(self.detectability.g) - self.G = tc.from_numpy(self.detectability.h) * tc.from_numpy( - self.detectability.g - ).unsqueeze(0) - self.reduction = reduction - self.eps = eps - self.normalize_gain = normalize_gain - - def _spectrum(self, a): - return tc.pow(tc.abs(tc.fft.rfft(a, axis=1, norm=self.norm)), 2.0) - - def _masker_power_array(self, a): - return tc.sum(a.unsqueeze(1) * self.G, axis=2) - - def _detectability(self, s, m, cs, ca): - return cs * self.leff * (s / (m + ca)).sum(axis=1) - - def to(self, device): - super().to(device) - self.G = self.G.to(device) - self.h = self.h.to(device) - self.g = self.g.to(device) - return self - - def frame(self, reference, test): - assert ( - len(reference.shape) == 2 and len(test.shape) == 2 - ), f"only support for batched one-dimensional inputs" - assert ( - reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size - ), f"input frame size different the specified upon construction" - - if self.normalize_gain: - e = self._spectrum(test - reference) - gain = self.gain(reference) - return tc.pow(tc.norm(gain * e, p="fro", dim=1), 2.0) - - e = self._spectrum(test - reference) - x = self._spectrum(reference) - e = self._masker_power_array(e) - x = self._masker_power_array(x) - - return self._detectability(e, x, self.cs, self.ca) - - def frame_absolute(self, reference, test): - assert ( - len(reference.shape) == 2 and len(test.shape) == 2 - ), f"only support for batched one-dimensional inputs" - assert ( - reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size - ), f"input frame size different the specified upon construction" - - t = self._spectrum(test) - x = self._spectrum(reference) - t = self._masker_power_array(t) - x = self._masker_power_array(x) - - return self._detectability(t, x, self.cs, self.ca) - - def gain(self, reference): - assert ( - len(reference.shape) == 2 - ), f"only support for batched one-dimensional inputs" - assert ( - reference.shape[1] == self.frame_size - ), f"input frame size different the specified upon construction" - - x = self._spectrum(reference) - x = self._masker_power_array(x) - numer = (self.cs * self.leff * self.h * self.g).unsqueeze(0) - denom = (x + self.ca).unsqueeze(-1) - G = numer / denom - gain = G.sum(axis=1).sqrt() - - if self.normalize_gain: - factor = tc.norm(gain, p="fro", dim=1).unsqueeze(-1) - gain = gain / factor - - return gain - - def forward(self, reference, test): - batches = self.frame(reference, test) - - if self.reduction == "mean": - return batches.mean() - - if self.reduction == "meanlog": - batches = tc.log(batches + self.eps) - return batches.mean() - - if self.reduction == None: - return batches diff --git a/libdetectability/detectability_loss.py b/libdetectability/detectability_loss.py new file mode 100644 index 0000000..c9868b2 --- /dev/null +++ b/libdetectability/detectability_loss.py @@ -0,0 +1,129 @@ +import numpy as np +import scipy as sp +import torch as torch + +from .detectability import Detectability +from .internal.gammatone_filterbank import gammatone_filterbank +from .internal.outer_middle_ear_filter import outer_middle_ear_filter + + +class DetectabilityLoss(torch.nn.Module): + def __init__( + self, + frame_size=2048, + sampling_rate=48000, + taps=32, + dbspl=94.0, + spl=1.0, + threshold_mode="hearing", + normalize_gain=False, + norm="backward", + reduction="mean", + eps=1e-8, + ): + super(DetectabilityLoss, self).__init__() + self.detectability = Detectability( + frame_size=frame_size, + sampling_rate=sampling_rate, + taps=taps, + dbspl=dbspl, + spl=spl, + threshold_mode=threshold_mode, + normalize_gain=normalize_gain, + norm=norm, + ) + self.ca = self.detectability.ca + self.cs = self.detectability.cs + self.frame_size = self.detectability.frame_size + self.taps = self.detectability.taps + self.leff = self.detectability.leff + self.norm = self.detectability.norm + self.h = torch.from_numpy(self.detectability.h) + self.g = torch.from_numpy(self.detectability.g) + self.G = torch.from_numpy(self.detectability.h) * torch.from_numpy( + self.detectability.g + ).unsqueeze(0) + self.reduction = reduction + self.eps = eps + self.normalize_gain = normalize_gain + + def _spectrum(self, a): + return torch.pow(torch.abs(torch.fft.rfft(a, axis=1, norm=self.norm)), 2.0) + + def _masker_power_array(self, a): + return torch.sum(a.unsqueeze(1) * self.G, axis=2) + + def _detectability(self, s, m, cs, ca): + return cs * self.leff * (s / (m + ca)).sum(axis=1) + + def to(self, device): + super().to(device) + self.G = self.G.to(device) + self.h = self.h.to(device) + self.g = self.g.to(device) + return self + + def frame(self, reference, test): + assert ( + len(reference.shape) == 2 and len(test.shape) == 2 + ), f"only support for batorchhed one-dimensional inputs" + assert ( + reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size + ), f"input frame size different the specified upon construction" + + if self.normalize_gain: + e = self._spectrum(test - reference) + gain = self.gain(reference) + return torch.pow(torch.norm(gain * e, p="fro", dim=1), 2.0) + + e = self._spectrum(test - reference) + x = self._spectrum(reference) + e = self._masker_power_array(e) + x = self._masker_power_array(x) + + return self._detectability(e, x, self.cs, self.ca) + + def frame_absolute(self, reference, test): + assert ( + len(reference.shape) == 2 and len(test.shape) == 2 + ), f"only support for batorchhed one-dimensional inputs" + assert ( + reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size + ), f"input frame size different the specified upon construction" + + t = self._spectrum(test) + x = self._spectrum(reference) + t = self._masker_power_array(t) + x = self._masker_power_array(x) + + return self._detectability(t, x, self.cs, self.ca) + + def gain(self, reference): + assert ( + len(reference.shape) == 2 + ), f"only support for batorchhed one-dimensional inputs" + assert ( + reference.shape[1] == self.frame_size + ), f"input frame size different the specified upon construction" + + x = self._spectrum(reference) + x = self._masker_power_array(x) + numer = (self.cs * self.leff * self.h * self.g).unsqueeze(0) + denom = (x + self.ca).unsqueeze(-1) + G = numer / denom + gain = G.sum(axis=1).sqrt() + + if self.normalize_gain: + factor = torch.norm(gain, p="fro", dim=1).unsqueeze(-1) + gain = gain / factor + + return gain + + def forward(self, reference, test): + batches = self.frame(reference, test) + + if self.reduction == "mean": + return batches.mean() + + if self.reduction == None: + return batches diff --git a/libdetectability/internal/outer_middle_ear_filter.py b/libdetectability/internal/outer_middle_ear_filter.py index f061abf..fadb769 100644 --- a/libdetectability/internal/outer_middle_ear_filter.py +++ b/libdetectability/internal/outer_middle_ear_filter.py @@ -2,17 +2,16 @@ from .threshold_in_quiet import threshold_in_quiet -def outer_middle_ear_filter( - frame_size, spl, dbspl, sampling_rate, relax_threshold=False -): - if not relax_threshold: +def outer_middle_ear_filter(frame_size, spl, dbspl, sampling_rate, threshold_mode): + if threshold_mode == "relaxed": return np.array( [ 1.0 / threshold_in_quiet(f, spl, dbspl) for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate)) ] ) - else: + + if threshold_mode == "hearing": threshold = np.array( [ threshold_in_quiet(f, spl, dbspl) @@ -25,3 +24,19 @@ def outer_middle_ear_filter( for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate)) ] ) + + if threshold_mode == "hearing_regularized": + threshold = np.array( + [ + threshold_in_quiet(f, spl, dbspl, regularized=True) + for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate)) + ] + ) + return np.array( + [ + 1.0 / np.min(threshold) + for f in np.fft.rfftfreq(frame_size, d=(1.0 / sampling_rate)) + ] + ) + + raise (f"Invalid 'threshold_mode' {threshold_mode}") diff --git a/libdetectability/internal/threshold_in_quiet.py b/libdetectability/internal/threshold_in_quiet.py index 0b8cc38..23a6ecd 100644 --- a/libdetectability/internal/threshold_in_quiet.py +++ b/libdetectability/internal/threshold_in_quiet.py @@ -2,6 +2,8 @@ from .threshold_in_quiet_db import threshold_in_quiet_db -def threshold_in_quiet(freq, spl, dbspl): +def threshold_in_quiet(freq, spl, dbspl, regularized=False): offset = dbspl - 20 * np.log10(spl) - return np.power(10.0, (threshold_in_quiet_db(freq) - offset) / 20.0) + return np.power( + 10.0, (threshold_in_quiet_db(freq, regularized=regularized) - offset) / 20.0 + ) diff --git a/libdetectability/internal/threshold_in_quiet_db.py b/libdetectability/internal/threshold_in_quiet_db.py index 7091072..b17c1ce 100644 --- a/libdetectability/internal/threshold_in_quiet_db.py +++ b/libdetectability/internal/threshold_in_quiet_db.py @@ -1,9 +1,14 @@ import numpy as np -def threshold_in_quiet_db(freq): - return ( +def threshold_in_quiet_db(freq, regularized=False): + value = ( 3.64 * np.power(freq / 1000.0, -0.8) - 6.5 * np.exp(-0.6 * np.power(freq / 1000.0 - 3.3, 2)) + 10e-4 * np.power(freq / 1000.0, 4) ) + + if regularized: + return np.maximum(value, 30) + else: + return value diff --git a/libdetectability/segmentedDetectability.py b/libdetectability/segmented_detectability.py similarity index 53% rename from libdetectability/segmentedDetectability.py rename to libdetectability/segmented_detectability.py index f82dfd6..52063b6 100644 --- a/libdetectability/segmentedDetectability.py +++ b/libdetectability/segmented_detectability.py @@ -1,8 +1,10 @@ -from .detectability import Detectability, DetectabilityLoss +from .detectability import Detectability +from .detectability_loss import DetectabilityLoss import libsegmenter import torch -class segmented_detectability: + +class SegmentedDetectability: def __init__( self, frame_size=2048, @@ -15,50 +17,60 @@ def __init__( norm="backward", ): self.detectability = DetectabilityLoss( - frame_size = frame_size, - sampling_rate = sampling_rate, - taps = taps, - dbspl = dbspl, - spl = spl, - relax_threshold = relax_threshold, - normalize_gain = normalize_gain, - norm = norm + frame_size=frame_size, + sampling_rate=sampling_rate, + taps=taps, + dbspl=dbspl, + spl=spl, + relax_threshold=relax_threshold, + normalize_gain=normalize_gain, + norm=norm, ) window = libsegmenter.hann(frame_size) assert frame_size % 2 == 0, "only evenly-sizes frames are supported" - hop_size = frame_size // 2; + hop_size = frame_size // 2 self.segmenter = libsegmenter.make_segmenter( - backend = "torch", - frame_size = frame_size, - hop_size = hop_size, - window = window, - mode = "wola", - edge_correction = False + backend="torch", + frame_size=frame_size, + hop_size=hop_size, + window=window, + mode="wola", + edge_correction=False, ) def calculate_segment_detectability_relative(self, reference_signal, test_signal): - assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape" + assert ( + reference_signal.shape == test_signal.shape + ), "reference and test signal must be of the same shape" reference_segments = self.segmenter.segment(reference_signal) test_segments = self.segmenter.segment(test_signal) number_of_batch_elements = reference_segments.shape[0] number_of_frames = reference_segments.shape[1] - detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames)) + detectability_of_segments = torch.zeros( + (number_of_batch_elements, number_of_frames) + ) for fIdx in range(0, number_of_frames): - detectability_of_segments[:, fIdx] = self.detectability.frame(reference_segments[:, fIdx], test_segments[:, fIdx]) + detectability_of_segments[:, fIdx] = self.detectability.frame( + reference_segments[:, fIdx], test_segments[:, fIdx] + ) return detectability_of_segments - def calculate_segment_detectability_absolute(self, reference_signal, test_signal): - assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape" + assert ( + reference_signal.shape == test_signal.shape + ), "reference and test signal must be of the same shape" reference_segments = self.segmenter.segment(reference_signal) test_segments = self.segmenter.segment(test_signal) number_of_batch_elements = reference_segments.shape[0] number_of_frames = reference_segments.shape[1] - detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames)) + detectability_of_segments = torch.zeros( + (number_of_batch_elements, number_of_frames) + ) for fIdx in range(0, number_of_frames): - detectability_of_segments[:, fIdx] = self.detectability.frame_absolute(reference_segments[:, fIdx], test_segments[:, fIdx]) + detectability_of_segments[:, fIdx] = self.detectability.frame_absolute( + reference_segments[:, fIdx], test_segments[:, fIdx] + ) return detectability_of_segments - diff --git a/libdetectability/test.py b/libdetectability/test.py index bbccd4d..eebfe5e 100644 --- a/libdetectability/test.py +++ b/libdetectability/test.py @@ -1,4 +1,5 @@ from .detectability import Detectability, DetectabilityLoss +from .segmentedDetectability import segmented_detectability import numpy as np import torch as tc import pytest @@ -100,9 +101,9 @@ def test_cost_old(): x = np.sin(2 * np.pi * 1000 * np.arange(2048) / 48000) y = 7.0 * np.sin(2 * np.pi * 1000 * np.arange(2048) / 48000) - print(old.detectability_gain(x, x - y), 2.0) - print(new.frame(x, y)) - print("end") + # print(old.detectability_gain(x, x - y), 2.0) + # print(new.frame(x, y)) + # print("end") def test_gain_old(): @@ -111,11 +112,53 @@ def test_gain_old(): new = Detectability() g = new.gain(x) - print(pytest.approx(np.power(np.linalg.norm(g * np.fft.rfft(x - y)), 2.0))) + # print(pytest.approx(np.power(np.linalg.norm(g * np.fft.rfft(x - y)), 2.0))) import pydetectability as pd old = pd.par_model(48000, 2048, pd.signal_pressure_mapping(1.0, 94.0)) g = old.gain(x) - print(pytest.approx(np.power(np.linalg.norm(g * np.fft.rfft(x - y)), 2.0))) - print("end") + # print(pytest.approx(np.power(np.linalg.norm(g * np.fft.rfft(x - y)), 2.0))) + # print("end") + + +def test_segmenter_detectability_relative(): + sampling_rate = 48000 + window_size = 2048 + hop_size = window_size // 2 + signal_length = 4 * window_size + 10 + batch_size = 3 + input = tc.ones((batch_size, signal_length)) + input2 = tc.ones((batch_size, signal_length)) + + segmenter = segmented_detectability( + frame_size=window_size, sampling_rate=sampling_rate + ) + + results = segmenter.calculate_segment_detectability_relative(input, input2) + max_detectability = results.numpy() + test_value = np.max(max_detectability) + ref_value = 0.0 + + assert test_value == pytest.approx(ref_value) + + +def test_segmenter_detectability_absolute(): + sampling_rate = 48000 + window_size = 2048 + hop_size = window_size // 2 + signal_length = 4 * window_size + 10 + batch_size = 3 + input = tc.ones((batch_size, signal_length)) + input2 = tc.zeros((batch_size, signal_length)) + + segmenter = segmented_detectability( + frame_size=window_size, sampling_rate=sampling_rate + ) + + results = segmenter.calculate_segment_detectability_absolute(input, input2) + max_detectability = results.numpy() + test_value = np.max(max_detectability) + ref_value = 0.0 + + assert test_value == pytest.approx(ref_value) diff --git a/libdetectability/testSegmenterDetectability.py b/libdetectability/testSegmenterDetectability.py deleted file mode 100644 index cf20c22..0000000 --- a/libdetectability/testSegmenterDetectability.py +++ /dev/null @@ -1,47 +0,0 @@ -from .segmentedDetectability import segmented_detectability -import torch as tc -import libsegmenter -import pytest -import numpy as np - -def test_segmenter_detectability_relative(): - sampling_rate = 48000 - window_size = 2048 - hop_size = window_size // 2; - - signal_length = 4*window_size+10 - batch_size = 3 - input = tc.ones((batch_size, signal_length)) - input2 = tc.ones((batch_size, signal_length)) - - segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate) - - results = segmenter.calculate_segment_detectability_relative(input, input2) - max_detectability = results.numpy() - test_value = np.max(max_detectability) - ref_value = 0.0 - - - assert test_value == pytest.approx(ref_value) - -def test_segmenter_detectability_absolute(): - sampling_rate = 48000 - window_size = 2048 - hop_size = window_size // 2; - - signal_length = 4*window_size+10 - batch_size = 3 - input = tc.ones((batch_size, signal_length)) - input2 = tc.zeros((batch_size, signal_length)) - - segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate) - - results = segmenter.calculate_segment_detectability_absolute(input, input2) - max_detectability = results.numpy() - test_value = np.max(max_detectability) - ref_value = 0.0 - - - assert test_value == pytest.approx(ref_value) - - diff --git a/setup.py b/setup.py index bc314fb..926431d 100644 --- a/setup.py +++ b/setup.py @@ -2,8 +2,8 @@ setup( name="libdetectability", - version="0.5.0", + version="0.6.0", packages=find_packages(), - install_requires=["pytest", "numpy", "scipy", "torch","libsegmenter"], + install_requires=["pytest", "numpy", "scipy", "torch", "libsegmenter"], test_suite="test", ) diff --git a/validation/test_outer_middle_ear_filter.py b/validation/test_outer_middle_ear_filter.py index 394cf46..d40b977 100644 --- a/validation/test_outer_middle_ear_filter.py +++ b/validation/test_outer_middle_ear_filter.py @@ -4,7 +4,7 @@ def test_outer_middle_ear_filter(): import matplotlib.pyplot as plt - filter = outer_middle_ear_filter(2048, 1.0, 94.0, 48000.0, relax_threshold=False) + filter = outer_middle_ear_filter(2048, 1.0, 94.0, 48000.0, threshold_mode="hearing") plt.plot(filter) plt.xlabel("Frequency") plt.ylabel("Amplitude") @@ -16,10 +16,21 @@ def test_outer_middle_ear_filter(): def test_outer_middle_ear_filter_relaxed(): import matplotlib.pyplot as plt - filter = outer_middle_ear_filter(2048, 1.0, 94.0, 48000.0, relax_threshold=True) + filter = outer_middle_ear_filter(2048, 1.0, 94.0, 48000.0, threshold_mode="relaxed") plt.plot(filter) plt.xlabel("Frequency") plt.ylabel("Amplitude") plt.tight_layout() plt.savefig("test_outer_middle_ear_filter_relaxed.png") plt.close() + +def test_outer_middle_ear_filter_hearing_regularized(): + import matplotlib.pyplot as plt + + filter = outer_middle_ear_filter(2048, 1.0, 94.0, 48000.0, threshold_mode="hearing_regularized") + plt.plot(filter) + plt.xlabel("Frequency") + plt.ylabel("Amplitude") + plt.tight_layout() + plt.savefig("test_outer_middle_ear_filter_hearing_regularized.png") + plt.close()