diff --git a/forge/csrc/passes/lower_to_mlir.cpp b/forge/csrc/passes/lower_to_mlir.cpp index 64b814dc7..a0dd747dd 100644 --- a/forge/csrc/passes/lower_to_mlir.cpp +++ b/forge/csrc/passes/lower_to_mlir.cpp @@ -55,6 +55,7 @@ enum class TargetType { SourceType, UInt32, + Int64, }; struct AttributeRemap @@ -104,6 +105,7 @@ class AttributeMapper { add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32)); add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg")); + add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64)); // Add more default mappings here } @@ -234,6 +236,7 @@ class MLIRGenerator case TargetType::UInt32: TT_ASSERT(std::get(value) >= 0, "Value must be an >= 0 for conversion to uint32"); return builder_.getUI32IntegerAttr(static_cast(std::get(value))); + case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast(std::get(value))); default: // If type not handled, throw an exception throw std::runtime_error("Unhandled target type conversion"); @@ -608,6 +611,7 @@ class MLIRGenerator lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["cosine"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["cumsum"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["embedding_bw"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op; diff --git a/forge/forge/op/eltwise_unary.py b/forge/forge/op/eltwise_unary.py index 3f54f11e4..e28585481 100644 --- a/forge/forge/op/eltwise_unary.py +++ b/forge/forge/op/eltwise_unary.py @@ -453,7 +453,7 @@ def Tanh(name: str, operandA: Tensor) -> Tensor: return op("tanh", name, operandA).get_tensor() -def CumSum(name: str, operandA: Tensor, axis: int, exclusive: bool = False) -> Tensor: +def CumSum(name: str, operandA: Tensor, dim: int) -> Tensor: """ Cumulative sum operation. @@ -483,9 +483,7 @@ def CumSum(name: str, operandA: Tensor, axis: int, exclusive: bool = False) -> T Forge tensor """ - assert not exclusive, "Currently not supported" - - return op("cumsum", name, operandA, axis=axis, exclusive=exclusive).get_tensor() + return op("cumsum", name, operandA, dim=dim).get_tensor() def LogicalNot(name: str, operandA: Tensor) -> Tensor: diff --git a/forge/forge/op/eval/forge/cumulativesum.py b/forge/forge/op/eval/forge/cumulativesum.py index f72f3fc01..6718454b7 100644 --- a/forge/forge/op/eval/forge/cumulativesum.py +++ b/forge/forge/op/eval/forge/cumulativesum.py @@ -19,17 +19,16 @@ class CumulativeSum(PyEltwiseUnaryOp): @classmethod - def create(cls, axis, exclusive=False): + def create(cls, dim): self = cls("cumsum") - self.axis = axis - self.exclusive = exclusive + self.dim = dim[0] return self def eval(self, tensors): assert len(tensors) == 1, "Cumulative Sum should have one input" shape = tensors[0].shape original_types = [o.dtype for o in tensors] - ret = torch.cumsum(tensors[0], dim=self.axis) + ret = torch.cumsum(tensors[0], dim=self.dim) if ret.dtype != original_types[0]: ret = ret.type(original_types[0]) @@ -44,7 +43,7 @@ def shape(self, tensor_shapes): def backward(self, ac, operand, inputs, output, grad): assert len(inputs) == 1, "Cumulative Sum should have one input" assert operand == 0, "Invalid operand index" - dim = self.axis + dim = self.dim assert dim == 0, "Unsupported dim different then 0 for cumulative sum backward pass" if dim == 0: return ac.op(Nop.create(), (grad,)) diff --git a/forge/forge/tvm_to_python.py b/forge/forge/tvm_to_python.py index fe4e00af0..b7013b096 100644 --- a/forge/forge/tvm_to_python.py +++ b/forge/forge/tvm_to_python.py @@ -661,19 +661,10 @@ def populate_cumsum_args(graph, nid, compiler_cfg): axis = node["attrs"]["axis"][0][0] args.append( ( - "axis", + "dim", f"{axis}", ) ) - - exclusive = node["attrs"]["exclusive"][0][0] - args.append( - ( - "exclusive", - f"{exclusive}", - ) - ) - return args diff --git a/forge/test/mlir/operators/eltwise_unary/test_eltwise_unary.py b/forge/test/mlir/operators/eltwise_unary/test_eltwise_unary.py index 4022c2e76..6494913c9 100644 --- a/forge/test/mlir/operators/eltwise_unary/test_eltwise_unary.py +++ b/forge/test/mlir/operators/eltwise_unary/test_eltwise_unary.py @@ -348,10 +348,27 @@ def forward(self, x): @pytest.mark.parametrize( "shape, dim", [ + ((56), 0), ((1, 128), 1), + pytest.param( + (1, 64, 76), + 2, + marks=pytest.mark.xfail(reason="ValueError: Data mismatch -> AutomaticValueChecker (compare_with_golden)"), + ), + pytest.param( + (1, 64, 76, 96), + 3, + marks=pytest.mark.xfail(reason="ValueError: Data mismatch -> AutomaticValueChecker (compare_with_golden)"), + ), + pytest.param( + (1, 64, 86, 100, 120), + 4, + marks=pytest.mark.xfail( + reason=" RuntimeError: (dim >= 0 && dim <= 3),info: dim should be 0 - 3, but got: 4" + ), + ), ], ) -@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph") @pytest.mark.push def test_cumsum(shape, dim): class CumSum(nn.Module): diff --git a/forge/test/models/pytorch/multimodal/llava/__init__.py b/forge/test/models/pytorch/multimodal/llava/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/forge/test/models/pytorch/multimodal/llava/test_llava.py b/forge/test/models/pytorch/multimodal/llava/test_llava.py new file mode 100644 index 000000000..84b346bdb --- /dev/null +++ b/forge/test/models/pytorch/multimodal/llava/test_llava.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import torch +from transformers import AutoProcessor, LlavaForConditionalGeneration + +import forge +from forge.verify.verify import verify + +from .utils import load_inputs +from test.models.utils import Framework, Source, Task, build_module_name + + +class Wrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids, attention_mask, pixel_values): + inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} + output = self.model(**inputs) + return output.logits + + +def load_model(variant): + processor = AutoProcessor.from_pretrained(variant) + model = LlavaForConditionalGeneration.from_pretrained(variant) + model = Wrapper(model) + return model, processor + + +variants = ["llava-hf/llava-1.5-7b-hf"] + + +@pytest.mark.nightly +@pytest.mark.parametrize("variant", variants, ids=variants) +def test_llava(record_forge_property, variant): + # Build Module Name + module_name = build_module_name( + framework=Framework.PYTORCH, + model="llava", + variant=variant, + task=Task.CONDITIONAL_GENERATION, + source=Source.HUGGINGFACE, + ) + + # Record Forge Property + record_forge_property("model_name", module_name) + + framework_model, processor = load_model(variant) + image = "https://www.ilankelman.org/stopsigns/australia.jpg" + text = "What’s shown in this image?" + + # Input sample + input_ids, attn_mask, pixel_values = load_inputs(image, text, processor) + inputs = [input_ids, attn_mask, pixel_values] + + # Forge compile framework model + compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name) + + # Model Verification + verify(inputs, framework_model, compiled_model) diff --git a/forge/test/models/pytorch/multimodal/llava/utils/__init__.py b/forge/test/models/pytorch/multimodal/llava/utils/__init__.py new file mode 100644 index 000000000..f49890762 --- /dev/null +++ b/forge/test/models/pytorch/multimodal/llava/utils/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +from .utils import load_inputs diff --git a/forge/test/models/pytorch/multimodal/llava/utils/utils.py b/forge/test/models/pytorch/multimodal/llava/utils/utils.py new file mode 100644 index 000000000..396bad987 --- /dev/null +++ b/forge/test/models/pytorch/multimodal/llava/utils/utils.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import re + +import requests +from PIL import Image + + +def is_url(url): + regex = r"^(https?)://[^\s/$.?#].[^\s]*$" + return bool(re.match(regex, url)) + + +def load_inputs(inp_image, text, processor): + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": text}, + ], + } + ] + text_prompt = processor.apply_chat_template(conversation, padding=True, add_generation_prompt=True) + if is_url(inp_image): + image = Image.open(requests.get(inp_image, stream=True).raw) + else: + if os.path.isfile(inp_image): + image = Image.open(inp_image) + else: + raise ValueError("Input is neither a valid URL nor a valid file path.") + + inputs = processor(images=image, text=text_prompt, return_tensors="pt") + input_ids = inputs["input_ids"] + attn_mask = inputs["attention_mask"] + pixel_values = inputs["pixel_values"] + + return input_ids, attn_mask, pixel_values diff --git a/forge/test/models/utils.py b/forge/test/models/utils.py index 9ebb4e54d..3789f36a7 100644 --- a/forge/test/models/utils.py +++ b/forge/test/models/utils.py @@ -33,6 +33,7 @@ class Task(StrEnum): OBJECT_DETECTION = "obj_det" SEMANTIC_SEGMENTATION = "sem_seg" MASKED_IMAGE_MODELLING = "masked_img" + CONDITIONAL_GENERATION = "cond_gen" IMAGE_ENCODING = "img_enc" VISUAL_BACKBONE = "visual_bb" diff --git a/third_party/tt-mlir b/third_party/tt-mlir index 277836600..c86135d73 160000 --- a/third_party/tt-mlir +++ b/third_party/tt-mlir @@ -1 +1 @@ -Subproject commit 277836600c9e244bad6ce3139a7c9a2c781b255c +Subproject commit c86135d737af19099465c0dc80b5558956ed5ca4