-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfish_speech_infer_train.py
191 lines (164 loc) · 8.49 KB
/
fish_speech_infer_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# -*- coding: utf-8 -*-
"""Fish-Speech-Infer-Train.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1trBvrdgyI-Ntd45ZnlT5lhGsI_HnKjC1
<a href="https://colab.research.google.com/drive/1trBvrdgyI-Ntd45ZnlT5lhGsI_HnKjC1?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
Initialization
You can access [this](http://blog.nonewiki.top) get some help
"""
# Commented out IPython magic to ensure Python compatibility.
# Are you using Colab
useColab = True #@param {type:"boolean"}
if useColab:
# Mount Google Cloud Drive
from google.colab import drive
drive.mount('/content/drive')
# %cd /content/drive/MyDrive
else:
# %cd .
# Commented out IPython magic to ensure Python compatibility.
# Clone Repo
import os
if not os.path.exists("fish-speech"):
!git clone https://github.com/fishaudio/fish-speech.git
!git checkout tags/v1.4.3 # Because the current 1.4 version is more stable, we are using 1.4.3 here. You can replace here
else:
print("The Fish Speech project is already in the directory.")
# Enter the project directory
# %cd fish-speech
# Install Packages
!sudo apt install libasound-dev portaudio19-dev libportaudio2 libportaudiocpp0
!pip install pyaudio
!pip install huggingface_hub
!pip install triton
!pip install .
# !huggingface-cli login # If you want to use version 1.5 of the model, please uncomment it
model_id = "fishaudio/fish-speech-1.4" #@param {type:"string"}
download_dir = "checkpoints/fish-speech-1.4" #@param {type:"string"}
!huggingface-cli download {model_id} --local-dir {download_dir} # You can replace here
"""Run the necessary variables."""
# Value
# If you are using a version of the repository other than 1.4, please change the values here
vqgan_model = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" #@param {type: "string"}
vqgan_config_name = "firefly_gan_vq" #@param {type: "string"}
llama_model = "checkpoints/fish-speech-1.4" #@param {type: "string"}
device = "cuda" #@param ["cuda", "cpu"]
useCompile = True #@param {type: "boolean"}
"""Fine-tuning"""
# Datasets
# Batch extraction of semantic tokens
!python tools/vqgan/extract_vq.py data --num-workers 1 --batch-size 16 --config-name {vqgan_config_name} --checkpoint-path {vqgan_model}
# Pack the dataset into protobuf
input_dir = "data" #@param {type: "string"}
output_dir = "data/protos" #@param {type: "string"}
!python tools/llama/build_dataset.py --input {input_dir} --output {output_dir} --text-extension .lab --num-workers 16
# Set up training configuration file
import yaml
project = "Speaker1" #@param {type: "string"}
train_config_name = "text2semantic_finetune" #@param {type: "string"}
lora_config_name = "r_8_alpha_16" #@param {type: "string"}
train_config_path = f"fish_speech/configs/{train_config_name}.yaml"
if not os.path.exists(train_config_path):
print(f"The file {train_config_path} does not exist.")
raise FileNotFoundError(train_config_path)
with open(train_config_path, "r", encoding="utf-8") as file:
data = yaml.load(file, Loader=yaml.FullLoader)
pretrained_ckpt_path = "checkpoints/fish-speech-1.4" #@param {type: "string"}
protos_dir = "data/protos" #@param {type: "string"}
max_steps = 1000 #@param {type: "integer"}
num_workers = 1 #@param {type: "integer"}
batch_size = 1 #@param {type: "integer"}
data["project"] = project
data["pretrained_ckpt_path"] = pretrained_ckpt_path
data["trainer"]["max_steps"] = max_steps
data["data"]["num_workers"] = num_workers
data["data"]["batch_size"] = batch_size
data["model"]["model"]["lora_config"] = lora_config_name
data["train_dataset"]["proto_files"] = [protos_dir]
data["val_dataset"]["proto_files"] = [protos_dir]
with open(train_config_path, "w", encoding="utf-8") as file:
yaml.dump(data, file, allow_unicode=True)
# Load training configuration file
import yaml
train_config_name = "text2semantic_finetune" #@param {type: "string"}
train_config_path = f"fish_speech/configs/{train_config_name}.yaml"
if not os.path.exists(train_config_path):
print(f"The file {train_config_path} does not exist.")
raise FileNotFoundError(train_config_path)
with open(train_config_path, "r", encoding="utf-8") as file:
data = yaml.load(file, Loader=yaml.FullLoader)
project = data["project"]
pretrained_ckpt_path = data["pretrained_ckpt_path"]
lora_config_name = data["model"]["model"]["lora_config"]
print(f"Project: {project}")
print(f"Pretrained ckpt path: {pretrained_ckpt_path}")
# Please execute "# Load training configuration file"
# Fine-tuning with LoRA
!python fish_speech/train.py --config-name {train_config_name}
# Please execute "# Load training configuration file"
# Convert the LoRA weights to regular weights
autoLoRA = False #@param {type: "boolean"}
if autoLoRA:
import re
lora_dir_path = f"results/{project}/checkpoints"
files = [f for f in os.listdir(lora_dir_path) if re.match(r"step_\d+\.ckpt", f)]
lora_weight = f"{lora_dir_path}/{max(files, key=lambda f: int(re.search(r'\d+', f).group()))}"
else:
lora_weight = "results/Speaker1/checkpoints/step_000000010.ckpt" #@param {type: "string"}
output_path = "checkpoints/fish-speech-1.4-Speaker1-lora/" #@param {type: "string"}
!python tools/llama/merge_lora.py --lora-config {lora_config_name} --base-weight {pretrained_ckpt_path} --lora-weight {lora_weight} --output {output}
# Please execute "# Load training configuration file"
# Continue training
autoLatestLoRA = False #@param {type: "boolean"}
if autoLatestLoRA:
import re
lora_dir_path = f"results/{project}/checkpoints"
files = [f for f in os.listdir(lora_dir_path) if re.match(r"step_\d+\.ckpt", f)]
latest_lora_weight = f"{lora_dir_path}/{max(files, key=lambda f: int(re.search(r'\d+', f).group()))}"
else:
latest_lora_weight = "results/Speaker1/checkpoints/step_000000010.ckpt" #@param {type: "string"}
import time
output_path = f"checkpoints/{time.time()}/".replace(".", "_")
!python tools/llama/merge_lora.py --lora-config {lora_config_name} --base-weight {pretrained_ckpt_path} --lora-weight {latest_lora_weight} --output {output}
if not os.path.exists(train_config_path):
print(f"The file {train_config_path} does not exist.")
raise FileNotFoundError(train_config_path)
with open(train_config_path, "r", encoding="utf-8") as file:
data = yaml.load(file, Loader=yaml.FullLoader)
pretrained_ckpt_path = output_path[:-1]
data["pretrained_ckpt_path"] = pretrained_ckpt_path
with open(train_config_path, "w", encoding="utf-8") as file:
yaml.dump(data, file, allow_unicode=True)
input("Please delete all the .ckpt files in the latest_lora_weight folder")
!python fish_speech/train.py --config-name {train_config_name}
"""Inference"""
# Generate prompt from voice
# If you plan to let the model randomly choose a voice timbre, you can skip this step
prompt_wav = "Speaker1.wav" #@param {type: "string"}
output_npy = "fake.npy" #@param {type: "string"}
!python tools/vqgan/inference.py -i {prompt_wav} -o {output_npy} --checkpoint-path {vqgan_model} --device {device
# Generate semantic tokens from text
output_wav_text = "Hello everyone! Welcome to Apple Park" #@param {type: "string"}
output_wav = "fake_Speaker1.wav" #@param {type: "string"}
prompt_wav_text = "Prompt wav text" #@param {type: "string"}
prompt_npy = "fake.npy" #@param {type: "string"}
if useCompile:
!python tools/llama/generate.py --text {output_wav_text} --prompt-text {prompt_wav_text} --prompt-tokens {prompt_npy} --checkpoint-path {llama_model} --num-samples 2 --compile --device {device}
else:
!python tools/llama/generate.py --text {output_wav_text} --prompt-text {prompt_wav_text} --prompt-tokens {prompt_npy} --checkpoint-path {llama_model} --num-samples 2 --device {device}
# Generate semantic tokens from text
input_npy = "codes_0.npy" #@param {type: "string"}
!python tools/vqgan/inference.py -i {input_npy} -o {output_wav} --device {device} --checkpoint-path {vqgan_model}
# Play
from IPython.display import Audio, display
display(Audio(output_wav, autoplay=False))
"""Webui"""
# Download Cloudflared
!wget https://github.com/cloudflare/cloudflared/releases/download/2024.11.1/cloudflared-linux-386 -O cloudflared
!chmod +x cloudflared
# Run webui
if useCompile:
!./cloudflared tunnel --url 127.0.0.1:7860 | python -m tools.webui --llama-checkpoint-path {llama_model} --decoder-checkpoint-path {vqgan_model} --decoder-config-name {vqgan_config_name} --compile --device {device}
else:
!./cloudflared tunnel --url 127.0.0.1:7860 | python -m tools.webui --llama-checkpoint-path {llama_model} --decoder-checkpoint-path {vqgan_model} --decoder-config-name {vqgan_config_name} --device {device}