From 1d638499f18a59a3044b245d3f87d1d66486e1e3 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 10 Dec 2024 13:37:24 -0800 Subject: [PATCH] overwrite pos for models loaded from hf (#836) --- src/levanter/main/sft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index b3ff0e74c..3f8329a2b 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -1,3 +1,4 @@ +import dataclasses import logging import os from dataclasses import dataclass, field @@ -99,6 +100,7 @@ def train(config: SFTConfig): converter = converter.replaced(tokenizer=tokenizer) model_config = converter.default_config + model_config = dataclasses.replace(converter.default_config, seq_len=config.max_seq_len) elif config.trainer.initialize_from is None: raise ValueError("Must specify either --initialize_from_hf or --initialize_from") else: