From 0b66bd061dd4198fcde0a94f894d627d688f2ceb Mon Sep 17 00:00:00 2001 From: Mostafa Farrag Date: Tue, 20 Aug 2024 23:24:06 +0200 Subject: [PATCH] change the model to gemma --- tests/conftest.py | 10 ++++++++++ tests/test_chat_model.py | 18 ++++++++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e54db31..eddf28b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,3 +10,13 @@ def manual_pdf() -> List[str]: @pytest.fixture(scope="module") def num_manual_pages() -> int: return 40 + + +@pytest.fixture(scope="module") +def model_id() -> str: + return "google/gemma-2b-it" + + +def is_running_in_github_actions(): + """Check if the tests are running in GitHub Actions.""" + return os.getenv("GITHUB_ACTIONS") == "true" diff --git a/tests/test_chat_model.py b/tests/test_chat_model.py index 95521e0..e0ac87b 100644 --- a/tests/test_chat_model.py +++ b/tests/test_chat_model.py @@ -1,8 +1,8 @@ import os import pytest from serapeum.chat_model import ChatModel -from transformers.models.bert.modeling_bert import BertLMHeadModel -from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM +from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast # Retrieve the Hugging Face token from environment variables huggingface_token = os.getenv("HUGGINGFACE_TOKEN") @@ -11,24 +11,22 @@ @pytest.fixture(scope="module") -def test_create_chat_model() -> ChatModel: +def test_create_chat_model(model_id: str) -> ChatModel: # first test without tokenizer_kwargs - ChatModel( - model_id="bert-base-uncased", access_token=huggingface_token, is_decoder=True - ) + ChatModel(model_id=model_id, access_token=huggingface_token, is_decoder=True) tokenizer_kwargs = {"clean_up_tokenization_spaces": True} chat_mode = ChatModel( - model_id="bert-base-uncased", + model_id=model_id, access_token=huggingface_token, tokenizer_kwargs=tokenizer_kwargs, is_decoder=True, ) - assert chat_mode.model_id == "bert-base-uncased" + assert chat_mode.model_id == model_id assert chat_mode.device == "cpu" - assert isinstance(chat_mode.model, BertLMHeadModel) - assert isinstance(chat_mode.tokenizer, BertTokenizerFast) + assert isinstance(chat_mode.model, GemmaForCausalLM) + assert isinstance(chat_mode.tokenizer, GemmaTokenizerFast) return chat_mode