Skip to content

Commit

Permalink
Merge pull request #3 from nielsdekoeijer/segmented_detectability
Browse files Browse the repository at this point in the history
Segmenter detectability
  • Loading branch information
macoustics authored Apr 30, 2024
2 parents 5fa6811 + 7907327 commit cd3a0da
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 1 deletion.
30 changes: 30 additions & 0 deletions libdetectability/detectability.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,21 @@ def frame(self, reference, test):

return self._detectability(e, x, self.cs, self.ca)

def frame_absolute(self, reference, test):
assert (
reference.size == self.frame_size and test.size == self.frame_size
), f"input frame size different the specified upon construction"
assert (
len(reference.shape) == 1 and len(test.shape) == 1
), f"only support for one-dimensional inputs"

t = self._spectrum(test)
x = self._spectrum(reference)
t = self._masker_power_array(t)
x = self._masker_power_array(x)

return self._detectability(t, x, self.cs, self.ca)

def gain(self, reference):
assert (
reference.size == self.frame_size
Expand Down Expand Up @@ -214,6 +229,21 @@ def frame(self, reference, test):

return self._detectability(e, x, self.cs, self.ca)

def frame_absolute(self, reference, test):
assert (
len(reference.shape) == 2 and len(test.shape) == 2
), f"only support for batched one-dimensional inputs"
assert (
reference.shape[1] == self.frame_size and test.shape[1] == self.frame_size
), f"input frame size different the specified upon construction"

t = self._spectrum(test)
x = self._spectrum(reference)
t = self._masker_power_array(t)
x = self._masker_power_array(x)

return self._detectability(t, x, self.cs, self.ca)

def gain(self, reference):
assert (
len(reference.shape) == 2
Expand Down
64 changes: 64 additions & 0 deletions libdetectability/segmentedDetectability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from .detectability import Detectability, DetectabilityLoss
import libsegmenter
import torch

class segmented_detectability:
def __init__(
self,
frame_size=2048,
sampling_rate=48000,
taps=32,
dbspl=94.0,
spl=1.0,
relax_threshold=False,
normalize_gain=False,
norm="backward",
):
self.detectability = DetectabilityLoss(
frame_size = frame_size,
sampling_rate = sampling_rate,
taps = taps,
dbspl = dbspl,
spl = spl,
relax_threshold = relax_threshold,
normalize_gain = normalize_gain,
norm = norm
)

window = libsegmenter.hann(frame_size)
assert frame_size % 2 == 0, "only evenly-sizes frames are supported"
hop_size = frame_size // 2;
self.segmenter = libsegmenter.make_segmenter(
backend = "torch",
frame_size = frame_size,
hop_size = hop_size,
window = window,
mode = "wola",
edge_correction = False
)

def calculate_segment_detectability_relative(self, reference_signal, test_signal):
assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape"
reference_segments = self.segmenter.segment(reference_signal)
test_segments = self.segmenter.segment(test_signal)
number_of_batch_elements = reference_segments.shape[0]
number_of_frames = reference_segments.shape[1]
detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames))
for fIdx in range(0, number_of_frames):
detectability_of_segments[:, fIdx] = self.detectability.frame(reference_segments[:, fIdx], test_segments[:, fIdx])

return detectability_of_segments


def calculate_segment_detectability_absolute(self, reference_signal, test_signal):
assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape"
reference_segments = self.segmenter.segment(reference_signal)
test_segments = self.segmenter.segment(test_signal)
number_of_batch_elements = reference_segments.shape[0]
number_of_frames = reference_segments.shape[1]
detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames))
for fIdx in range(0, number_of_frames):
detectability_of_segments[:, fIdx] = self.detectability.frame_absolute(reference_segments[:, fIdx], test_segments[:, fIdx])

return detectability_of_segments

47 changes: 47 additions & 0 deletions libdetectability/testSegmenterDetectability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from .segmentedDetectability import segmented_detectability
import torch as tc
import libsegmenter
import pytest
import numpy as np

def test_segmenter_detectability_relative():
sampling_rate = 48000
window_size = 2048
hop_size = window_size // 2;

signal_length = 4*window_size+10
batch_size = 3
input = tc.ones((batch_size, signal_length))
input2 = tc.ones((batch_size, signal_length))

segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate)

results = segmenter.calculate_segment_detectability_relative(input, input2)
max_detectability = results.numpy()
test_value = np.max(max_detectability)
ref_value = 0.0


assert test_value == pytest.approx(ref_value)

def test_segmenter_detectability_absolute():
sampling_rate = 48000
window_size = 2048
hop_size = window_size // 2;

signal_length = 4*window_size+10
batch_size = 3
input = tc.ones((batch_size, signal_length))
input2 = tc.zeros((batch_size, signal_length))

segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate)

results = segmenter.calculate_segment_detectability_absolute(input, input2)
max_detectability = results.numpy()
test_value = np.max(max_detectability)
ref_value = 0.0


assert test_value == pytest.approx(ref_value)


2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
name="libdetectability",
version="0.4.2",
packages=find_packages(),
install_requires=["pytest", "numpy", "scipy", "torch"],
install_requires=["pytest", "numpy", "scipy", "torch","libsegmenter"],
test_suite="test",
)

0 comments on commit cd3a0da

Please sign in to comment.