forked from meta-llama/llama-cookbook
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig_utils.py
61 lines (47 loc) · 2.31 KB
/
config_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import inspect
from dataclasses import fields
from peft import (
LoraConfig,
AdaptionPromptConfig,
PrefixTuningConfig,
)
import configs.datasets as datasets
from configs import lora_config, llama_adapter_config, prefix_config, train_config
from .dataset_utils import DATASET_PREPROC
def update_config(config, **kwargs):
if isinstance(config, (tuple, list)):
for c in config:
update_config(c, **kwargs)
else:
for k, v in kwargs.items():
if hasattr(config, k):
setattr(config, k, v)
elif "." in k:
# allow --some_config.some_param=True
config_name, param_name = k.split(".")
if type(config).__name__ == config_name:
if hasattr(config, param_name):
setattr(config, param_name, v)
else:
# In case of specialized config we can warm user
print(f"Warning: {config_name} does not accept parameter: {k}")
elif isinstance(config, train_config):
print(f"Warning: unknown parameter {k}")
def generate_peft_config(train_config, kwargs):
configs = (lora_config, llama_adapter_config, prefix_config)
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
names = tuple(c.__name__.rstrip("_config") for c in configs)
assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
config = configs[names.index(train_config.peft_method)]
update_config(config, **kwargs)
params = {k.name: getattr(config, k.name) for k in fields(config)}
peft_config = peft_configs[names.index(train_config.peft_method)](**params)
return peft_config
def generate_dataset_config(train_config, kwargs):
names = tuple(DATASET_PREPROC.keys())
assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]
update_config(dataset_config, **kwargs)
return dataset_config