From d4079889ad0f019914d1351d769ee9568b0da6c6 Mon Sep 17 00:00:00 2001 From: Nikola Vukobrat <124874832+nvukobratTT@users.noreply.github.com> Date: Wed, 29 Jan 2025 10:31:24 +0100 Subject: [PATCH] Update ResNet50 E2E test to represent a valid demo version with actual predictions (#1126) ### Ticket ### Problem description Update ResNet50 E2E test to represent a valid demo version with actual predictions ### What's changed - Loading dataset (currently single sample) & labels dictionary - Comparing CPU vs TT label predictions ### Checklist - [x] New/Existing tests provide coverage for changes --- .../pytorch/vision/resnet/test_resnet.py | 64 +++++++++---------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/forge/test/models/pytorch/vision/resnet/test_resnet.py b/forge/test/models/pytorch/vision/resnet/test_resnet.py index 8d36baec9..36e62dae7 100644 --- a/forge/test/models/pytorch/vision/resnet/test_resnet.py +++ b/forge/test/models/pytorch/vision/resnet/test_resnet.py @@ -5,11 +5,12 @@ import requests import timm import torch +from datasets import load_dataset from loguru import logger from PIL import Image from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform -from transformers import AutoFeatureExtractor, ResNetForImageClassification +from transformers import AutoImageProcessor, ResNetForImageClassification import forge from forge.verify.verify import verify @@ -17,33 +18,16 @@ from test.models.utils import Framework, Source, Task, build_module_name from test.utils import download_model - -def generate_model_resnet_imgcls_hf_pytorch(variant): - # Load ResNet feature extractor and model checkpoint from HuggingFace - model_ckpt = variant - feature_extractor = download_model(AutoFeatureExtractor.from_pretrained, model_ckpt) - model = download_model(ResNetForImageClassification.from_pretrained, model_ckpt) - - # Load data sample - try: - url = "https://images.rawpixel.com/image_1300/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIyLTA1L3BkMTA2LTA0Ny1jaGltXzEuanBn.jpg" - image = Image.open(requests.get(url, stream=True).raw) - except: - logger.warning( - "Failed to download the image file, replacing input with random tensor. Please check if the URL is up to date" - ) - image = torch.rand(1, 3, 256, 256) - - # Data preprocessing - inputs = feature_extractor(image, return_tensors="pt") - pixel_values = inputs["pixel_values"] - - return model, [pixel_values], {} +variants = [ + "microsoft/resnet-50", +] +@pytest.mark.push @pytest.mark.nightly -def test_resnet(record_forge_property): - # Build Module Name +@pytest.mark.parametrize("variant", variants, ids=variants) +def test_resnet_hf(variant, record_forge_property): + # Record model properties module_name = build_module_name( framework=Framework.PYTORCH, model="resnet", @@ -51,19 +35,31 @@ def test_resnet(record_forge_property): source=Source.HUGGINGFACE, task=Task.IMAGE_CLASSIFICATION, ) - - # Record Forge Property record_forge_property("model_name", module_name) - framework_model, inputs, _ = generate_model_resnet_imgcls_hf_pytorch( - "microsoft/resnet-50", - ) + # Load dataset + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] - # Forge compile framework model - compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name) + # Load Torch model, preprocess image, and label dictionary + processor = download_model(AutoImageProcessor.from_pretrained, variant) + framework_model = download_model(ResNetForImageClassification.from_pretrained, variant, return_dict=False) + label_dict = framework_model.config.id2label - # Model Verification - verify(inputs, framework_model, compiled_model) + inputs = processor(image, return_tensors="pt") + inputs = inputs["pixel_values"] + + compiled_model = forge.compile(framework_model, inputs) + + cpu_logits = framework_model(inputs)[0] + cpu_pred = label_dict[cpu_logits.argmax(-1).item()] + + tt_logits = compiled_model(inputs)[0] + tt_pred = label_dict[tt_logits.argmax(-1).item()] + + assert cpu_pred == tt_pred, f"Inference mismatch: CPU prediction: {cpu_pred}, TT prediction: {tt_pred}" + + verify([inputs], framework_model, compiled_model) def generate_model_resnet_imgcls_timm_pytorch(variant):