Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slice audio v3 - silero vad #36

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
5 changes: 3 additions & 2 deletions fish_audio_preprocess/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .merge_short import merge_short
from .resample import resample
from .separate_audio import separate
from .slice_audio import slice_audio, slice_audio_v2
from .slice_audio import slice_audio, slice_audio_v2, slice_audio_v3
from .transcribe import transcribe


Expand All @@ -34,11 +34,12 @@ def cli(debug: bool):
cli.add_command(loudness_norm)
cli.add_command(slice_audio)
cli.add_command(slice_audio_v2)
cli.add_command(slice_audio_v3)
cli.add_command(resample)
cli.add_command(transcribe)
cli.add_command(merge_short)
cli.add_command(merge_lab)


if __name__ == "__main__":
to_wav()
cli()
8 changes: 2 additions & 6 deletions fish_audio_preprocess/cli/convert_to_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,14 @@ def to_wav(
skipped += 1
continue

command = ["ffmpeg", "-i", str(file)]
command = ["ffmpeg", "-y", "-nostats", "-loglevel", "error", "-i", str(file)]

if segment > 0:
command.extend(["-f", "segment", "-segment_time", str(segment)])

command.append(str(new_file))

sp.check_call(
command,
stdout=sp.DEVNULL,
stderr=sp.DEVNULL,
)
sp.check_call(command)

logger.info("Done!")
logger.info(f"Total: {len(files)}, Skipped: {skipped}")
Expand Down
12 changes: 10 additions & 2 deletions fish_audio_preprocess/cli/length.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
from typing import Optional

import click
import torchaudio
from loguru import logger
from tqdm import tqdm

from fish_audio_preprocess.utils.file import AUDIO_EXTENSIONS, list_files

backends = torchaudio.list_audio_backends()
if "sox" in backends:
backend = "sox"
elif "ffmpeg" in backends:
backends = "ffmpeg"
else:
backend = "soundfile"


def process_one(file, input_dir):
import soundfile as sf
Expand All @@ -28,10 +37,9 @@ def process_one(file, input_dir):


def process_one_accurate(file, input_dir):
import torchaudio

try:
y, sr = torchaudio.load(str(file), backend="sox")
y, sr = torchaudio.load(str(file), backend=backend)
return y.size(-1), sr, y.size(-1) / sr, file.relative_to(input_dir)
except Exception as e:
logger.warning(f"Error reading {file}: {e}")
Expand Down
137 changes: 136 additions & 1 deletion fish_audio_preprocess/cli/slice_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,140 @@ def slice_audio_v2(
logger.info(f"Output directory: {output_dir}")


@click.command()
@click.argument("input_dir", type=click.Path(exists=True, file_okay=False))
@click.argument("output_dir", type=click.Path(exists=False, file_okay=False))
@click.option("--recursive/--no-recursive", default=True, help="Search recursively")
@click.option(
"--overwrite/--no-overwrite", default=False, help="Overwrite existing files"
)
@click.option(
"--clean/--no-clean", default=False, help="Clean output directory before processing"
)
@click.option(
"--num-workers",
help="Number of workers to use for processing, defaults 1",
default=1,
show_default=True,
type=int,
)
@click.option(
"--min-duration",
help="Minimum duration of each slice",
default=0.5,
show_default=True,
type=float,
)
@click.option(
"--max-duration",
help="Maximum duration of each slice",
default=20.0,
show_default=True,
type=float,
)
@click.option(
"--min-silence-duration",
help="Minimum duration of each slice",
default=0.3,
show_default=True,
type=float,
)
@click.option(
"--speech-pad-duration",
help="final speech chunks are padded by speech pad duration(s) each side",
default=0.1,
show_default=True,
type=float,
)
@click.option(
"--flat-layout/--no-flat-layout", default=False, help="Use flat directory structure"
)
@click.option(
"--merge-short/--no-merge-short",
default=False,
help="Merge short slices automatically",
)
def slice_audio_v3(
input_dir: str,
output_dir: str,
recursive: bool,
overwrite: bool,
clean: bool,
num_workers: int,
min_duration: float,
max_duration: float,
min_silence_duration: float,
speech_pad_duration: float,
flat_layout: bool,
merge_short: bool,
):
"""(silero-vad version) Slice audio files into smaller chunks by silence."""

from fish_audio_preprocess.utils.slice_audio_v3 import slice_audio_file_v3

input_dir, output_dir = Path(input_dir), Path(output_dir)

if flat_layout:
logger.info("Using flat directory structure")

if merge_short:
logger.info("Merging short slices automatically")

if input_dir == output_dir and clean:
logger.error("You are trying to clean the input directory, aborting")
return

make_dirs(output_dir, clean)

files = list_files(input_dir, extensions=AUDIO_EXTENSIONS, recursive=recursive)
logger.info(f"Found {len(files)} files, processing...")

skipped = 0

with ProcessPoolExecutor(max_workers=num_workers) as executor:
tasks = []

for file in tqdm(files, desc="Preparing tasks"):
# Get relative path to input_dir
relative_path = file.relative_to(input_dir)
save_path = output_dir / relative_path.parent / relative_path.stem

if save_path.exists() and not overwrite:
skipped += 1
continue

if (
output_dir / relative_path.parent / relative_path.stem
if not flat_layout
else output_dir / relative_path.parent
).exists() is False:
(
output_dir / relative_path.parent / relative_path.stem
if not flat_layout
else output_dir / relative_path.parent
).mkdir(parents=True)

tasks.append(
executor.submit(
slice_audio_file_v3,
input_file=str(file),
output_dir=save_path,
min_duration=min_duration,
max_duration=max_duration,
min_silence_duration=min_silence_duration,
speech_pad_s=speech_pad_duration,
flat_layout=flat_layout,
merge_short=merge_short,
)
)

for i in tqdm(as_completed(tasks), total=len(tasks), desc="Processing"):
assert i.exception() is None, i.exception()

logger.info("Done!")
logger.info(f"Total: {len(files)}, Skipped: {skipped}")
logger.info(f"Output directory: {output_dir}")


if __name__ == "__main__":
slice_audio()
slice_audio_v3()
125 changes: 125 additions & 0 deletions fish_audio_preprocess/utils/slice_audio_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# This file is edited from Anya

from pathlib import Path
from typing import Iterable, Union

import librosa
import numpy as np
import soundfile as sf
import torch
import torchaudio
from silero_vad import get_speech_timestamps, load_silero_vad

from fish_audio_preprocess.utils.slice_audio import slice_by_max_duration
from fish_audio_preprocess.utils.slice_audio_v2 import merge_short_chunks


def slice_audio_v3(
audio: np.ndarray,
rate: int,
min_duration: float = 1.0,
max_duration: float = 20.0,
min_silence_duration: float = 0.3,
speech_pad_s: float = 0.1,
merge_short: bool = False,
) -> Iterable[np.ndarray]:
"""Slice audio by silence

Args:
audio: audio data, in shape (samples, channels)
rate: sample rate
min_duration: minimum duration of each slice
max_duration: maximum duration of each slice
min_silence_duration: minimum duration of silence
speech_pad_s: final speech chunks are padded by speech_pad_s each side
merge_short: merge short slices automatically

Returns:
Iterable of sliced audio
"""

if len(audio) / rate < min_duration:
sliced_by_max_duration_chunk = slice_by_max_duration(audio, max_duration, rate)
yield from (
merge_short_chunks(sliced_by_max_duration_chunk, max_duration, rate)
if merge_short
else sliced_by_max_duration_chunk
)
return

vad_model = load_silero_vad()

wav = torch.from_numpy(audio)
if wav.dim() > 1:
wav = wav.mean(dim=0, keepdim=True)

sr = 16000
if sr != rate:
transform = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)
wav = transform(wav)

speech_timestamps = get_speech_timestamps(
wav,
vad_model,
sampling_rate=sr,
min_silence_duration_ms=int(min_silence_duration * 1000),
min_speech_duration_ms=int(min_duration * 1000),
speech_pad_ms=int(speech_pad_s * 1000),
max_speech_duration_s=max_duration,
return_seconds=True,
)

sliced_audio = [
audio[int(timestamp["start"] * rate) : int(timestamp["end"] * rate)]
for timestamp in speech_timestamps
]

if merge_short:
sliced_audio = merge_short_chunks(sliced_audio, max_duration, rate)

for chunk in sliced_audio:
sliced_by_max_duration_chunk = slice_by_max_duration(chunk, max_duration, rate)
yield from sliced_by_max_duration_chunk


def slice_audio_file_v3(
input_file: Union[str, Path],
output_dir: Union[str, Path],
min_duration: float = 1.0,
max_duration: float = 20.0,
min_silence_duration: float = 0.1,
speech_pad_s: float = 0.1,
flat_layout: bool = False,
merge_short: bool = False,
) -> None:
"""
Slice audio by silence and save to output folder

Args:
input_file: input audio file
output_dir: output folder
min_duration: minimum duration of each slice
max_duration: maximum duration of each slice
min_silence_duration: minimum duration of silence
speech_pad_s: final speech chunks are padded by speech_pad_s each side
flat_layout: use flat directory structure
merge_short: merge short slices automatically
"""

output_dir = Path(output_dir)
audio, rate = librosa.load(str(input_file), sr=None, mono=True)
for idx, sliced in enumerate(
slice_audio_v3(
audio,
rate,
min_duration=min_duration,
max_duration=max_duration,
speech_pad_s=speech_pad_s,
min_silence_duration=min_silence_duration,
merge_short=merge_short,
)
):
if flat_layout:
sf.write(str(output_dir) + f"_{idx:04d}.wav", sliced, rate)
else:
sf.write(str(output_dir / f"{idx:04d}.wav"), sliced, rate)
6 changes: 3 additions & 3 deletions fish_audio_preprocess/utils/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def batch_transcribe(
model = AutoModel(
model=model_size,
vad_model="fsmn-vad",
punc_model="ct-punc" if model_size == "paraformer-zh" else None,
punc_model="ct-punc",
log_level="ERROR",
disable_pbar=True,
)
Expand All @@ -77,8 +77,8 @@ def batch_transcribe(
# print(result)
if isinstance(result, list):
results[str(file)] = "".join(
[re.sub(r"<\|.*?\|>", "", item["text"]) for item in result]
)
[re.sub(r"< \|.*?\| >", "", item["text"]) for item in result]
).strip()
else:
results[str(file)] = result["text"]
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"richuru>=0.1.1",
"praat-parselmouth>=0.4.3",
"click>=8.0.0",
"faster-whisper @ git+https://github.com/SYSTRAN/faster-whisper",
"faster-whisper",
"funasr",
"modelscope",
]
Expand Down