diff --git a/libdetectability/detectability.py b/libdetectability/detectability.py index 7c896eb..0cbd033 100644 --- a/libdetectability/detectability.py +++ b/libdetectability/detectability.py @@ -120,6 +120,21 @@ def frame(self, reference, test): return self._detectability(e, x, self.cs, self.ca) + def frame_absolute(self, reference, test): + assert ( + reference.size == self.frame_size and test.size == self.frame_size + ), f"input frame size different the specified upon construction" + assert ( + len(reference.shape) == 1 and len(test.shape) == 1 + ), f"only support for one-dimensional inputs" + + t = self._spectrum(test) + x = self._spectrum(reference) + t = self._masker_power_array(t) + x = self._masker_power_array(x) + + return self._detectability(t, x, self.cs, self.ca) + def gain(self, reference): assert ( reference.size == self.frame_size @@ -214,6 +229,21 @@ def frame(self, reference, test): return self._detectability(e, x, self.cs, self.ca) + def frame_absolute(self, reference, test): + assert ( + len(reference.shape) == 2 and len(test.shape) == 2 + ), f"only support for batched one-dimensional inputs" + assert ( + reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size + ), f"input frame size different the specified upon construction" + + t = self._spectrum(test) + x = self._spectrum(reference) + t = self._masker_power_array(t) + x = self._masker_power_array(x) + + return self._detectability(t, x, self.cs, self.ca) + def gain(self, reference): assert ( len(reference.shape) == 2 diff --git a/libdetectability/segmentedDetectability.py b/libdetectability/segmentedDetectability.py new file mode 100644 index 0000000..f82dfd6 --- /dev/null +++ b/libdetectability/segmentedDetectability.py @@ -0,0 +1,64 @@ +from .detectability import Detectability, DetectabilityLoss +import libsegmenter +import torch + +class segmented_detectability: + def __init__( + self, + frame_size=2048, + sampling_rate=48000, + taps=32, + dbspl=94.0, + spl=1.0, + relax_threshold=False, + normalize_gain=False, + norm="backward", + ): + self.detectability = DetectabilityLoss( + frame_size = frame_size, + sampling_rate = sampling_rate, + taps = taps, + dbspl = dbspl, + spl = spl, + relax_threshold = relax_threshold, + normalize_gain = normalize_gain, + norm = norm + ) + + window = libsegmenter.hann(frame_size) + assert frame_size % 2 == 0, "only evenly-sizes frames are supported" + hop_size = frame_size // 2; + self.segmenter = libsegmenter.make_segmenter( + backend = "torch", + frame_size = frame_size, + hop_size = hop_size, + window = window, + mode = "wola", + edge_correction = False + ) + + def calculate_segment_detectability_relative(self, reference_signal, test_signal): + assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape" + reference_segments = self.segmenter.segment(reference_signal) + test_segments = self.segmenter.segment(test_signal) + number_of_batch_elements = reference_segments.shape[0] + number_of_frames = reference_segments.shape[1] + detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames)) + for fIdx in range(0, number_of_frames): + detectability_of_segments[:, fIdx] = self.detectability.frame(reference_segments[:, fIdx], test_segments[:, fIdx]) + + return detectability_of_segments + + + def calculate_segment_detectability_absolute(self, reference_signal, test_signal): + assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape" + reference_segments = self.segmenter.segment(reference_signal) + test_segments = self.segmenter.segment(test_signal) + number_of_batch_elements = reference_segments.shape[0] + number_of_frames = reference_segments.shape[1] + detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames)) + for fIdx in range(0, number_of_frames): + detectability_of_segments[:, fIdx] = self.detectability.frame_absolute(reference_segments[:, fIdx], test_segments[:, fIdx]) + + return detectability_of_segments + diff --git a/libdetectability/testSegmenterDetectability.py b/libdetectability/testSegmenterDetectability.py new file mode 100644 index 0000000..cf20c22 --- /dev/null +++ b/libdetectability/testSegmenterDetectability.py @@ -0,0 +1,47 @@ +from .segmentedDetectability import segmented_detectability +import torch as tc +import libsegmenter +import pytest +import numpy as np + +def test_segmenter_detectability_relative(): + sampling_rate = 48000 + window_size = 2048 + hop_size = window_size // 2; + + signal_length = 4*window_size+10 + batch_size = 3 + input = tc.ones((batch_size, signal_length)) + input2 = tc.ones((batch_size, signal_length)) + + segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate) + + results = segmenter.calculate_segment_detectability_relative(input, input2) + max_detectability = results.numpy() + test_value = np.max(max_detectability) + ref_value = 0.0 + + + assert test_value == pytest.approx(ref_value) + +def test_segmenter_detectability_absolute(): + sampling_rate = 48000 + window_size = 2048 + hop_size = window_size // 2; + + signal_length = 4*window_size+10 + batch_size = 3 + input = tc.ones((batch_size, signal_length)) + input2 = tc.zeros((batch_size, signal_length)) + + segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate) + + results = segmenter.calculate_segment_detectability_absolute(input, input2) + max_detectability = results.numpy() + test_value = np.max(max_detectability) + ref_value = 0.0 + + + assert test_value == pytest.approx(ref_value) + + diff --git a/setup.py b/setup.py index b1441dc..3e40cca 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,6 @@ name="libdetectability", version="0.4.2", packages=find_packages(), - install_requires=["pytest", "numpy", "scipy", "torch"], + install_requires=["pytest", "numpy", "scipy", "torch","libsegmenter"], test_suite="test", )