Skip to content

Commit

Permalink
Merge branch 'main' into akannan/op_support_maskedscatter
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokkumarkannan1 authored Feb 8, 2025
2 parents 6cc38c1 + 5b71d92 commit d828ed5
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 21 deletions.
4 changes: 4 additions & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum class TargetType
{
SourceType,
UInt32,
Int64,
};

struct AttributeRemap
Expand Down Expand Up @@ -104,6 +105,7 @@ class AttributeMapper
{
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32));
add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg"));
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64));

// Add more default mappings here
}
Expand Down Expand Up @@ -234,6 +236,7 @@ class MLIRGenerator
case TargetType::UInt32:
TT_ASSERT(std::get<int>(value) >= 0, "Value must be an >= 0 for conversion to uint32");
return builder_.getUI32IntegerAttr(static_cast<uint32_t>(std::get<int>(value)));
case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast<int64_t>(std::get<int>(value)));
default:
// If type not handled, throw an exception
throw std::runtime_error("Unhandled target type conversion");
Expand Down Expand Up @@ -608,6 +611,7 @@ class MLIRGenerator
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
lowering_handler_map["cosine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CosOp>;
lowering_handler_map["cumsum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CumSumOp>;
lowering_handler_map["embedding_bw"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingBackwardOp>;
lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingOp>;
Expand Down
6 changes: 2 additions & 4 deletions forge/forge/op/eltwise_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def Tanh(name: str, operandA: Tensor) -> Tensor:
return op("tanh", name, operandA).get_tensor()


def CumSum(name: str, operandA: Tensor, axis: int, exclusive: bool = False) -> Tensor:
def CumSum(name: str, operandA: Tensor, dim: int) -> Tensor:

"""
Cumulative sum operation.
Expand Down Expand Up @@ -483,9 +483,7 @@ def CumSum(name: str, operandA: Tensor, axis: int, exclusive: bool = False) -> T
Forge tensor
"""

assert not exclusive, "Currently not supported"

return op("cumsum", name, operandA, axis=axis, exclusive=exclusive).get_tensor()
return op("cumsum", name, operandA, dim=dim).get_tensor()


def LogicalNot(name: str, operandA: Tensor) -> Tensor:
Expand Down
9 changes: 4 additions & 5 deletions forge/forge/op/eval/forge/cumulativesum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@

class CumulativeSum(PyEltwiseUnaryOp):
@classmethod
def create(cls, axis, exclusive=False):
def create(cls, dim):
self = cls("cumsum")
self.axis = axis
self.exclusive = exclusive
self.dim = dim[0]
return self

def eval(self, tensors):
assert len(tensors) == 1, "Cumulative Sum should have one input"
shape = tensors[0].shape
original_types = [o.dtype for o in tensors]
ret = torch.cumsum(tensors[0], dim=self.axis)
ret = torch.cumsum(tensors[0], dim=self.dim)

if ret.dtype != original_types[0]:
ret = ret.type(original_types[0])
Expand All @@ -44,7 +43,7 @@ def shape(self, tensor_shapes):
def backward(self, ac, operand, inputs, output, grad):
assert len(inputs) == 1, "Cumulative Sum should have one input"
assert operand == 0, "Invalid operand index"
dim = self.axis
dim = self.dim
assert dim == 0, "Unsupported dim different then 0 for cumulative sum backward pass"
if dim == 0:
return ac.op(Nop.create(), (grad,))
Expand Down
11 changes: 1 addition & 10 deletions forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,19 +661,10 @@ def populate_cumsum_args(graph, nid, compiler_cfg):
axis = node["attrs"]["axis"][0][0]
args.append(
(
"axis",
"dim",
f"{axis}",
)
)

exclusive = node["attrs"]["exclusive"][0][0]
args.append(
(
"exclusive",
f"{exclusive}",
)
)

return args


Expand Down
19 changes: 18 additions & 1 deletion forge/test/mlir/operators/eltwise_unary/test_eltwise_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,27 @@ def forward(self, x):
@pytest.mark.parametrize(
"shape, dim",
[
((56), 0),
((1, 128), 1),
pytest.param(
(1, 64, 76),
2,
marks=pytest.mark.xfail(reason="ValueError: Data mismatch -> AutomaticValueChecker (compare_with_golden)"),
),
pytest.param(
(1, 64, 76, 96),
3,
marks=pytest.mark.xfail(reason="ValueError: Data mismatch -> AutomaticValueChecker (compare_with_golden)"),
),
pytest.param(
(1, 64, 86, 100, 120),
4,
marks=pytest.mark.xfail(
reason=" RuntimeError: (dim >= 0 && dim <= 3),info: dim should be 0 - 3, but got: 4"
),
),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
@pytest.mark.push
def test_cumsum(shape, dim):
class CumSum(nn.Module):
Expand Down
Empty file.
65 changes: 65 additions & 0 deletions forge/test/models/pytorch/multimodal/llava/test_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0


import pytest
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

import forge
from forge.verify.verify import verify

from .utils import load_inputs
from test.models.utils import Framework, Source, Task, build_module_name


class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids, attention_mask, pixel_values):
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
output = self.model(**inputs)
return output.logits


def load_model(variant):
processor = AutoProcessor.from_pretrained(variant)
model = LlavaForConditionalGeneration.from_pretrained(variant)
model = Wrapper(model)
return model, processor


variants = ["llava-hf/llava-1.5-7b-hf"]


@pytest.mark.nightly
@pytest.mark.parametrize("variant", variants, ids=variants)
def test_llava(record_forge_property, variant):
# Build Module Name
module_name = build_module_name(
framework=Framework.PYTORCH,
model="llava",
variant=variant,
task=Task.CONDITIONAL_GENERATION,
source=Source.HUGGINGFACE,
)

# Record Forge Property
record_forge_property("model_name", module_name)

framework_model, processor = load_model(variant)
image = "https://www.ilankelman.org/stopsigns/australia.jpg"
text = "What’s shown in this image?"

# Input sample
input_ids, attn_mask, pixel_values = load_inputs(image, text, processor)
inputs = [input_ids, attn_mask, pixel_values]

# Forge compile framework model
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
4 changes: 4 additions & 0 deletions forge/test/models/pytorch/multimodal/llava/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
from .utils import load_inputs
39 changes: 39 additions & 0 deletions forge/test/models/pytorch/multimodal/llava/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import re

import requests
from PIL import Image


def is_url(url):
regex = r"^(https?)://[^\s/$.?#].[^\s]*$"
return bool(re.match(regex, url))


def load_inputs(inp_image, text, processor):
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text},
],
}
]
text_prompt = processor.apply_chat_template(conversation, padding=True, add_generation_prompt=True)
if is_url(inp_image):
image = Image.open(requests.get(inp_image, stream=True).raw)
else:
if os.path.isfile(inp_image):
image = Image.open(inp_image)
else:
raise ValueError("Input is neither a valid URL nor a valid file path.")

