-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from nielsdekoeijer/segmented_detectability
Segmenter detectability
- Loading branch information
Showing
4 changed files
with
142 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from .detectability import Detectability, DetectabilityLoss | ||
import libsegmenter | ||
import torch | ||
|
||
class segmented_detectability: | ||
def __init__( | ||
self, | ||
frame_size=2048, | ||
sampling_rate=48000, | ||
taps=32, | ||
dbspl=94.0, | ||
spl=1.0, | ||
relax_threshold=False, | ||
normalize_gain=False, | ||
norm="backward", | ||
): | ||
self.detectability = DetectabilityLoss( | ||
frame_size = frame_size, | ||
sampling_rate = sampling_rate, | ||
taps = taps, | ||
dbspl = dbspl, | ||
spl = spl, | ||
relax_threshold = relax_threshold, | ||
normalize_gain = normalize_gain, | ||
norm = norm | ||
) | ||
|
||
window = libsegmenter.hann(frame_size) | ||
assert frame_size % 2 == 0, "only evenly-sizes frames are supported" | ||
hop_size = frame_size // 2; | ||
self.segmenter = libsegmenter.make_segmenter( | ||
backend = "torch", | ||
frame_size = frame_size, | ||
hop_size = hop_size, | ||
window = window, | ||
mode = "wola", | ||
edge_correction = False | ||
) | ||
|
||
def calculate_segment_detectability_relative(self, reference_signal, test_signal): | ||
assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape" | ||
reference_segments = self.segmenter.segment(reference_signal) | ||
test_segments = self.segmenter.segment(test_signal) | ||
number_of_batch_elements = reference_segments.shape[0] | ||
number_of_frames = reference_segments.shape[1] | ||
detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames)) | ||
for fIdx in range(0, number_of_frames): | ||
detectability_of_segments[:, fIdx] = self.detectability.frame(reference_segments[:, fIdx], test_segments[:, fIdx]) | ||
|
||
return detectability_of_segments | ||
|
||
|
||
def calculate_segment_detectability_absolute(self, reference_signal, test_signal): | ||
assert reference_signal.shape == test_signal.shape, "reference and test signal must be of the same shape" | ||
reference_segments = self.segmenter.segment(reference_signal) | ||
test_segments = self.segmenter.segment(test_signal) | ||
number_of_batch_elements = reference_segments.shape[0] | ||
number_of_frames = reference_segments.shape[1] | ||
detectability_of_segments = torch.zeros((number_of_batch_elements, number_of_frames)) | ||
for fIdx in range(0, number_of_frames): | ||
detectability_of_segments[:, fIdx] = self.detectability.frame_absolute(reference_segments[:, fIdx], test_segments[:, fIdx]) | ||
|
||
return detectability_of_segments | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from .segmentedDetectability import segmented_detectability | ||
import torch as tc | ||
import libsegmenter | ||
import pytest | ||
import numpy as np | ||
|
||
def test_segmenter_detectability_relative(): | ||
sampling_rate = 48000 | ||
window_size = 2048 | ||
hop_size = window_size // 2; | ||
|
||
signal_length = 4*window_size+10 | ||
batch_size = 3 | ||
input = tc.ones((batch_size, signal_length)) | ||
input2 = tc.ones((batch_size, signal_length)) | ||
|
||
segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate) | ||
|
||
results = segmenter.calculate_segment_detectability_relative(input, input2) | ||
max_detectability = results.numpy() | ||
test_value = np.max(max_detectability) | ||
ref_value = 0.0 | ||
|
||
|
||
assert test_value == pytest.approx(ref_value) | ||
|
||
def test_segmenter_detectability_absolute(): | ||
sampling_rate = 48000 | ||
window_size = 2048 | ||
hop_size = window_size // 2; | ||
|
||
signal_length = 4*window_size+10 | ||
batch_size = 3 | ||
input = tc.ones((batch_size, signal_length)) | ||
input2 = tc.zeros((batch_size, signal_length)) | ||
|
||
segmenter = segmented_detectability(frame_size = window_size, sampling_rate = sampling_rate) | ||
|
||
results = segmenter.calculate_segment_detectability_absolute(input, input2) | ||
max_detectability = results.numpy() | ||
test_value = np.max(max_detectability) | ||
ref_value = 0.0 | ||
|
||
|
||
assert test_value == pytest.approx(ref_value) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters