Aditya Singh edited this page May 12, 2021

Defined in:


class DiarizationDataset()

Defined in

class DiarizationDataset(dataset_name=None
                 use_precomputed_vad= True,
                 use_oracle_vad= False,
                 skip_overlap= True)

Create an abstract class for loading the dataset. This class applies the necessary pre-processing and x-vector feature extraction methods to return the audio file as a bunch of segmented x-vector features to use it directly in the clustering algorithm to predict speaker labels. The module uses the pre-computed X-vectors if available otherwise extract it during the runtime.


Argument Detail
dataset_name: str, Name of the pre-existing dataset to use. Options: ami, ami_dev, voxconverse
data_dir: str, Directory for any dataset other the options specified in dataset_name.
Both dataset_name and data_dir cannot be None
sr: int, Sampling rate of the audio signal
window_len: int, Window length (in ms) of each of the audio segments to be passed for feature extraction
window_step: int, Step (in ms) between two windows of audio segments to be passed for feature extraction
transform: list, List of transforms like mel-transform to be performed on audio while preprocessing,
default = None
batch_size_for_ecapa: int, Batch size of audio segments while performing feature extraction using ECAPA-TDNN
vad_step: int, Number of windows to split each audio chunk into. Argument used by Silero-VAD module
split: str, Argument defining type of split of dataset,
default = 'full' indicates no split
use_precomputed_vad: bool, If True, downloads precomputed Voice Activity Detection label output for the dataset. Only available for dataset options specified in dataset_name
use_oracle_vad: bool, If True, model does Voice Activity Detection directly from groundtruth rttm files bypassing the Silero VAD module.
skip_overlap: bool, If True, model skips the windows with multiple speakers speaking by inspecting the groundtruth rttm files

Class Functions:

  1. __getitem__: def __getitem__(self, idx)


Argument Detail
idx: int, Index to the required audio in the list of audio in root directory


Variable Detail
audio_segments: torch.Tensor, (n_windows, features_len) Tensor of feature vectors of each audio segment window
diarization_segments: torch.Tensor, (n_windows, n_spks) Tensor containing ground truth of speaker labels,
1 if i-th window has j-th speaker speaking, else 0
audio_segments: torch.Tensor, (n_windows, features_len) Tensor of feature vectors of each audio segment window
speech_segments: torch.Tensor, (n_windows,) Tensor with i-th value 1 if VAD returns presence of speech audio in i-th window, else 0
label_path: str, Path of the rttm file containing labels for the 'idx' wav audio
  1. read_rttm: def read_rttm(self, path)


Argument Detail
path: str, Path to the RTTM diarization file


Variable Detail
rttm_out: numpy.ndarray, (..., 3) Array with column 1 holding start time of speaker, column 2 holding end time of speaker, and column 3 holding speaker label

def make_rttm()

def make_rttm(out_dir, name, labels, win_step):

Defined in

Create RTTM Diarization files for non-overlapping speaker labels in var labels. Assumes non-speech part to have value -1 and speech part to have some speaker label (0, 1, 2, ...).


Argument Detail
out_dir: str, Directory where the output RTTM diarization files to be saved
name: str, name for the audio files for which diarization was predicted
labels: int, Speaker/ Non-speech labels assigned to different audio segments based on the win_step used to extract feature vectors
win_step: int, Step (in ms) between two windows of audio segments used for feature extraction


Variable Detail
return variable: str, Path to the saved RTTM diarization file

def get_metrics()

def get_metrics(groundtruth_path, hypothesis_path, collar=0.25, skip_overlap=True):

Defined in

Evaluate the diarization results of all the predicted RTTM files present in hypothesis directory to the grountruth RTTM files present in groundtruth directory.


Argument Detail
groundtruth_path: str, directory of groundtruth rttm files
hypothesis_path: str, directory of hypothesis rttm files
collar: float, Duration (in seconds) of collars removed from evaluation around boundaries of reference segments
skip_overlap: bool, If True, calculates Diarization Error Rate ignoring the overlapped region


Variable Detail
metric: pyannote.metrics, Pyannote metric class having diarization DERs for all the files.

def plot_annot()

def plot_annot(name="IS1009a", collar=0.25, skip_overlap=True, groundtruth_path=None, hypothesis_path=None):

Defined in

Calculate the Diarization Error Rate for filename specified, and print the groundtruth and hypothesis time series plot.


Argument Detail
name: str, Name of the file whose time series plot is to be generated. File must be present in the hypothesis_path folder
collar: float, Duration (in seconds) of collars removed from evaluation around boundaries of reference segments
skip_overlap: bool, If True, calculates Diarization Error Rate ignoring the overlapped region
groundtruth_path: str, Directory of groundtruth rttm files
hypothesis_path: str, Directory of hypothesis rttm files