-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 56ca359
Showing
46 changed files
with
3,006 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
playground* | ||
wandb | ||
results | ||
results_eval | ||
saved_models | ||
outputs | ||
reference_code | ||
messy_scripts | ||
*_decomposed_params.pt | ||
**/__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
<h1 align="center"> | ||
<h1>Transformer<sup>2</sup>: Self-adaptive LLMs🦎 </h1> | ||
</h1> | ||
<p align="center"> | ||
📄 <a href="xxx">[Paper]</a> | | ||
🤗 <a href="https://huggingface.co/SakanaAI">[Hugging Face]</a> | ||
</p> | ||
|
||
Self-adaptive large language models (LLMs) aim to solve the challenges posed by traditional fine-tuning methods, which are often computationally intensive and static in their ability to handle diverse tasks. | ||
|
||
We are excited to introduce Transformer², a novel self-adaptation framework that adapts LLMs for unseen tasks in real-time by selectively adjusting only the singular components of their weight matrices. | ||
During inference, Transformer² employs a two-pass mechanism: first, a dispatch system identifies the task properties, and then task-specific ``expert'' vectors, trained using reinforcement learning, are dynamically mixed to obtain targeted behavior for the incoming prompt. | ||
<h1 align="center"> | ||
<a> | ||
<img width="500" src="assets/cover.gif"></a><br> | ||
<br> | ||
|
||
|
||
## Installation | ||
|
||
### 1. Clone the Repo | ||
``` | ||
git clone https://github.com/SakanaAI/self-adaptive-llms | ||
cd self-adaptive-llms | ||
``` | ||
|
||
### 2. Install Libraries | ||
```bash | ||
conda create -n t2 python=3.11 -y | ||
conda activate t2 | ||
pip install --upgrade pip | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### 3. Install Tasks Evaluator | ||
```bash | ||
cd evaluation/fishfarm | ||
pip install -e . | ||
``` | ||
|
||
## Usage | ||
We provide example scripts for both training and evalution. | ||
|
||
Please change the argument in the provided script to choose among models and tasks | ||
|
||
### Training | ||
|
||
```bash | ||
bash scripts/train_task_expert.sh | ||
``` | ||
|
||
### Evaluation | ||
|
||
#### Prompt-based evaluation | ||
Classficitaion expert can be loaded by specifying the CLS_EXPERT_PATH in the script. | ||
```bash | ||
bash scripts/eval_prompt_based.sh | ||
``` | ||
|
||
#### Few-shots evaluation | ||
```bash | ||
bash scripts/eval_few_shot.sh | ||
``` | ||
|
||
## Citation | ||
If you find **Transformer^2** useful for your research, please cite using this BibTeX: | ||
``` | ||
arxiv bib xxx | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import BaseModel | ||
from .llama3instruct import Llama3Instruct8B | ||
from .mistral03instruct import MistralV03Instruct7B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseModel(ABC): | ||
def __init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_model_id(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_model_name(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_param_file(self, param_folder_path=""): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from .base import BaseModel | ||
|
||
|
||
class Llama3Instruct8B(BaseModel): | ||
def __init__(self): | ||
self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
self.dec_param_file_n = "llama3_decomposed_params.pt" | ||
|
||
def get_model_id(self): | ||
return self.model_id | ||
|
||
def get_model_name(self): | ||
return self.model_id.split("/")[1] | ||
|
||
def get_param_file(self, param_folder_path=""): | ||
return os.path.join(param_folder_path, self.dec_param_file_n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from .base import BaseModel | ||
|
||
|
||
class MistralV03Instruct7B(BaseModel): | ||
def __init__(self): | ||
self.model_id = "mistralai/Mistral-7B-Instruct-v0.3" | ||
self.dec_param_file_n = "mistral_decomposed_params.pt" | ||
|
||
def get_model_id(self): | ||
return self.model_id | ||
|
||
def get_model_name(self): | ||
return self.model_id.split("/")[1] | ||
|
||
def get_param_file(self, param_folder_path=""): | ||
return os.path.join(param_folder_path, self.dec_param_file_n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
base_model: | ||
_target_: base_model.Llama3Instruct8B | ||
|
||
|
||
base_model_name: llama3i8b | ||
|
||
# reference_params_results: | ||
# - 'saved_models/llama3i8b/gsm8k/learnable_params.pt' | ||
# - 'saved_models/llama3i8b/mbpp/learnable_params.pt' | ||
# - 'saved_models/llama3i8b/ai2arc/learnable_params.pt' | ||
|
||
reference_params_results: | ||
- "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt" | ||
- "ckpts/learnable_params/llama3_8b_instruct_mbpp_pro_svd_pg_mlp.pt" | ||
- "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
base_model: | ||
_target_: base_model.MistralV03Instruct7B | ||
|
||
|
||
base_model_name: mistral03i7b | ||
|
||
reference_params_results: | ||
- 'saved_models/mistral03i7b/gsm8k/policy_params.pt' | ||
- 'saved_models/mistral03i7b/mbpp/policy_params.pt' | ||
- 'saved_models/mistral03i7b/ai2arc/policy_params.pt' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
defaults: | ||
- _self_ | ||
- policy@_global_: default | ||
- task@_global_: gsm8k | ||
- base_model@_global_: llama3i8b | ||
- optimization@_global_: reinforce | ||
- mode@_global_: training | ||
|
||
num_iters: 2000 | ||
test_interval: 10 | ||
lr: 2e-3 | ||
batch_size: 256 | ||
seed: 42 | ||
init_val: 0.1 | ||
test_only: false | ||
model_dir: null | ||
save_legacy_params: false | ||
use_lora: false | ||
prompt_based_eval: false | ||
experts_path_dict: null | ||
|
||
run_name: null | ||
|
||
load_ckpt: null | ||
exp_suffix: 'st' | ||
|
||
exp_name: ${base_model_name}/${optim_name}-${exp_suffix} | ||
|
||
wandb_log: true # enabled by default | ||
wandb_project: shakeoff | ||
wandb_group_name: ${exp_name} | ||
extract_svd: false | ||
|
||
out_dir: results | ||
|
||
hydra: | ||
run: | ||
dir: ${out_dir}/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
exp_name: eval_${base_model_name}/temp-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy}-${exp_suffix}-r | ||
|
||
test_only: true | ||
load_ckpt: null | ||
use_lora: false | ||
|
||
prompt_based_eval: false | ||
experts_path_dict: | ||
code: null | ||
math: null | ||
reasoning: null | ||
other: null | ||
|
||
wandb_project: T^2_eval | ||
wandb_group_name: ${exp_name} | ||
out_dir: results_eval |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
|
||
optimization_algorithm: | ||
_target_: optim_modules.CEM | ||
elite_ratio: ${elite_ratio} | ||
pop_size: ${pop_size} | ||
min_trainable_param: ${min_trainable_param} | ||
max_trainable_param: ${max_trainable_param} | ||
optim_ema: ${optim_ema} | ||
re_eval_best: ${re_eval_best} | ||
use_loglikelihood_for_ties: ${use_loglikelihood_for_ties} | ||
|
||
|
||
pop_size: 32 | ||
elite_ratio: 0.2 | ||
min_trainable_param: 0 | ||
max_trainable_param: 1 | ||
optim_ema: 0 | ||
re_eval_best: True | ||
use_loglikelihood_for_ties: true | ||
optim_name: CEM-pop${pop_size}e${elite_ratio}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
optimization_algorithm: | ||
_target_: optim_modules.Reinforce | ||
# policy: ${policy} | ||
# gpu: ${gpu} | ||
max_grad_norm: ${max_grad_norm} | ||
lr: ${lr} | ||
rw_norm: ${rw_norm} | ||
rw_clip: ${rw_clip} | ||
kl_ref_coeff: ${kl_ref_coeff} | ||
|
||
|
||
# policy: | ||
# gpu: | ||
max_grad_norm: 1e-3 | ||
lr: 2e-3 | ||
rw_norm: 0 | ||
rw_clip: null | ||
kl_ref_coeff: 0 | ||
rw_strategy: rN${rw_norm}C${rw_clip} | ||
optim_name: RL-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
|
||
optimization_algorithm: | ||
_target_: optim_modules.RandomShooting | ||
# policy: ${policy} | ||
# gpu: ${gpu} | ||
pop_size: ${pop_size} | ||
min_trainable_param: ${min_trainable_param} | ||
max_trainable_param: ${max_trainable_param} | ||
optim_ema: ${optim_ema} | ||
re_eval_best: ${re_eval_best} | ||
use_loglikelihood_for_ties: ${use_loglikelihood_for_ties} | ||
|
||
|
||
pop_size: 32 | ||
min_trainable_param: 0 | ||
max_trainable_param: 1 | ||
optim_ema: 0 | ||
re_eval_best: True | ||
use_loglikelihood_for_ties: false | ||
optim_name: RSML-pop${pop_size}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
shakeoff_policy: | ||
_target_: policy.Policy | ||
init_val: ${init_val} | ||
mode: ${policy_mode} | ||
max_mult: ${max_mult} | ||
|
||
policy_mode: 1 | ||
max_mult: 1 | ||
policy_name: ${policy_mode}_mm${max_mult} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
|
||
shakeoff_policy: | ||
_target_: policy.WeightedCombination | ||
base_policy_cfg: null | ||
params_paths: ${reference_params_results} | ||
norm_coeffs: ${norm_coeffs} | ||
per_layer: ${per_layer} | ||
init_values: ${init_values} | ||
|
||
norm_coeffs: true | ||
per_layer: false | ||
init_values: null | ||
|
||
policy_name: Wcomb_n${norm_coeffs}_p${per_layer} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
task_loader: | ||
_target_: tasks.FewShotTask | ||
wrapped_task: | ||
_target_: tasks.AI2ArcTask | ||
wrapped_split: ${wrapped_split} | ||
shots: ${task_shots} | ||
seed: ${task_loader_seed} | ||
|
||
|
||
wrapped_split: transfer | ||
task_shots: 20 | ||
task_loader_seed: 38 | ||
|
||
task_name: arc_chal_${task_shots}shots | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
task_loader: | ||
_target_: tasks.FewShotTask | ||
wrapped_task: | ||
_target_: tasks.AI2ArcTask | ||
wrapped_split: ${wrapped_split} | ||
shots: ${task_shots} | ||
seed: ${task_loader_seed} | ||
|
||
|
||
wrapped_split: transfer | ||
task_shots: 3 | ||
task_loader_seed: 38 | ||
|
||
task_name: arc_chal_${task_shots}shots | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
task_loader: | ||
_target_: tasks.FewShotTask | ||
wrapped_task: | ||
_target_: tasks.AI2ArcTask | ||
wrapped_split: ${wrapped_split} | ||
shots: ${task_shots} | ||
seed: ${task_loader_seed} | ||
|
||
|
||
wrapped_split: transfer | ||
task_shots: 5 | ||
task_loader_seed: 38 | ||
|
||
task_name: arc_chal_${task_shots}shots | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
task_loader: | ||
_target_: tasks.AI2ArcTask | ||
|
||
|
||
task_name: ai2_arc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
task_loader: | ||
_target_: tasks.ClsTask | ||
|
||
|
||
task_name: Cls | ||
|
Oops, something went wrong.