Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingbigcat committed Jan 7, 2025
0 parents commit 56ca359
Show file tree
Hide file tree
Showing 46 changed files with 3,006 additions and 0 deletions.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
playground*
wandb
results
results_eval
saved_models
outputs
reference_code
messy_scripts
*_decomposed_params.pt
**/__pycache__/
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
<h1 align="center">
<h1>Transformer<sup>2</sup>: Self-adaptive LLMs🦎 </h1>
</h1>
<p align="center">
📄 <a href="xxx">[Paper]</a> |
🤗 <a href="https://huggingface.co/SakanaAI">[Hugging Face]</a>
</p>

Self-adaptive large language models (LLMs) aim to solve the challenges posed by traditional fine-tuning methods, which are often computationally intensive and static in their ability to handle diverse tasks.

We are excited to introduce Transformer², a novel self-adaptation framework that adapts LLMs for unseen tasks in real-time by selectively adjusting only the singular components of their weight matrices.
During inference, Transformer² employs a two-pass mechanism: first, a dispatch system identifies the task properties, and then task-specific ``expert'' vectors, trained using reinforcement learning, are dynamically mixed to obtain targeted behavior for the incoming prompt.
<h1 align="center">
<a>
<img width="500" src="assets/cover.gif"></a><br>
<br>


## Installation

### 1. Clone the Repo
```
git clone https://github.com/SakanaAI/self-adaptive-llms
cd self-adaptive-llms
```

### 2. Install Libraries
```bash
conda create -n t2 python=3.11 -y
conda activate t2
pip install --upgrade pip
pip install -r requirements.txt
```

### 3. Install Tasks Evaluator
```bash
cd evaluation/fishfarm
pip install -e .
```

## Usage
We provide example scripts for both training and evalution.

Please change the argument in the provided script to choose among models and tasks

### Training

```bash
bash scripts/train_task_expert.sh
```

### Evaluation

#### Prompt-based evaluation
Classficitaion expert can be loaded by specifying the CLS_EXPERT_PATH in the script.
```bash
bash scripts/eval_prompt_based.sh
```

#### Few-shots evaluation
```bash
bash scripts/eval_few_shot.sh
```

## Citation
If you find **Transformer^2** useful for your research, please cite using this BibTeX:
```
arxiv bib xxx
```
Binary file added assets/cover.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions base_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseModel
from .llama3instruct import Llama3Instruct8B
from .mistral03instruct import MistralV03Instruct7B
18 changes: 18 additions & 0 deletions base_model/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod


class BaseModel(ABC):
def __init__(self):
pass

@abstractmethod
def get_model_id(self):
raise NotImplementedError

@abstractmethod
def get_model_name(self):
raise NotImplementedError

@abstractmethod
def get_param_file(self, param_folder_path=""):
raise NotImplementedError
18 changes: 18 additions & 0 deletions base_model/llama3instruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

from .base import BaseModel


class Llama3Instruct8B(BaseModel):
def __init__(self):
self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
self.dec_param_file_n = "llama3_decomposed_params.pt"

def get_model_id(self):
return self.model_id

def get_model_name(self):
return self.model_id.split("/")[1]

def get_param_file(self, param_folder_path=""):
return os.path.join(param_folder_path, self.dec_param_file_n)
18 changes: 18 additions & 0 deletions base_model/mistral03instruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

from .base import BaseModel


class MistralV03Instruct7B(BaseModel):
def __init__(self):
self.model_id = "mistralai/Mistral-7B-Instruct-v0.3"
self.dec_param_file_n = "mistral_decomposed_params.pt"

def get_model_id(self):
return self.model_id

def get_model_name(self):
return self.model_id.split("/")[1]

def get_param_file(self, param_folder_path=""):
return os.path.join(param_folder_path, self.dec_param_file_n)
15 changes: 15 additions & 0 deletions cfgs/base_model/llama3i8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
base_model:
_target_: base_model.Llama3Instruct8B


base_model_name: llama3i8b

# reference_params_results:
# - 'saved_models/llama3i8b/gsm8k/learnable_params.pt'
# - 'saved_models/llama3i8b/mbpp/learnable_params.pt'
# - 'saved_models/llama3i8b/ai2arc/learnable_params.pt'

reference_params_results:
- "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt"
- "ckpts/learnable_params/llama3_8b_instruct_mbpp_pro_svd_pg_mlp.pt"
- "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt"
10 changes: 10 additions & 0 deletions cfgs/base_model/mistral03i7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
base_model:
_target_: base_model.MistralV03Instruct7B


base_model_name: mistral03i7b

reference_params_results:
- 'saved_models/mistral03i7b/gsm8k/policy_params.pt'
- 'saved_models/mistral03i7b/mbpp/policy_params.pt'
- 'saved_models/mistral03i7b/ai2arc/policy_params.pt'
38 changes: 38 additions & 0 deletions cfgs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
defaults:
- _self_
- policy@_global_: default
- task@_global_: gsm8k
- base_model@_global_: llama3i8b
- optimization@_global_: reinforce
- mode@_global_: training

num_iters: 2000
test_interval: 10
lr: 2e-3
batch_size: 256
seed: 42
init_val: 0.1
test_only: false
model_dir: null
save_legacy_params: false
use_lora: false
prompt_based_eval: false
experts_path_dict: null

run_name: null

load_ckpt: null
exp_suffix: 'st'

