From e5f4cfc42faf62835d9e9892b4c2e323959a1a6b Mon Sep 17 00:00:00 2001 From: Mikhmed Nabiev Date: Wed, 3 Apr 2024 17:46:19 +0300 Subject: [PATCH] fixed code for getting embeddings --- src/mylib/utils/get_embeddings.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/mylib/utils/get_embeddings.py b/src/mylib/utils/get_embeddings.py index 0149a3c..002db37 100644 --- a/src/mylib/utils/get_embeddings.py +++ b/src/mylib/utils/get_embeddings.py @@ -32,28 +32,33 @@ 'sr': data['fs'] } rate = data['sr'] - # data['data'] = data['data'][:5 * data['sr']] # 5 секунд scaled = np.int16(data['data'] / np.max(np.abs(data['data'])) * 32767) write(path_to_save, rate, scaled) print("Getting embeddings") path_to_embedds = os.path.join(experiment_folder, "../../code/embeddings") os.makedirs(path_to_embedds, exist_ok=True) +device = "cuda" if torch.cuda.is_available() else "cpu" for asr_model_name in ["Clementapa/wav2vec2-base-960h-phoneme-reco-dutch"]: if "wav2vec" in asr_model_name: path_asr_embedds = os.path.join(path_to_embedds, "wav2vec") os.makedirs(path_asr_embedds, exist_ok=True) print("Wav2Vec Embeddings") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(asr_model_name) - asr_model = Wav2Vec2Model.from_pretrained(asr_model_name) + asr_model = Wav2Vec2Model.from_pretrained(asr_model_name).to(device) for audio in tqdm(glob.glob(os.path.join(path_to_audio, "*.wav"))): + if os.path.exists(os.path.join(path_asr_embedds, os.path.basename(audio))): + continue input_audio, sr = librosa.load(audio, sr=16000) + sr = int(sr) + embed = [] + for j in range(0, len(input_audio), sr): + part = input_audio[j: j + sr] + i = feature_extractor(part, return_tensors='pt', sampling_rate=sr).to(device) - - i = feature_extractor(input_audio, return_tensors='pt', sampling_rate=sr) - with torch.no_grad(): - output = asr_model(i.input_values) - - np.save(os.path.join(path_asr_embedds, os.path.basename(audio)), output.last_hidden_state.numpy()) - + with torch.no_grad(): + output = asr_model(i.input_values) + embed.append(output.last_hidden_state.cpu().numpy()) + embed = np.concatenate(embed, axis=1) + np.save(os.path.join(path_asr_embedds, os.path.basename(audio)), embed)