Skip to content

Commit

Permalink
Add args to SFT example
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed Dec 11, 2023
1 parent d275cb4 commit 5dd8cf5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -73,6 +73,8 @@ class ScriptArguments:
},
)
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
mixed_precision: Optional[str] = field(default="bf16", metadata={"help": "Mixed precision training"})
target_modules: Optional[List[str]] = field(default=None, metadata={"help": "Target modules for LoRA adapters"})


parser = HfArgumentParser(ScriptArguments)
Expand Down Expand Up @@ -135,6 +137,7 @@ class ScriptArguments:
lora_alpha=script_args.peft_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
target_modules=script_args.target_modules,
)
else:
peft_config = None
Expand Down

0 comments on commit 5dd8cf5

Please sign in to comment.