exp_name: ${base_model_name}/${optim_name}-${exp_suffix}

wandb_log: true # enabled by default
wandb_project: shakeoff
wandb_group_name: ${exp_name}
extract_svd: false

out_dir: results

hydra:
run:
dir: ${out_dir}/
16 changes: 16 additions & 0 deletions cfgs/mode/eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
exp_name: eval_${base_model_name}/temp-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy}-${exp_suffix}-r

test_only: true
load_ckpt: null
use_lora: false

prompt_based_eval: false
experts_path_dict:
code: null
math: null
reasoning: null
other: null

wandb_project: T^2_eval
wandb_group_name: ${exp_name}
out_dir: results_eval
Empty file added cfgs/mode/training.yaml
Empty file.
20 changes: 20 additions & 0 deletions cfgs/optimization/cem.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

optimization_algorithm:
_target_: optim_modules.CEM
elite_ratio: ${elite_ratio}
pop_size: ${pop_size}
min_trainable_param: ${min_trainable_param}
max_trainable_param: ${max_trainable_param}
optim_ema: ${optim_ema}
re_eval_best: ${re_eval_best}
use_loglikelihood_for_ties: ${use_loglikelihood_for_ties}


pop_size: 32
elite_ratio: 0.2
min_trainable_param: 0
max_trainable_param: 1
optim_ema: 0
re_eval_best: True
use_loglikelihood_for_ties: true
optim_name: CEM-pop${pop_size}e${elite_ratio}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties}
21 changes: 21 additions & 0 deletions cfgs/optimization/reinforce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

optimization_algorithm:
_target_: optim_modules.Reinforce
# policy: ${policy}
# gpu: ${gpu}
max_grad_norm: ${max_grad_norm}
lr: ${lr}
rw_norm: ${rw_norm}
rw_clip: ${rw_clip}
kl_ref_coeff: ${kl_ref_coeff}


# policy:
# gpu:
max_grad_norm: 1e-3
lr: 2e-3
rw_norm: 0
rw_clip: null
kl_ref_coeff: 0
rw_strategy: rN${rw_norm}C${rw_clip}
optim_name: RL-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy}
20 changes: 20 additions & 0 deletions cfgs/optimization/rsm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

optimization_algorithm:
_target_: optim_modules.RandomShooting
# policy: ${policy}
# gpu: ${gpu}
pop_size: ${pop_size}
min_trainable_param: ${min_trainable_param}
max_trainable_param: ${max_trainable_param}
optim_ema: ${optim_ema}
re_eval_best: ${re_eval_best}
use_loglikelihood_for_ties: ${use_loglikelihood_for_ties}


pop_size: 32
min_trainable_param: 0
max_trainable_param: 1
optim_ema: 0
re_eval_best: True
use_loglikelihood_for_ties: false
optim_name: RSML-pop${pop_size}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties}
9 changes: 9 additions & 0 deletions cfgs/policy/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
shakeoff_policy:
_target_: policy.Policy
init_val: ${init_val}
mode: ${policy_mode}
max_mult: ${max_mult}

policy_mode: 1
max_mult: 1
policy_name: ${policy_mode}_mm${max_mult}
15 changes: 15 additions & 0 deletions cfgs/policy/wcomb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@


shakeoff_policy:
_target_: policy.WeightedCombination
base_policy_cfg: null
params_paths: ${reference_params_results}
norm_coeffs: ${norm_coeffs}
per_layer: ${per_layer}
init_values: ${init_values}

norm_coeffs: true
per_layer: false
init_values: null

policy_name: Wcomb_n${norm_coeffs}_p${per_layer}
15 changes: 15 additions & 0 deletions cfgs/task/ablation_tasks/few_shot_arc_challenge_20.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
task_loader:
_target_: tasks.FewShotTask
wrapped_task:
_target_: tasks.AI2ArcTask
wrapped_split: ${wrapped_split}
shots: ${task_shots}
seed: ${task_loader_seed}


wrapped_split: transfer
task_shots: 20
task_loader_seed: 38

task_name: arc_chal_${task_shots}shots

15 changes: 15 additions & 0 deletions cfgs/task/ablation_tasks/few_shot_arc_challenge_3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
task_loader:
_target_: tasks.FewShotTask
wrapped_task:
_target_: tasks.AI2ArcTask
wrapped_split: ${wrapped_split}
shots: ${task_shots}
seed: ${task_loader_seed}


wrapped_split: transfer
task_shots: 3
task_loader_seed: 38

task_name: arc_chal_${task_shots}shots

15 changes: 15 additions & 0 deletions cfgs/task/ablation_tasks/few_shot_arc_challenge_5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
task_loader:
_target_: tasks.FewShotTask
wrapped_task:
_target_: tasks.AI2ArcTask
wrapped_split: ${wrapped_split}
shots: ${task_shots}
seed: ${task_loader_seed}


wrapped_split: transfer
task_shots: 5
task_loader_seed: 38

task_name: arc_chal_${task_shots}shots

6 changes: 6 additions & 0 deletions cfgs/task/ai2_arc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
task_loader:
_target_: tasks.AI2ArcTask


task_name: ai2_arc

6 changes: 6 additions & 0 deletions cfgs/task/cls.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
task_loader:
_target_: tasks.ClsTask


task_name: Cls

Loading

0 comments on commit 56ca359

Please sign in to comment.