Skip to content

Commit

Permalink
Add lora support
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 29, 2023
1 parent a86e61a commit cdecc2a
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 2 deletions.
89 changes: 89 additions & 0 deletions fish_speech/configs/text2semantic_finetune_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
defaults:
- base
- _self_

project: text2semantic_400m_finetune_lora
max_length: 4096
ckpt_path: checkpoints/text2semantic-400m-v0.3-4k.pth
resume_weights_only: true

# Lightning Trainer
trainer:
accumulate_grad_batches: 2
gradient_clip_val: 1.0
gradient_clip_algorithm: 'norm'
max_steps: 1000
precision: bf16-true
limit_val_batches: 10
log_every_n_steps: 10

# Dataset Configuration
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: fishaudio/speech-lm-v1

# Dataset Configuration
train_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}

val_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}

data:
_target_: fish_speech.datasets.text.TextDataModule
train_dataset: ${train_dataset}
val_dataset: ${val_dataset}
num_workers: 4
batch_size: 8
tokenizer: ${tokenizer}
max_length: ${max_length}

# Model Configuration
model:
_target_: fish_speech.models.text2semantic.TextToSemantic

model:
_target_: fish_speech.models.text2semantic.llama.Transformer
config:
_target_: fish_speech.models.text2semantic.llama.ModelArgs
max_seq_len: 4096
vocab_size: 36408
n_layer: 24
n_head: 16
dim: 1024
rope_base: 10000
norm_eps: 1e-5
num_codebooks: 4 # single codebook
codebook_size: 168 # codebook size 160 + 2 special tokens

lora_config:
_target_: fish_speech.models.text2semantic.lit_module.LoraConfig
r: 8
lora_alpha: 16

optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 3e-4
weight_decay: 0.1
betas: [0.9, 0.95]
eps: 1e-5

lr_scheduler:
_target_: torch.optim.lr_scheduler.LambdaLR
_partial_: true
lr_lambda:
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
_partial_: true
num_warmup_steps: 100
num_training_steps: ${trainer.max_steps}
final_lr_ratio: 0.1

# Callbacks
callbacks:
model_checkpoint:
every_n_train_steps: 200
62 changes: 60 additions & 2 deletions fish_speech/models/text2semantic/lit_module.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,81 @@
import platform
from typing import Any, Optional
from dataclasses import dataclass
from typing import Any, Dict, Optional

import lightning as L
import loralib as lora
import torch
import torch.nn.functional as F
from lightning.pytorch.utilities.types import OptimizerLRScheduler

import fish_speech.utils as utils
from fish_speech.models.text2semantic.llama import Transformer

log = utils.RankedLogger(__name__, rank_zero_only=True)


@dataclass
class LoraConfig:
r: int
lora_alpha: float
lora_dropout: float = 0.0


class TextToSemantic(L.LightningModule):
def __init__(self, model, optimizer: Any, lr_scheduler: Any):
def __init__(
self,
model: Transformer,
optimizer: Any,
lr_scheduler: Any,
lora_config: Optional[LoraConfig] = None,
):
super().__init__()

self.model = model
self.optimizer_builder = optimizer
self.lr_scheduler_builder = lr_scheduler
self.lora_config = lora_config

if self.lora_config is not None:
self.setup_lora()

def setup_lora(self):
# Replace the embedding layer with a LoRA layer
self.model.embeddings = lora.Embedding(
num_embeddings=self.model.embeddings.num_embeddings,
embedding_dim=self.model.embeddings.embedding_dim,
padding_idx=self.model.embeddings.padding_idx,
r=self.lora_config.r,
lora_alpha=self.lora_config.lora_alpha,
)

# Replace output layer with a LoRA layer
linears = [(self.model, "output")]

# Replace all linear layers with LoRA layers
for layer in self.model.layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)

for module, layer in linears:
updated_linear = lora.Linear(
in_features=getattr(module, layer).in_features,
out_features=getattr(module, layer).out_features,
bias=getattr(module, layer).bias,
r=self.lora_config.r,
lora_alpha=self.lora_config.lora_alpha,
lora_dropout=self.lora_config.lora_dropout,
)
setattr(module, layer, updated_linear)

# Mark only the LoRA layers as trainable
lora.mark_only_lora_as_trainable(self.model, bias="lora_only")

def forward(self, x):
return self.model(x)
Expand Down

0 comments on commit cdecc2a

Please sign in to comment.