From caf96579e3e2aae6fa85d04e9cd1347a4e370d2e Mon Sep 17 00:00:00 2001 From: Ashok Kumar Kannan <160501980+ashokkumarkannan1@users.noreply.github.com> Date: Mon, 20 Jan 2025 11:00:41 +0000 Subject: [PATCH] Add llava model bringup --- .../pytorch/multimodal/llava/__init__.py | 0 .../pytorch/multimodal/llava/test_llava.py | 65 +++++++++++++++++++ .../multimodal/llava/utils/__init__.py | 4 ++ .../pytorch/multimodal/llava/utils/utils.py | 39 +++++++++++ forge/test/models/utils.py | 1 + 5 files changed, 109 insertions(+) create mode 100644 forge/test/models/pytorch/multimodal/llava/__init__.py create mode 100644 forge/test/models/pytorch/multimodal/llava/test_llava.py create mode 100644 forge/test/models/pytorch/multimodal/llava/utils/__init__.py create mode 100644 forge/test/models/pytorch/multimodal/llava/utils/utils.py diff --git a/forge/test/models/pytorch/multimodal/llava/__init__.py b/forge/test/models/pytorch/multimodal/llava/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/forge/test/models/pytorch/multimodal/llava/test_llava.py b/forge/test/models/pytorch/multimodal/llava/test_llava.py new file mode 100644 index 000000000..84b346bdb --- /dev/null +++ b/forge/test/models/pytorch/multimodal/llava/test_llava.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import torch +from transformers import AutoProcessor, LlavaForConditionalGeneration + +import forge +from forge.verify.verify import verify + +from .utils import load_inputs +from test.models.utils import Framework, Source, Task, build_module_name + + +class Wrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids, attention_mask, pixel_values): + inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} + output = self.model(**inputs) + return output.logits + + +def load_model(variant): + processor = AutoProcessor.from_pretrained(variant) + model = LlavaForConditionalGeneration.from_pretrained(variant) + model = Wrapper(model) + return model, processor + + +variants = ["llava-hf/llava-1.5-7b-hf"] + + +@pytest.mark.nightly +@pytest.mark.parametrize("variant", variants, ids=variants) +def test_llava(record_forge_property, variant): + # Build Module Name + module_name = build_module_name( + framework=Framework.PYTORCH, + model="llava", + variant=variant, + task=Task.CONDITIONAL_GENERATION, + source=Source.HUGGINGFACE, + ) + + # Record Forge Property + record_forge_property("model_name", module_name) + + framework_model, processor = load_model(variant) + image = "https://www.ilankelman.org/stopsigns/australia.jpg" + text = "What’s shown in this image?" + + # Input sample + input_ids, attn_mask, pixel_values = load_inputs(image, text, processor) + inputs = [input_ids, attn_mask, pixel_values] + + # Forge compile framework model + compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name) + + # Model Verification + verify(inputs, framework_model, compiled_model) diff --git a/forge/test/models/pytorch/multimodal/llava/utils/__init__.py b/forge/test/models/pytorch/multimodal/llava/utils/__init__.py new file mode 100644 index 000000000..f51d90ac6 --- /dev/null +++ b/forge/test/models/pytorch/multimodal/llava/utils/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +from .utils import load_inputs, load_model diff --git a/forge/test/models/pytorch/multimodal/llava/utils/utils.py b/forge/test/models/pytorch/multimodal/llava/utils/utils.py new file mode 100644 index 000000000..396bad987 --- /dev/null +++ b/forge/test/models/pytorch/multimodal/llava/utils/utils.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import re + +import requests +from PIL import Image + + +def is_url(url): + regex = r"^(https?)://[^\s/$.?#].[^\s]*$" + return bool(re.match(regex, url)) + + +def load_inputs(inp_image, text, processor): + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": text}, + ], + } + ] + text_prompt = processor.apply_chat_template(conversation, padding=True, add_generation_prompt=True) + if is_url(inp_image): + image = Image.open(requests.get(inp_image, stream=True).raw) + else: + if os.path.isfile(inp_image): + image = Image.open(inp_image) + else: + raise ValueError("Input is neither a valid URL nor a valid file path.") + + inputs = processor(images=image, text=text_prompt, return_tensors="pt") + input_ids = inputs["input_ids"] + attn_mask = inputs["attention_mask"] + pixel_values = inputs["pixel_values"] + + return input_ids, attn_mask, pixel_values diff --git a/forge/test/models/utils.py b/forge/test/models/utils.py index 9ebb4e54d..3789f36a7 100644 --- a/forge/test/models/utils.py +++ b/forge/test/models/utils.py @@ -33,6 +33,7 @@ class Task(StrEnum): OBJECT_DETECTION = "obj_det" SEMANTIC_SEGMENTATION = "sem_seg" MASKED_IMAGE_MODELLING = "masked_img" + CONDITIONAL_GENERATION = "cond_gen" IMAGE_ENCODING = "img_enc" VISUAL_BACKBONE = "visual_bb"