Skip to content

Commit

Permalink
[pytorch] add matmul sample
Browse files Browse the repository at this point in the history
  • Loading branch information
Avimitin committed Aug 9, 2024
1 parent 922fc85 commit 1908b0c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/pytorch/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

let

builder = makeBuilder { casePrefix = "mlir"; };
builder = makeBuilder { casePrefix = "pytorch"; };
build = { caseName, sourcePath }:
let
buddyBuildConfig = import (sourcePath + "/config.nix");
Expand Down
30 changes: 30 additions & 0 deletions tests/pytorch/matmul/config.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
includes = [
../memref.hpp
];

buddyOptArgs = [
[
"--pass-pipeline"
"builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)"
]
[
"--pass-pipeline"
"builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers))"
]
[
"--lower-affine"
"--convert-math-to-llvm"
"--convert-math-to-libm"
"--convert-scf-to-cf"
"--convert-arith-to-llvm"
"--expand-strided-metadata"
"--finalize-memref-to-llvm"
"--lower-vector-exp"
"--lower-rvv=rv32"
"--convert-vector-to-llvm"
"--convert-func-to-llvm"
"--reconcile-unrealized-casts"
]
];
}
22 changes: 22 additions & 0 deletions tests/pytorch/matmul/matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "memref.hpp"

extern "C" void _mlir_ciface_forward(MemRef<float, 1> *output,
MemRef<float, 1> *arg1,
MemRef<float, 1> *arg2);

// One-dimension, with length 512
static const int32_t sizes[3] = {8, 8, 8};

__attribute((section(".vdata"))) float input_float_1[512];
MemRef<float, 1> input1(input_float_1, sizes);

__attribute((section(".vdata"))) float input_float_2[512];
MemRef<float, 1> input2(input_float_2, sizes);

__attribute((section(".vdata"))) float output_float_1[512];
MemRef<float, 1> output(output_float_1, sizes);

extern "C" int test() {
_mlir_ciface_forward(&output, &input1, &input2);
return 0;
}
26 changes: 26 additions & 0 deletions tests/pytorch/matmul/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa

# Define the input data.
float32_in1 = torch.randn(8, 8, 8).to(torch.float32)
float32_in2 = torch.randn(8, 8, 8).to(torch.float32)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

# Pass the function and input data to the dynamo compiler's importer, the
# importer will first build a graph. Then, lower the graph to top-level IR.
# (tosa, linalg, etc.). Finally, accepts the generated module and weight parameters.
graphs = dynamo_compiler.importer(torch.matmul, *(float32_in1, float32_in2))
graph = graphs[0]
graph.lower_to_top_level_ir()

with open("forward.mlir", "w") as mlir_module:
print(graph._imported_module, file = mlir_module)

0 comments on commit 1908b0c

Please sign in to comment.