Skip to content

Commit

Permalink
changes for audio PR after reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiltseb committed Mar 12, 2024
1 parent 9ac4c80 commit 61dc28f
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 132 deletions.
20 changes: 0 additions & 20 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,6 @@
],
streaming=True,
),
Endpoint(
name="whisper_transcribe_batch",
path="/video/transcribe_batch",
summary="Transcribe a batch of videos using Whisper Medium",
outputs=[
EndpointOutput(
name="transcription",
output="videos_transcriptions_whisper_medium",
),
EndpointOutput(
name="segments",
output="videos_transcriptions_segments_whisper_medium",
),
EndpointOutput(
name="info",
output="videos_transcriptions_info_whisper_medium",
),
EndpointOutput(name="transcription_id", output="transcription_id"),
],
),
Endpoint(
name="load_transcription",
path="/video/get_transcription",
Expand Down
1 change: 0 additions & 1 deletion aana/deployments/vad_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class VadConfig(BaseModel):
"""

model: str = Field(
# default="https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin",
description="The VAD model url.",
)

Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ async def transcribe_stream(
transcription=asr_transcription,
)

@test_cache
async def transcribe_batch(
self, audio_batch: list[Audio], params: WhisperParams = None
self, audio_batch: list[Audio], params: WhisperParams | None = None
) -> WhisperBatchOutput:
"""Transcribe the batch of audios with the Whisper model.
Expand Down
52 changes: 39 additions & 13 deletions aana/models/core/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io, gc
import itertools
from pathlib import Path
from collections.abc import Generator
import torch, decord # noqa: F401 # See https://github.com/dmlc/decord/issues/263
from decord import DECORDError
import numpy as np
Expand All @@ -19,7 +20,7 @@ class AbstractAudioLibrary:

@classmethod
def read_file(cls, path: Path) -> np.ndarray:
"""Read an audio file from path.
"""Read an audio file from path and return as numpy audio array.
Args:
path (Path): The path of the file to read.
Expand Down Expand Up @@ -99,9 +100,9 @@ def read_file(cls, path: Path, sample_rate=16000) -> np.ndarray:

with av.open(str(path), mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)
frames = ignore_invalid_frames(frames)
frames = group_frames(frames, 500000)
frames = resample_frames(frames, resampler)

for frame in frames:
array = frame.to_ndarray()
Expand All @@ -119,19 +120,15 @@ def read_file(cls, path: Path, sample_rate=16000) -> np.ndarray:
return audio

@classmethod
def read_from_bytes(cls, content: bytes, sample_rate=16000) -> np.ndarray:
def read_from_bytes(cls, content: bytes) -> np.ndarray:
"""Read audio bytes as numpy array.
Args:
content (bytes): The content of the file to read.
sample_rate (int): sample rate of the audio
Returns:
np.ndarray: The file as a numpy array.
"""
# Open the audio stream
# container = av.open(BytesIO(content))
# frames = av.AudioFrame.from_ndarray(np.zeros(0, dtype=np.int16), format="s16")
# frames.planes[0].buffer = content
# Create an in-memory file-like object
content_io = io.BytesIO(content)

Expand Down Expand Up @@ -204,7 +201,18 @@ def write_audio_bytes(cls, path: Path, audio: bytes, sample_rate=16000):
wav_file.writeframes(audio)


def _ignore_invalid_frames(frames):
def ignore_invalid_frames(frames: Generator) -> Generator:
"""Filter out invalid frames from the input generator.
Args:
frames (Generator): The input generator of frames.
Yields:
av.audio.frame.AudioFrame: Valid audio frames.
Raises:
StopIteration: When the input generator is exhausted.
"""
iterator = iter(frames)

while True:
Expand All @@ -216,7 +224,16 @@ def _ignore_invalid_frames(frames):
continue


def _group_frames(frames, num_samples=None):
def group_frames(frames: Generator, num_samples: int | None = None) -> Generator:
"""Group audio frames and yield groups of frames based on the specified number of samples.
Args:
frames (Generator): The input generator of audio frames.
num_samples (int | None): The target number of samples for each group.
Yields:
av.audio.frame.AudioFrame: Grouped audio frames.
"""
fifo = av.audio.fifo.AudioFifo()

for frame in frames:
Expand All @@ -230,7 +247,16 @@ def _group_frames(frames, num_samples=None):
yield fifo.read()


def _resample_frames(frames, resampler):
def resample_frames(frames: Generator, resampler) -> Generator:
"""Resample audio frames using the provided resampler.
Args:
frames (Generator): The input generator of audio frames.
resampler: The audio resampler.
Yields:
av.audio.frame.AudioFrame: Resampled audio frames.
"""
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)
Expand Down
3 changes: 3 additions & 0 deletions aana/tests/deployments/test_vad_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def setup_vad_deployment(setup_deployment, request):
return name, deployment, *setup_deployment(deployment, bind=True)


# Issue: test silent audio (add expected files): https://github.com/mobiusml/aana_sdk/issues/77


