Skip to content

Commit

Permalink
change the model to gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
MAfarrag committed Aug 20, 2024
1 parent 19356d2 commit 0b66bd0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,13 @@ def manual_pdf() -> List[str]:
@pytest.fixture(scope="module")
def num_manual_pages() -> int:
return 40


@pytest.fixture(scope="module")
def model_id() -> str:
return "google/gemma-2b-it"


def is_running_in_github_actions():
"""Check if the tests are running in GitHub Actions."""
return os.getenv("GITHUB_ACTIONS") == "true"
18 changes: 8 additions & 10 deletions tests/test_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import pytest
from serapeum.chat_model import ChatModel
from transformers.models.bert.modeling_bert import BertLMHeadModel
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast

# Retrieve the Hugging Face token from environment variables
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
Expand All @@ -11,24 +11,22 @@


@pytest.fixture(scope="module")
def test_create_chat_model() -> ChatModel:
def test_create_chat_model(model_id: str) -> ChatModel:
# first test without tokenizer_kwargs
ChatModel(
model_id="bert-base-uncased", access_token=huggingface_token, is_decoder=True
)
ChatModel(model_id=model_id, access_token=huggingface_token, is_decoder=True)

tokenizer_kwargs = {"clean_up_tokenization_spaces": True}

chat_mode = ChatModel(
model_id="bert-base-uncased",
model_id=model_id,
access_token=huggingface_token,
tokenizer_kwargs=tokenizer_kwargs,
is_decoder=True,
)
assert chat_mode.model_id == "bert-base-uncased"
assert chat_mode.model_id == model_id
assert chat_mode.device == "cpu"
assert isinstance(chat_mode.model, BertLMHeadModel)
assert isinstance(chat_mode.tokenizer, BertTokenizerFast)
assert isinstance(chat_mode.model, GemmaForCausalLM)
assert isinstance(chat_mode.tokenizer, GemmaTokenizerFast)

return chat_mode

Expand Down

0 comments on commit 0b66bd0

Please sign in to comment.