inputs = processor(images=image, text=text_prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attn_mask = inputs["attention_mask"]
pixel_values = inputs["pixel_values"]

return input_ids, attn_mask, pixel_values
1 change: 1 addition & 0 deletions forge/test/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Task(StrEnum):
OBJECT_DETECTION = "obj_det"
SEMANTIC_SEGMENTATION = "sem_seg"
MASKED_IMAGE_MODELLING = "masked_img"
CONDITIONAL_GENERATION = "cond_gen"
IMAGE_ENCODING = "img_enc"
VISUAL_BACKBONE = "visual_bb"

Expand Down
2 changes: 1 addition & 1 deletion third_party/tt-mlir
Submodule tt-mlir updated 29 files
+5 −5 .github/build-docker-images.sh
+4 −1 .github/workflows/build-and-test.yml
+25 −0 include/ttmlir/Target/LLVM/LLVMToDynamicLib.h
+5 −1 lib/Dialect/TTIR/Transforms/Allocate.cpp
+5 −1 lib/Dialect/TTNN/Transforms/Passes.cpp
+5 −1 lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp
+121 −13 lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
+1 −1 lib/OpModel/TTNN/Conversion.cpp
+1 −1 lib/OpModel/TTNN/MetalHeaders.h
+1 −0 lib/SharedLib/CMakeLists.txt
+1 −0 lib/Target/CMakeLists.txt
+12 −0 lib/Target/LLVM/CMakeLists.txt
+292 −0 lib/Target/LLVM/LLVMToDynamicLib.cpp
+31 −0 lib/Target/LLVM/LLVMToDynamicLibRegistration.cpp
+1 −1 runtime/lib/ttnn/runtime.cpp
+15 −1 runtime/tools/python/ttrt/common/perf.py
+145 −10 test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir
+1 −1 test/ttmlir/Conversion/StableHLOToTTIR/mnist_inference.mlir
+1 −2 test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir
+29 −0 test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_hoist_call.mlir
+16 −0 test/ttmlir/Silicon/TTNN/llmbox/perf/all_gather.mlir
+17 −0 test/ttmlir/Silicon/TTNN/llmbox/perf/all_reduce.mlir
+16 −0 test/ttmlir/Silicon/TTNN/n300/perf/all_gather.mlir
+17 −0 test/ttmlir/Silicon/TTNN/n300/perf/all_reduce.mlir
+16 −0 test/ttmlir/Silicon/TTNN/tg/perf/all_gather.mlir
+129 −0 test/ttmlir/Translate/LLVM/dylib.mlir
+1 −1 third_party/CMakeLists.txt
+5 −0 tools/ttmlir-translate/ttmlir-translate.cpp
+1 −1 tools/ttnn-standalone/ttnn-precompiled.hpp

0 comments on commit d828ed5

Please sign in to comment.