Skip to content

Commit

Permalink
fixed code for getting embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
mikhmed-nabiev committed Apr 3, 2024
1 parent c55dd92 commit e5f4cfc
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/mylib/utils/get_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,33 @@
'sr': data['fs']
}
rate = data['sr']
# data['data'] = data['data'][:5 * data['sr']] # 5 секунд
scaled = np.int16(data['data'] / np.max(np.abs(data['data'])) * 32767)
write(path_to_save, rate, scaled)

print("Getting embeddings")
path_to_embedds = os.path.join(experiment_folder, "../../code/embeddings")
os.makedirs(path_to_embedds, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
for asr_model_name in ["Clementapa/wav2vec2-base-960h-phoneme-reco-dutch"]:
if "wav2vec" in asr_model_name:
path_asr_embedds = os.path.join(path_to_embedds, "wav2vec")
os.makedirs(path_asr_embedds, exist_ok=True)
print("Wav2Vec Embeddings")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(asr_model_name)
asr_model = Wav2Vec2Model.from_pretrained(asr_model_name)
asr_model = Wav2Vec2Model.from_pretrained(asr_model_name).to(device)

for audio in tqdm(glob.glob(os.path.join(path_to_audio, "*.wav"))):
if os.path.exists(os.path.join(path_asr_embedds, os.path.basename(audio))):
continue
input_audio, sr = librosa.load(audio, sr=16000)
sr = int(sr)
embed = []
for j in range(0, len(input_audio), sr):
part = input_audio[j: j + sr]
i = feature_extractor(part, return_tensors='pt', sampling_rate=sr).to(device)


i = feature_extractor(input_audio, return_tensors='pt', sampling_rate=sr)
with torch.no_grad():
output = asr_model(i.input_values)

np.save(os.path.join(path_asr_embedds, os.path.basename(audio)), output.last_hidden_state.numpy())

with torch.no_grad():
output = asr_model(i.input_values)
embed.append(output.last_hidden_state.cpu().numpy())
embed = np.concatenate(embed, axis=1)
np.save(os.path.join(path_asr_embedds, os.path.basename(audio)), embed)

0 comments on commit e5f4cfc

Please sign in to comment.