Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' (#57)
Browse files Browse the repository at this point in the history
* update PyTorch version to 2.1.0.dev20230523 (llvm#2148)

- torch version: 2.1.0.dev20230523
 - torch commit hash: 981d4c2578d10d8a96d173471802fc2812541fb1
 - torchvision version: 0.16.0.dev20230523

Co-authored-by: Roll PyTorch Action <[email protected]>

* [Torch Dialect] Add split.tensor support + recompose rules (llvm#2102)

* add split.tensor support + recompose rules

* add e2e test

* address comments

* address comments

* erase op in recomposeOp

---------

Co-authored-by: zhekun.zhang <[email protected]>

* [Stablehlo] Add `AtenIndexTensor` StableHlo support (llvm#2107)

* Add AtenIndexTensor StableHlo support

* clean up

* Empty commit, trigger test

* try to debug hanging test

* fix segfulat

* fix bad include

---------

Co-authored-by: zhekun.zhang <[email protected]>

* [arm64] Fix release builds for ARM64 (llvm#2157)

Tested on Ubuntu 23.04 on Ampere Altra instance.

* [Stablehlo] Add aten.uniform lowering (llvm#2101)

* add uniform stablehlo lowering

* add unit test

* new line

* rm redundant file

* Empty commit, trigger test

* fix include

* address comments

---------

Co-authored-by: zhekun.zhang <[email protected]>

* update PyTorch version to 2.1.0.dev20230525 (llvm#2167)

- torch version: 2.1.0.dev20230525
 - torch commit hash: eb2ef134b4e834a9b8a8b6de86ddd7d2780ce0ac
 - torchvision version: 0.16.0.dev20230525

Co-authored-by: Roll PyTorch Action <[email protected]>

* CI: disable caching for release builds (llvm#2168)

This patch adds a (default-true) input called `cache-enabled` to the
setup-build action, so that when the input is false, ccache is not setup
on the host machine.  This patch also sets the input to be false for the
release builds.

* Add alias analysis for cast-like ops to maximize-value-semantics (llvm#2160)

When `use_tracing=True` is used to import a model into Torch-MLIR,
several casts get inserted in the IR to bridge the untyped inputs and
outputs with the typed body of the computation. These casts create
extra aliases of tensors that cause the current analysis in
`maximize-value-semantics` to fail.

In particular, the `maximize-value-semantics` analysis assumes that the
only valid alias right after an overwrite is the overwritten
alias. So, if there is a use of a casted version of the overwritten
alias after the overwrite, the analysis fails.

This commit improves the analysis by identifying all cast-like aliases
of the overwritten alias and allowing such aliases to be used after an
overwrite.

Because this issue only arises when using tracing, it cannot be
currently tested e2e, so only lit test is added.

* only setup python for non-docker platforms (llvm#2171)

Original PR was accidentally merged to a branch. Re-landing same PR to main now

* Remove spurious pip in Release builds (llvm#2172)

(left over from a previous commit that was approved and landed in a branch on accident)

* [Torch Op] Add AtenChunkOp support (llvm#2152)

* add chunkOp support

* update LTC xfail list

* address comments

* address comments

---------

Co-authored-by: zhekun.zhang <[email protected]>

* Add ARM64 release builds (llvm#2159)

Creates a build_linux_arm64 job that builds the release on an arm64 self-hosted runner.
Drop Python 3.10 support
Pass  TM_TORCH_VERSION to choose the Stable PyTorch version (since arm64 doesn't have nightly builds)

Borrows nightly / stable Pytorch switch from the WIP
llvm#2038

* Delete another spurious pip (llvm#2173)

* update PyTorch version to 2.1.0.dev20230526 (llvm#2175)

- torch version: 2.1.0.dev20230526
 - torch commit hash: 10b46f7c7f69f9bf705d2b6ea53efb9c59145685
 - torchvision version: 0.16.0.dev20230526

Co-authored-by: Roll PyTorch Action <[email protected]>

* [Stablehlo] Enable Stablehlo backend with arith dialect (llvm#2139)

* Add correct type checking for tm_tensor.attention

* [TM_TENSOR] Add `aten.scatter.[src|value]` op

This commit adds support of `aten.scatter.src` and `aten.scatter.value`
ops.

Signed-Off-by: Gaurav Shukla <[email protected]>

* [MLIR][TORCH] Add support for the total_weight for aten.nll_loss_forward op

Signed-Off By: Vivek Khandelwal <[email protected]>

* Add Stable PyTorch CI Pipeline (llvm#2038)

* feat: split pytorch requirements into stable and nightly

* fix: add true to tests to see full output

* refactor: add comments to explain true statement

* feat: move some tests to experimental mode

* refactor: refactor pipeline into more fine grained difference

* feat: add version differentiation for some tests

* feat: activate more configs

* refactor: change implementation to use less requirement files

* refactor: remove contraints used for testing

* fix: revert some requirement file names

* refactor: remove unnecessary ninja install

* fix: fix version parsing

* refactor: remove dependency on torchvision in main requirements file

* refactor: remove index url

* style: remove unnecesary line switch

* fix: readd index url

* Add `ReadOnly` trait to `copy.to_vtensor` (llvm#2179)

Before inlining a global slot, the users of the global slot are
checked to see if they are `ReadOnly` or `MemoryEffectFree` to make
sure that the global slot is not being mutated. Because the op
`copy.to_vtensor` currently does not have the `ReadOnly` trait, if a
global slot is passed to `copy.to_vtensor`, the pass
`InlineGlobalSlots` will fail.

The op `copy.to_vtensor` is `ReadOnly`, since it does not modify the
contents of the input tensor; it simply makes a new copy. This commit
adds the trait as well as an e2e test that generates the case of a
global slot being passed to a `copy.to_vtensor`.

* [Importer] import constant tuple (llvm#2132)

* [Importer] import constant tuple

* update

* update

* update

* update PyTorch version to 2.1.0.dev20230531 (llvm#2188)

- torch version: 2.1.0.dev20230531
 - torch commit hash: 48552338649ccc467060f5f93dbe19e2acbc4d1a
 - torchvision version: 0.16.0.dev20230531

Co-authored-by: Roll PyTorch Action <[email protected]>

* [Torch Dialect] Add support for AtenScalarTensorOp (llvm#2085)

* add scalar_tensor op

* add dynamo pass test; needs PR2062

* try to fix

* Empty commit, trigger test

* Empty commit, trigger test

* address comments

* use dtype function

* fix decompose rule

* remove unused include

* Empty commit, trigger test

* fix test

* disable ltc

* fix dtype

---------

Co-authored-by: zhekun.zhang <[email protected]>

---------

Signed-off-by: Gaurav Shukla <[email protected]>
Co-authored-by: Sean Silva <[email protected]>
Co-authored-by: Roll PyTorch Action <[email protected]>
Co-authored-by: Zhekun Zhang <[email protected]>
Co-authored-by: zhekun.zhang <[email protected]>
Co-authored-by: powderluv <[email protected]>
Co-authored-by: Ashay Rane <[email protected]>
Co-authored-by: Ramiro Leal-Cavazos <[email protected]>
Co-authored-by: Yuanqiang Liu <[email protected]>
Co-authored-by: George Petterson <[email protected]>
Co-authored-by: Gaurav Shukla <[email protected]>
Co-authored-by: Vivek Khandelwal <[email protected]>
Co-authored-by: maxbartel <[email protected]>
  • Loading branch information
13 people authored Jun 2, 2023
1 parent b25c53a commit e53d054
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 22 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ jobs:
llvm-build: out-of-tree
- os-arch: windows-x86_64
torch-version: stable
# For PyTorch stable builds, we don't build PyTorch from source
- torch-version: stable
torch-binary: OFF
include:
# Specify OS versions
- os-arch: ubuntu-x86_64
Expand Down Expand Up @@ -88,7 +91,7 @@ jobs:
arch: x64

- name: Try to Restore PyTorch Build Cache
if: matrix.os-arch != 'windows-x86_64'
if: ${{ matrix.torch-binary == 'OFF' }}
id: cache-pytorch
uses: actions/cache/restore@v3
with:
Expand Down Expand Up @@ -146,7 +149,7 @@ jobs:
run: ./build_tools/python_deploy/build_windows_ci.sh

- name: Save PyTorch Build Cache
if: ${{ github.ref_name == 'main' && matrix.torch-binary == 'OFF' && matrix.os-arch != 'windows-x86_64' }}
if: ${{ github.ref_name == 'main' && matrix.torch-binary == 'OFF' }}
uses: actions/cache/save@v3
with:
path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse
Expand Down
8 changes: 6 additions & 2 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,10 @@
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
"ScalarTensorIntModule_basic",
"ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic",
"SelectScattertModule_basic",
"SelectScattertStaticModule_basic",
"SliceStaticModule_basic",
Expand Down Expand Up @@ -1135,8 +1137,10 @@
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"DetachModule_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
"ScalarTensorIntModule_basic",
"ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"TensorsConcatStaticModule_basic",
Expand Down
14 changes: 7 additions & 7 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7188,22 +7188,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.union<float, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._shape_as_tensor\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
Expand Down Expand Up @@ -8573,6 +8568,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple<int, int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int11 = torch.constant.int 11\n"
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4379,10 +4379,12 @@ class DecomposeAtenScalarTensor : public OpRewritePattern<AtenScalarTensorOp> {

Value cstNone = rewriter.create<ConstantNoneOp>(op.getLoc());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
Value dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), resultTy.getDtype());
Value toDTypeLayout = rewriter.create<AtenToDtypeLayoutOp>(
op.getLoc(), resultTy, numToTensor, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), /*non_blocking*/ cstFalse,
/*copy*/ cstFalse, /*memory_format*/ cstNone);
op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(),
op.getDevice(), op.getPinMemory(), /*non_blocking=*/cstFalse,
/*copy=*/cstFalse, /*memory_format=*/cstNone);
rewriter.replaceOp(op, toDTypeLayout);
return success();
}
Expand Down
6 changes: 5 additions & 1 deletion lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
return failure();

Value newStart = sliceOp.getStart();
Value newEnd = sliceOp.getEnd();
Value dimSize = rewriter.create<AtenSizeIntOp>(
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
Expand All @@ -56,6 +57,9 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
}
newEnd = rewriter.create<PrimMinIntOp>(op.getLoc(), newEnd, dimSize);

newStart = rewriter.create<PrimMinIntOp>(op.getLoc(), newStart, dimSize);
newEnd = rewriter.create<PrimMinIntOp>(op.getLoc(), newEnd, dimSize);

Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);

Expand All @@ -64,7 +68,7 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
Type rangeType = tensorType.getWithSizesAndDtype(
{kUnknownSize}, tensorType.getOptionalDtype());
Value range = rewriter.create<AtenArangeStartStepOp>(
op.getLoc(), rangeType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
op.getLoc(), rangeType, newStart, newEnd, sliceOp.getStep(),
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
/*pin_memory=*/noneVal);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -774,11 +774,12 @@ def aten〇tensor〇bool〡shape(t: bool, dtype: Optional[int] = None, device: O
def aten〇scalar_tensor〡shape(s: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return []

@check_dtype_function([Invocation(-1), Invocation(-1.0)])
def aten〇scalar_tensor〡dtype(s: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is not None:
return dtype
else:
return get_dtype_of_scalar(s)
return torch.float32

@check_shape_function([
Invocation(TensorOfShape()),
Expand Down
51 changes: 48 additions & 3 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3965,7 +3965,29 @@ def ScalarTensorFloat32Module_basic(module, tu: TestUtils):
# ==============================================================================


class ScalarTensorIntModule(torch.nn.Module):
class ScalarTensorDefaultDtypeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
scalar = torch.ops.aten.scalar_tensor(1.0)
return scalar


@register_test_case(module_factory=lambda: ScalarTensorDefaultDtypeModule())
def ScalarTensorDefaultDtypeModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


class ScalarTensorInt64Module(torch.nn.Module):

def __init__(self):
super().__init__()
Expand All @@ -3979,10 +4001,33 @@ def forward(self):
return scalar


@register_test_case(module_factory=lambda: ScalarTensorIntModule())
def ScalarTensorIntModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: ScalarTensorInt64Module())
def ScalarTensorInt64Module_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


class ScalarTensorInt32Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
scalar = torch.ops.aten.scalar_tensor(1, dtype=torch.int32)
return scalar


@register_test_case(module_factory=lambda: ScalarTensorInt32Module())
def ScalarTensorInt32Module_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


Expand Down
24 changes: 24 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,30 @@ def SliceCopyMax_Module_basic(module, tu: TestUtils):
# ==============================================================================


class SliceCopyStartGreaterThanDimSize_Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y):
xslice = torch.ops.aten.slice(x, 0, 100, 10, 1)
xslice.copy_(y)
return x


@register_test_case(module_factory=lambda: SliceCopyStartGreaterThanDimSize_Module())
def SliceCopyStartGreaterThanDimSize_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(0, 4, 4))


# ==============================================================================


class SliceCopyEndGreaterThanDimSize_Module(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
10b46f7c7f69f9bf705d2b6ea53efb9c59145685
a14be7981bcef6186441a6c5780976e27e6246ea
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.1.0.dev20230526
torch==2.1.0.dev20230601
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.16.0.dev20230526
torchvision==0.16.0.dev20230601

0 comments on commit e53d054

Please sign in to comment.