-
Notifications
You must be signed in to change notification settings - Fork 0
utils
Defined in utils.py
class DiarizationDataset(dataset_name=None
data_dir=None,
sr=16000,
window_len=240,
window_step=120,
transform=None,
batch_size_for_ecapa=512,
vad_step=4,
split='full',
use_precomputed_vad= True,
use_oracle_vad= False,
skip_overlap= True)
Create an abstract class for loading the dataset. This class applies the necessary pre-processing and x-vector feature extraction methods to return the audio file as a bunch of segmented x-vector features to use it directly in the clustering algorithm to predict speaker labels. The module uses the pre-computed X-vectors if available otherwise extract it during the runtime.
Parameters:
Argument | Detail |
---|---|
dataset_name: |
str, Name of the pre-existing dataset to use. Options: ami , ami_dev , voxconverse
|
data_dir: |
str, Directory for any dataset other the options specified in dataset_name . Both dataset_name and data_dir cannot be None |
sr: |
int, Sampling rate of the audio signal |
window_len: |
int, Window length (in ms) of each of the audio segments to be passed for feature extraction |
window_step: |
int, Step (in ms) between two windows of audio segments to be passed for feature extraction |
transform: |
list, List of transforms like mel-transform to be performed on audio while preprocessing, default = None |
batch_size_for_ecapa: |
int, Batch size of audio segments while performing feature extraction using ECAPA-TDNN |
vad_step: |
int, Number of windows to split each audio chunk into. Argument used by Silero-VAD module |
split: |
str, Argument defining type of split of dataset, default = 'full' indicates no split |
use_precomputed_vad: |
bool, If True, downloads precomputed Voice Activity Detection label output for the dataset. Only available for dataset options specified in dataset_name
|
use_oracle_vad: |
bool, If True, model does Voice Activity Detection directly from groundtruth rttm files bypassing the Silero VAD module. |
skip_overlap: |
bool, If True, model skips the windows with multiple speakers speaking by inspecting the groundtruth rttm files |
Class Functions:
Parameters:
Argument | Detail |
---|---|
idx: |
int, Index to the required audio in the list of audio in root directory |
Returns:
Variable | Detail |
---|---|
audio_segments: |
torch.Tensor, (n_windows, features_len) Tensor of feature vectors of each audio segment window |
diarization_segments: |
torch.Tensor, (n_windows, n_spks) Tensor containing ground truth of speaker labels, 1 if i-th window has j-th speaker speaking, else 0 |
audio_segments: |
torch.Tensor, (n_windows, features_len) Tensor of feature vectors of each audio segment window |
speech_segments: |
torch.Tensor, (n_windows,) Tensor with i-th value 1 if VAD returns presence of speech audio in i-th window, else 0 |
label_path: |
str, Path of the rttm file containing labels for the 'idx' wav audio |
Parameters:
Argument | Detail |
---|---|
path: |
str, Path to the RTTM diarization file |
Returns:
Variable | Detail |
---|---|
rttm_out: |
numpy.ndarray, (..., 3) Array with column 1 holding start time of speaker, column 2 holding end time of speaker, and column 3 holding speaker label |
def make_rttm(out_dir, name, labels, win_step):
Defined in utils.py
Create RTTM Diarization files for non-overlapping speaker labels in var labels
. Assumes non-speech part to have value -1
and speech part to have some speaker label (0, 1, 2, ...)
.
Parameters:
Argument | Detail |
---|---|
out_dir: |
str, Directory where the output RTTM diarization files to be saved |
name: |
str, name for the audio files for which diarization was predicted |
labels: |
int, Speaker/ Non-speech labels assigned to different audio segments based on the win_step used to extract feature vectors |
win_step: |
int, Step (in ms) between two windows of audio segments used for feature extraction |
Returns:
Variable | Detail |
---|---|
return variable: |
str, Path to the saved RTTM diarization file |
def get_metrics(groundtruth_path, hypothesis_path, collar=0.25, skip_overlap=True):
Defined in utils.py
Evaluate the diarization results of all the predicted RTTM files present in hypothesis directory to the grountruth RTTM files present in groundtruth directory.
Parameters:
Argument | Detail |
---|---|
groundtruth_path: |
str, directory of groundtruth rttm files |
hypothesis_path: |
str, directory of hypothesis rttm files |
collar: |
float, Duration (in seconds) of collars removed from evaluation around boundaries of reference segments |
skip_overlap: |
bool, If True, calculates Diarization Error Rate ignoring the overlapped region |
Returns:
Variable | Detail |
---|---|
metric: |
pyannote.metrics, Pyannote metric class having diarization DERs for all the files. |
def plot_annot(name="IS1009a", collar=0.25, skip_overlap=True, groundtruth_path=None, hypothesis_path=None):
Defined in utils.py
Calculate the Diarization Error Rate for filename specified, and print the groundtruth and hypothesis time series plot.
Parameters:
Argument | Detail |
---|---|
name: |
str, Name of the file whose time series plot is to be generated. File must be present in the hypothesis_path folder |
collar: |
float, Duration (in seconds) of collars removed from evaluation around boundaries of reference segments |
skip_overlap: |
bool, If True, calculates Diarization Error Rate ignoring the overlapped region |
groundtruth_path: |
str, Directory of groundtruth rttm files |
hypothesis_path: |
str, Directory of hypothesis rttm files |