-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathsolo.py
189 lines (136 loc) · 5.2 KB
/
solo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# not suited for average users
# meant for easier understanding of the training process
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.utils.data import Dataset
from audiocraft.modules.conditioners import (
ClassifierFreeGuidanceDropout
)
import os
import wandb
model = MusicGen.get_pretrained('small')
model.lm = model.lm.to(torch.float32) #important
class AudioDataset(Dataset):
def __init__(self,
data_dir
):
self.data_dir = data_dir
self.data_map = []
dir_map = os.listdir(data_dir)
for d in dir_map:
name, ext = os.path.splitext(d)
if ext == '.wav':
if os.path.exists(os.path.join(data_dir, name + '.txt')):
self.data_map.append({
"audio": os.path.join(data_dir, d),
"label": os.path.join(data_dir, name + '.txt')
})
else:
raise ValueError(f'No label file for {name}')
def __len__(self):
return len(self.data_map)
def __getitem__(self, idx):
data = self.data_map[idx]
audio = data['audio']
label = data['label']
return audio, label
dataset = AudioDataset('/home/ubuntu/dataset')
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
learning_rate = 0.0001
model.lm.train()
scaler = torch.cuda.amp.GradScaler()
#from paper
optimizer = AdamW(model.lm.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
run = wandb.init(project='audiocraft')
num_epochs = 10000
save_step = 25
save_models = True
def count_nans(tensor):
nan_mask = torch.isnan(tensor)
num_nans = torch.sum(nan_mask).item()
return num_nans
def preprocess_audio(audio_path, model: MusicGen, duration: int = 30):
wav, sr = torchaudio.load(audio_path)
wav = torchaudio.functional.resample(wav, sr, model.sample_rate)
wav = wav.mean(dim=0, keepdim=True)
end_sample = int(model.sample_rate * duration)
wav = wav[:, :end_sample]
assert wav.shape[0] == 1
assert wav.shape[1] == model.sample_rate * duration
wav = wav.cuda()
wav = wav.unsqueeze(1)
with torch.no_grad():
gen_audio = model.compression_model.encode(wav)
codes, scale = gen_audio
assert scale is None
return codes
def fixnan(tensor: torch.Tensor):
nan_mask = torch.isnan(tensor)
result = torch.where(nan_mask, torch.zeros_like(tensor), tensor)
return result
def one_hot_encode(tensor, num_classes=2048):
shape = tensor.shape
one_hot = torch.zeros((shape[0], shape[1], num_classes))
for i in range(shape[0]):
for j in range(shape[1]):
index = tensor[i, j].item()
one_hot[i, j, index] = 1
return one_hot
duration = 30
current_step = 0
for epoch in range(num_epochs):
for batch_idx, (audio, label) in enumerate(train_dataloader):
optimizer.zero_grad()
#where audio and label are just paths
audio = audio[0]
label = label[0]
audio = preprocess_audio(audio, model) #returns tensor
text = open(label, 'r').read().strip()
attributes, _ = model._prepare_tokens_and_attributes([text], None)
conditions = attributes
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + null_conditions
tokenized = model.lm.condition_provider.tokenize(conditions)
cfg_conditions = model.lm.condition_provider(tokenized)
condition_tensors = cfg_conditions
codes = torch.cat([audio, audio], dim=0)
with torch.autocast(device_type="cuda", dtype=torch.float16):
lm_output = model.lm.compute_predictions(
codes=codes,
conditions=[],
condition_tensors=condition_tensors
)
codes = codes[0]
logits = lm_output.logits[0]
mask = lm_output.mask[0]
codes = one_hot_encode(codes, num_classes=2048)
codes = codes.cuda()
logits = logits.cuda()
mask = mask.cuda()
mask = mask.view(-1)
masked_logits = logits.view(-1, 2048)[mask]
masked_codes = codes.view(-1, 2048)[mask]
loss = criterion(masked_logits,masked_codes)
assert count_nans(masked_logits) == 0
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.lm.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
print(f"Epoch: {epoch}/{num_epochs}, Batch: {batch_idx}/{len(train_dataloader)}, Loss: {loss.item()}")
run.log({
"loss": loss.item(),
"step": current_step,
})
current_step += 1
if save_models:
if current_step % save_step == 0:
torch.save(model.lm.state_dict(), f"saved_models/lm_{current_step}.pt")