Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 16, 2024
1 parent 57680a3 commit 31f96ba
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def download_pretrained_model(self) -> None:

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
filename = os.path.join(self.no_weights_model, "model.safetensors")
os.makedirs(self.no_weights_model, exist_ok=True)

if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)

filename = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=torch.nn.Linear(1, 1).state_dict(), filename=filename, metadata={"format": "pt"})
with fast_weights_init():
# unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model
Expand All @@ -59,11 +64,6 @@ def create_no_weights_model(self) -> None:
del self.pretrained_model
torch.cuda.empty_cache()

if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)

if self.config.task in TEXT_GENERATION_TASKS:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
Expand All @@ -78,11 +78,11 @@ def load_model_with_no_weights(self) -> None:
def load_model_from_pretrained(self) -> None:
if self.config.task in TEXT_GENERATION_TASKS:
self.pretrained_model = TGI(
config=TGIConfig(self.config.model, **self.txi_kwargs, **self.tgi_kwargs),
config=TGIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tgi_kwargs),
)
elif self.config.task in TEXT_EMBEDDING_TASKS:
self.pretrained_model = TEI(
config=TEIConfig(self.config.model, **self.txi_kwargs, **self.tei_kwargs),
config=TEIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tei_kwargs),
)
else:
raise NotImplementedError(f"TXI does not support task {self.config.task}")
Expand Down

0 comments on commit 31f96ba

Please sign in to comment.