Skip to content

Commit

Permalink
example: add musicgen example (#2899)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Oct 26, 2023
1 parent ce807f3 commit 3d271c4
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 50 deletions.
Empty file removed example/LLM/musicgen/README.md
Empty file.
16 changes: 0 additions & 16 deletions example/LLM/musicgen/runtime.yaml

This file was deleted.

21 changes: 21 additions & 0 deletions example/text-to-music/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Text To Music
======

Examples for text-to-music offline batch evaluations and online demo.

Models
------

- [musicGen](https://musicgen.com/): a powerful single Language Model (LM) redefining the boundaries of conditional music generation, with the ability to create high-quality music by taking cues from text descriptions or melodies.

Datasets
------

MusicGen samples from <https://ai.honu.io/papers/musicgen/> website.

What we learn from these examples?
------

- Build Starwhale Model for the text-to-music models.
- Log audio artifact in the evaluation phase.
- Write evaluation results summary by Starwhale Report.
18 changes: 18 additions & 0 deletions example/text-to-music/datasets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Datasets for text-to-music
======

MusicGen mini
------

Run command to build dataset:

```bash
python musicgen-mini.py
```

Run swcli command to show built dataset:

```bash
swcli dataset info musicgen-mini
```

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from starwhale import dataset
from starwhale.utils.debug import init_logger

init_logger(4)
init_logger(3)

# data from https://ai.honu.io/papers/musicgen/
desc_samples = [
Expand Down Expand Up @@ -29,15 +29,15 @@
"An energetic hip-hop music piece, with synth sounds and strong bass. There is a rhythmic hi-hat patten in the drums."
"90s rock song with electric guitar and heavy drums",
"An 80s driving pop song with heavy drums and synth pads in the background",
" An energetic hip-hop music piece, with synth sounds and strong bass. There is a rhythmic hi-hat patten in the drums.",
"An energetic hip-hop music piece, with synth sounds and strong bass. There is a rhythmic hi-hat patten in the drums.",
]


def build_dataset() -> None:
print("Building musicgen dataset...")
with dataset("musicgen-mini") as ds:
for desc in desc_samples:
ds.append({"desc": desc})
for idx, desc in enumerate(desc_samples):
ds[idx] = {"desc": desc}
ds.commit()


Expand Down
File renamed without changes.
47 changes: 47 additions & 0 deletions example/text-to-music/musicgen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
MusicGen Example Guides
======

[MusicGen](https://ai.honu.io/papers/musicgen/) a single Language Model (LM) that operates over several streams of compressed discrete music representation.

- 🏔️ Homepage: ️<https://ai.honu.io/papers/musicgen/>
- 🌋 Github: <https://github.com/facebookresearch/audiocraft>
- 🏕️ Size: small(300M), melody(1.5B), medium(1.5B), large(3.3B)

Login Starwhale Cloud
------

```bash
swcli instance login --token "${TOKEN}" --alias cloud-cn https://cloud.starwhale.cn/
```

Build Starwhale Runtime
------

```bash
swcli -vvv runtime build
swcli runtime cp musicgen https://cloud.starwhale.cn/project/starwhale:llm_text_to_audio
```

Build Starwhale Model
------

Model name choices: `melody`, `medium`, `small` and `large`.

```bash
python3 build.py ${model_name}

swcli runtime activate musicgen
python3 build.py small
swcli model cp musicgen-small https://cloud.starwhale.cn/project/starwhale:llm_text_to_audio
```

Run Starwhale Model
------

```bash
# use model src dir
swcli model run --workdir . --runtime musicgen --dataset musicgen-mini -m evaluation

# use model package
swcli model run --uri musicgen-small --runtime musicgen --dataset musicgen-mini
```
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from huggingface_hub import snapshot_download

from starwhale import model as starwhale_model
from starwhale.utils import debug
from starwhale import init_logger

try:
from .utils import (
Expand All @@ -22,7 +22,13 @@
)
from evaluation import music_predict

debug.init_logger(4)
# init_logger configures the logging level of starwhale:
# -> 0: ERROR (default)
# -> 1: WARNING
# -> 2: INFO
# -> 3: DEBUG
# -> 4: TRACE
init_logger(3)


def build_starwhale_model(model_name: str) -> None:
Expand All @@ -39,6 +45,8 @@ def build_starwhale_model(model_name: str) -> None:
)

prepare_build_model_package(model_name)

# Use Starwhale SDK to build Starwhale model Package, `swcli model build` is also available.
starwhale_model.build(name=f"musicgen-{model_name}", modules=[music_predict])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import typing as t
from tempfile import NamedTemporaryFile

from audiocraft.models import MusicGen
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from audiocraft.data.audio import audio_write
from audiocraft.models.loaders import load_lm_model, load_compression_model

from starwhale import Audio, MIMEType, evaluation

Expand All @@ -15,36 +14,31 @@
except ImportError:
from utils import get_model_name, PRETRAINED_MODELS_DIR

duration = int(os.environ.get("DURATION", 10))
top_k = int(os.environ.get("TOP_K", 250))
top_p = int(os.environ.get("TOP_P", 0))
temperature = float(os.environ.get("TEMPERATURE", 1.0))
cfg_coef = float(os.environ.get("CFG_COEF", 3.0))
max_input_length = int(os.environ.get("MAX_INPUT_LENGTH", 512))
max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", 500)) # 500 tokens = 10 seconds
guidance_scale = float(os.environ.get("GUIDANCE_SCALE", 3.0))

_g_model = None
_g_processor = None
_g_model_name = None


def _load_model() -> t.Tuple[MusicGen, str]:
global _g_model, _g_model_name
def _load_model() -> t.Tuple[t.Any, t.Any, str]:
global _g_model, _g_model_name, _g_processor

if _g_model is None or _g_model_name is None:
model_name = get_model_name()
device = "cuda"
c_model = load_compression_model(
PRETRAINED_MODELS_DIR / model_name / "compression_state_dict.bin",
device=device,
if _g_model is None or _g_model_name is None or _g_processor is None:
_g_model_name = get_model_name()
_g_model = MusicgenForConditionalGeneration.from_pretrained(
PRETRAINED_MODELS_DIR / _g_model_name
)
l_model = load_lm_model(
PRETRAINED_MODELS_DIR / model_name / "state_dict.bin", device=device
_g_model.to("cuda")
_g_processor = AutoProcessor.from_pretrained(
PRETRAINED_MODELS_DIR / _g_model_name
)
if model_name == "melody":
l_model.condition_provider.conditioners["self_wav"].match_len_on_eval = True
_g_model = MusicGen(model_name, c_model, l_model)
_g_model_name = model_name

return _g_model, _g_model_name
return _g_model, _g_processor, _g_model_name


@evaluation.predict(
Expand All @@ -55,30 +49,36 @@ def _load_model() -> t.Tuple[MusicGen, str]:
)
def music_predict(data: dict) -> Audio:
# TODO: support batch prediction
model, _ = _load_model()
model.set_generation_params(
duration=duration,
model, processor, _ = _load_model()
inputs = processor(
text=[data["desc"][:max_input_length]],
padding=True,
return_tensors="pt",
)
# TODO: support melody
outputs = model.generate(
**inputs.to("cuda"),
do_sample=True,
guidance_scale=guidance_scale,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
temperature=temperature,
cfg_coef=cfg_coef,
)
# TODO: support melody
outputs = model.generate([data.desc[:max_input_length]])
output = outputs[0].detach().cpu().float()

with NamedTemporaryFile("wb", suffix=".wav", delete=True) as file:
fpath = audio_write(
stem_name=file.name,
wav=output,
sample_rate=model.sample_rate,
sample_rate=model.config.audio_encoder.sampling_rate,
strategy="loudness",
loudness_headroom_db=16,
loudness_compressor=True,
add_suffix=False,
)
audio = Audio(
fp=fpath,
fp=fpath.read_bytes(),
mime_type=MIMEType.WAV,
)
return audio
4 changes: 4 additions & 0 deletions example/text-to-music/musicgen/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch==2.0.1
transformers==4.31.0
huggingface-hub # download hf models
audiocraft==1.0.0
10 changes: 10 additions & 0 deletions example/text-to-music/musicgen/runtime.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: musicgen
mode: venv
environment:
arch: noarch
os: ubuntu:20.04
cuda: 11.7
python: 3.9
starwhale_version: 0.6.1 # Starwhale >= 0.6.0 supports log artifacts in the evaluation phase.
dependencies:
- requirements.txt
File renamed without changes.

0 comments on commit 3d271c4

Please sign in to comment.