Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 16, 2024
1 parent a8c4159 commit 9c1cd0c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
25 changes: 11 additions & 14 deletions optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List

import torch
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download, snapshot_download
from py_txi import TEI, TGI, TEIConfig, TGIConfig
from safetensors.torch import save_model

Expand Down Expand Up @@ -45,33 +46,29 @@ def download_pretrained_model(self) -> None:
self.generation_config.save_pretrained(save_directory=model_snapshot_folder)

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
filename = os.path.join(self.no_weights_model, "model.safetensors")
os.makedirs(self.no_weights_model, exist_ok=True)
model_path = Path(hf_hub_download(self.config.model, filename="config.json", cache_dir=self.tmpdir.name)).parent
save_model(model=torch.nn.Linear(1, 1), filename=model_path / "model.safetensors", metadata={"format": "pt"})
self.pretrained_processor.save_pretrained(save_directory=model_path)
self.pretrained_config.save_pretrained(save_directory=model_path)

self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)

save_model(model=torch.nn.Linear(1, 1), filename=filename, metadata={"format": "pt"})
with fast_weights_init():
# unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model
self.pretrained_model = self.automodel_loader.from_pretrained(
self.no_weights_model, **self.config.model_kwargs, device_map="auto", _fast_init=False
model_path, **self.config.model_kwargs, device_map="auto", _fast_init=False
)
save_model(model=self.pretrained_model, filename=filename, metadata={"format": "pt"})
save_model(model=self.pretrained_model, filename=model_path / "model.safetensors", metadata={"format": "pt"})
del self.pretrained_model
torch.cuda.empty_cache()

if self.config.task in TEXT_GENERATION_TASKS:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
self.generation_config.save_pretrained(save_directory=self.no_weights_model)
self.generation_config.save_pretrained(save_directory=model_path)

def load_model_with_no_weights(self) -> None:
self.config.volumes = {self.no_weights_model: {"bind": "/no_weights_model/", "mode": "rw"}}
original_model, self.config.model = self.config.model, "/no_weights_model/"
original_volumes, self.config.volumes = self.config.volumes, {self.tmpdir.name: {"bind": "/data", "mode": "rw"}}
self.load_model_from_pretrained()
self.config.model = original_model
self.config.volumes = original_volumes

def load_model_from_pretrained(self) -> None:
if self.config.task in TEXT_GENERATION_TASKS:
Expand Down
1 change: 1 addition & 0 deletions tests/configs/cuda_inference_py_txi_gpt2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defaults:
- _base_ # inherits from base config
- _cuda_ # inherits from cuda config
- _inference_ # inherits from inference config
- _no_weights_ # inherits from no weights config
- _gpt2_ # inherits from gpt2 config
- _self_ # hydra 1.1 compatibility
- override backend: py-txi
Expand Down

0 comments on commit 9c1cd0c

Please sign in to comment.