@pytest.mark.skipif(
not is_gpu_available() and not is_using_deployment_cache(),
reason="GPU is not available",
Expand Down
17 changes: 4 additions & 13 deletions aana/tests/deployments/test_whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def setup_whisper_deployment(setup_deployment, request):
return name, deployment, *setup_deployment(deployment, bind=True)


# Issue: test silent audio (add expected files): https://github.com/mobiusml/aana_sdk/issues/77
@pytest.mark.skipif(
not is_gpu_available() and not is_using_deployment_cache(),
reason="GPU is not available",
Expand Down Expand Up @@ -78,14 +79,15 @@ async def test_whisper_deployment(setup_whisper_deployment, audio_file):
audio = Audio(path=path, media_id=audio_file)

output = await handle.transcribe.remote(
media=audio, params=WhisperParams(word_timestamps=True, temperature=0.0)
audio=audio, params=WhisperParams(word_timestamps=True, temperature=0.0)
)
output = pydantic_to_dict(output)

compare_transcriptions(expected_output, output)

# Test transcribe_stream method)
stream = handle.options(stream=True).transcribe_stream.remote(
media=audio, params=WhisperParams(word_timestamps=True, temperature=0.0)
audio=audio, params=WhisperParams(word_timestamps=True, temperature=0.0)
)

# Combine individual segments and compare with the final dict
Expand All @@ -102,17 +104,6 @@ async def test_whisper_deployment(setup_whisper_deployment, audio_file):
compare_transcriptions(expected_output, dict(grouped_dict))

# Test transcribe_batch method
audios = [audio, audio]

batch_output = await handle.transcribe_batch.remote(
media_batch=audios,
params=WhisperParams(word_timestamps=True, temperature=0.0),
)
batch_output = pydantic_to_dict(batch_output)

for i in range(len(audios)):
output = {k: v[i] for k, v in batch_output.items()}
compare_transcriptions(expected_output, output)

# Test transcribe_in_chunks method: Note that the expected asr output is different
expectd_batched_output_path = resources.path(
Expand Down
32 changes: 22 additions & 10 deletions aana/tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def test_audio():
audio.cleanup()

# TODO: Test creation of audio object from a URL
# Test creation from content
try:
path = resources.path("aana.tests.files.audios", "physicsworks.wav")
content = path.read_bytes()
audio = Audio(content=content, save_on_disk=False)
assert audio.path is None
assert audio.content == content
assert audio.url is None

assert audio.get_content() == content
finally:
audio.cleanup()

# Test creation from content
try:
Expand Down Expand Up @@ -62,16 +74,14 @@ def test_save_audio():
audio.cleanup()
assert audio.path.exists() # Cleanup should NOT delete the file if path is provided

# TODO: Test saving from URL to disk (load audio physicsworks.wav to aws for audio)

# Test saving from content to disk
# Test saving from video URL to disk and read
try:
path = resources.path("aana.tests.files.audios", "physicsworks.wav")
content = path.read_bytes()
audio = Audio(content=content, save_on_disk=True)
assert audio.content == content
assert audio.url is None
assert audio.path.exists()
url = "https://mobius-public.s3.eu-west-1.amazonaws.com/squirrel.mp4"

audio = Audio(url=url, save_on_disk=True)
assert audio.content is None
assert audio.url is not None
assert audio.get_numpy() is not None
finally:
audio.cleanup()

Expand Down Expand Up @@ -150,7 +160,7 @@ def test_extract_audio():
assert audio.content is not None
assert audio.path.exists()
assert video.media_id == video_input.media_id
assert audio.media_id == video.media_id
assert audio.media_id == "audio_" + video.media_id

finally:
video.cleanup()
Expand All @@ -174,3 +184,5 @@ def test_extract_audio():
# but the audio content should be the same
assert extracted_audio_1.path != extracted_audio_2.path
assert (extracted_audio_1.get_numpy().all()) == extracted_audio_2.get_numpy().all()

# Issue: test silent audio (add empty bytes expected result): https://github.com/mobiusml/aana_sdk/issues/77
1 change: 1 addition & 0 deletions aana/utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
device (torch.device | None): Device to perform the segmentation.
fscore (bool): Flag indicating whether to compute F-score during inference.
use_auth_token (str | None): Optional authentication token for model access.
inference_kwargs (dict): Optional additional arguments from VoiceActivityDetection pipeline.
"""
super().__init__(
segmentation=segmentation,
Expand Down
4 changes: 1 addition & 3 deletions aana/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def download_model(url: str, model_path: Path | None = None, check_sum=True) ->
model_path = Path(model_dir) / "pytorch_model.bin"

if Path(model_path).exists() and not Path(model_path).is_file():
raise RuntimeError(
f"Not a regular file: {model_path}"
) # exists and is not a regular file # noqa: TRY003
raise RuntimeError(f"Not a regular file: {model_path}") # noqa: TRY003

if not Path(model_path).exists():
try:
Expand Down
Loading

0 comments on commit 61dc28f

Please sign in to comment.