Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tcp] Merge main into mlir-tcp #2518

Merged
merged 42 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
7a7be60
Fix python package install instructions (#2464)
sogartar Sep 14, 2023
b03efdf
build: manually update PyTorch version
vivekkhandelwal1 Sep 19, 2023
278c41e
Bump llvm-project to f66cd9e9556a53142a26a5c21a72e21f1579217c. (#2466)
stellaraccident Sep 19, 2023
20ea1c9
Revert accidental change to submodule origin. (#2477)
stellaraccident Sep 20, 2023
023fc90
[Torch Dialect] add avg_pool 2d and 3d op variants (#2473)
davidgens-cerebras Sep 20, 2023
b9847b1
Fixing implicit double to float casts. (#2476)
benvanik Sep 20, 2023
059041e
[LTC] Support torch.ones/zeros/arange ops (#2440)
GlebKazantaev Sep 21, 2023
6699cbc
build: manually update PyTorch version (#2480)
vivekkhandelwal1 Sep 22, 2023
5f772e8
CI: reconcile differences between RollPyTorch and pre-merge checks (#…
ashay Sep 23, 2023
a520d39
[MLIR][TORCH] Add device "cpu" support for aten.to.dtype_layout op (…
brucekimrokcmu Sep 25, 2023
c9fd789
[NFC] Clean-up `ConvertAtenViewOp` in linalg backend (#2470)
ramiro050 Sep 26, 2023
ff7f8b2
update llvm-project to d13da154a7c7eff77df8686b2de1cfdfa7cc7029 (#2483)
dan-garvey Sep 26, 2023
7760bda
build: manually update PyTorch version
vivekkhandelwal1 Sep 27, 2023
e69266a
update PyTorch version to 2.2.0.dev20230927 (#2489)
stellaraccident Sep 27, 2023
7c6b9d2
[linalg] Fix handling of trailing size-1 dimensions in aten.view (#2474)
ramiro050 Sep 27, 2023
8abfa5b
Use PyTorch nightly for Arm release build (#2488)
vivekkhandelwal1 Sep 27, 2023
4e1dd3b
add e2e support for torch.log10 (#2479)
saienduri Sep 28, 2023
860be09
Elide dynamic broadcast checks when in strict symbolic shapes mode. (…
stellaraccident Sep 29, 2023
71ac62f
build: manually update PyTorch version
vivekkhandelwal1 Sep 29, 2023
c434736
[MLIR][TORCH] Add support for conversion to int8 dtype
vivekkhandelwal1 Sep 29, 2023
9293326
[MLIR][TORCH] Add support for bitwise_right_shit and bitwise_and.Scal…
vivekkhandelwal1 Sep 28, 2023
b75c208
update PyTorch version to 2.2.0.dev20231002 (#2497)
stellaraccident Oct 2, 2023
d10a86f
Disable LTC for arm release
vivekkhandelwal1 Sep 28, 2023
32d9b20
Add linspace/cumprod/roll ops (#2498)
antoniojkim Oct 3, 2023
ca6ce89
[MLIR][TORCH] Add support for int8 dtype for sub, add, and bitwise_an…
vivekkhandelwal1 Oct 3, 2023
4892ed4
update PyTorch version to 2.2.0.dev20231003 (#2500)
stellaraccident Oct 3, 2023
1c508af
Revert "[linalg] Fix handling of trailing size-1 dimensions in aten.v…
ramiro050 Oct 3, 2023
2e5d650
[linalg] Add handling for leadin and trailing size-1 dims in ViewOp
ramiro050 Oct 3, 2023
14e6da8
update PyTorch version to 2.2.0.dev20231004 (#2502)
stellaraccident Oct 4, 2023
ae72eec
Improve aten.broadcast_to folder when in strict symbol mode (#2504)
qedawkins Oct 5, 2023
42b6c0a
update PyTorch version to 2.2.0.dev20231005 (#2506)
stellaraccident Oct 5, 2023
6f81ad7
[TorchToLinalg] Improve broadcast lowerings in strict symbolic modes …
qedawkins Oct 5, 2023
26ea13d
update PyTorch version to 2.2.0.dev20231006 (#2507)
stellaraccident Oct 6, 2023
9b5a4af
Update README to include new meeting schedule (#2503)
ramiro050 Oct 10, 2023
e649e06
Add aten.unflatten.int support and its torch-to-tosa lowering (#2509)
zezhang Oct 14, 2023
f2c53b8
Add aten.isclose support and its torch-to-tosa lowering (#2512)
zezhang Oct 16, 2023
14a4da9
Update llvm-project to b44b3494f60296db6aca38a14cab061d9b747a0a (#2511)
Oct 17, 2023
4279b75
update AtenClampOp in torch-to-tosa to handle fp inputs (#2516)
zezhang Oct 17, 2023
52abae1
Bump LLVM to get bazel fixes (#2517)
sjain-stanford Oct 18, 2023
86cf909
Merge branch 'main' into raghavanr/torch-mlir-upgrade
navahgar Oct 18, 2023
b846437
Fix the names of arith MaximumF and MinimumF ops
navahgar Oct 18, 2023
9624268
[Tcp] Add new e2e tests to pass list
navahgar Oct 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions .github/workflows/RollPyTorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,21 @@ jobs:
- name: Get torch-mlir
uses: actions/checkout@v3
with:
submodules: 'true'
submodules: 'false'
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}

- name: Get LLVM and StableHlo submodules
run: |
set -eo pipefail
cd ${GITHUB_WORKSPACE}

# Fetching the submodules concurrently may cause problems, so we fetch
# them one after another.
rm -f .git/modules/externals/llvm-project/index.lock
rm -f .git/modules/externals/stablehlo/index.lock
git submodule update --init --recursive externals/llvm-project
git submodule update --init --recursive externals/stablehlo

- name: Setup ccache
uses: ./.github/actions/setup-build
with:
Expand Down Expand Up @@ -71,15 +83,14 @@ jobs:
echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV}
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}

- name: Build and test (in-tree), also update ODS and abstract interpretation library
- name: Build and test (out-of-tree), also update ODS and abstract interpretation library
if: env.PT_HASH_CHANGED != '0'
run: |
cd ${GITHUB_WORKSPACE}
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \
TM_PYTHON_VERSIONS="cp311-cp311" \
./build_tools/python_deploy/build_linux_packages.sh

- name: Post issue comment on build failure
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/buildRelease.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ jobs:
cd $GITHUB_WORKSPACE
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TM_TORCH_VERSION="stable" ./build_tools/python_deploy/build_linux_packages.sh
TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh

# If we were given a release_id, then upload the package we just built
# to the github releases page.
Expand Down
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,17 @@ We have few paths to lower down to the Torch MLIR Dialect.
- `#torch-mlir` channel on the LLVM [Discord](https://discord.gg/xS7Z362) - this is the most active communication channel
- Github issues [here](https://github.com/llvm/torch-mlir/issues)
- [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse
- Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
- Weekly op office hours on Thursdays 8:30-9:30AM PST. See [here](https://discourse.llvm.org/t/announcing-torch-mlir-office-hours/63973/2) for more information.

### Meetings

Community Meeting / Developer Hour:
- 1st and 3rd Monday of the month at 9 am PST
- 2nd and 4th Monday of the month at 5 pm PST

Office Hours:
- Every Thursday at 8:30 am PST

Meeting links can be found [here](https://discourse.llvm.org/t/new-community-meeting-developer-hour-schedule/73868).

## Install torch-mlir snapshot

Expand All @@ -61,7 +70,7 @@ python -m pip install --upgrade pip
Then, we can install torch-mlir with the corresponding torch and torchvision nightlies.
```
pip install --pre torch-mlir torchvision \
-f https://llvm.github.io/torch-mlir/package-index/
-f https://llvm.github.io/torch-mlir/package-index/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
```

Expand Down
3 changes: 2 additions & 1 deletion build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ def gen_fallback_code(*args, **kwargs):
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
tensor_class=self.tensor_class,
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h",
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
lazy_ir_generator=GenMlirLazyIr,
)
Expand Down
35 changes: 7 additions & 28 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@ blacklist:
# It also doesn't have confusing `unsafe` argument.
- _index_put_impl

# Ops with list of tensors output
- split.Tensor
- split_with_sizes
- unbind.int
- chunk

# Additional ops which autogen is supported for but don't compile yet
- _convolution
- detach
Expand All @@ -18,42 +12,28 @@ blacklist:

# Disabled for consistency with TS backend
- lift_fresh_copy
- new_empty
- rsub
- slice.Tensor # Disabled in favour of slice_copy.Tensor
- zeros
- ones
- arange
- arange.start
- arange.start_step
- fill.Scalar
- scalar_tensor

# Disabled in favour of functionalized alternatives
- _reshape_alias
- expand
- permute
- select.int
- squeeze
- squeeze.dim
- t
- transpose.int
- expand
- squeeze
- unsqueeze
- view
- slice.Tensor
- split.Tensor
- split_with_sizes
- unbind.int

whitelist:
# Enabled for consistency with TS backend
- arange.start_out

# List of supported ops that we don't want to do the full codegen for
supported:
# - bernoulli
# - bernoulli_
- _to_copy
- clone
- empty.memory_format
- empty_strided
- fill_.Scalar
- _unsafe_view
- unbind_copy.int
- split_copy.Tensor
Expand All @@ -80,18 +60,17 @@ supported:
- _trilinear
- linalg_pinv.atol_rtol_tensor
- logsumexp.out
- t

# List of ops that will take in symints for the size instead of ints
symint:
- empty.memory_format
- new_empty_strided
- expand_copy
- narrow_copy
- slice_backward
- slice_copy.Tensor
- split_copy.Tensor
- slice_scatter
- view
- view_copy
- as_strided_copy
- as_strided_scatter
Expand Down
6 changes: 6 additions & 0 deletions build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ function run_in_docker() {
out-of-tree)
setup_venv "$python_version" "$TM_TORCH_VERSION"
build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" "$TM_TORCH_VERSION"
if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then
pushd /main_checkout/torch-mlir
TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_torch_ods.sh
TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_abstract_interp_lib.sh
popd
fi
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
test_out_of_tree
fi
Expand Down
28 changes: 21 additions & 7 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -288,6 +291,12 @@

# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
"AtenEmbeddingBagStaticModule_basic",

# Lowering not present for this case
"ElementwiseToDtypeI64ToUI8Module_basic",

# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
"ElementwiseAddScalarInt8Module_basic",
}

if torch_version_for_comparison() < version.parse("2.1.0.dev"):
Expand Down Expand Up @@ -827,7 +836,6 @@
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"RollModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
Expand Down Expand Up @@ -1022,6 +1030,8 @@
"TypePromotionSameCategoryDifferentWidthModule_basic",
"TypePromotionSameCategoryZeroRankWider_basic",
"TypePromotionZeroRankHigherCategoryModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseSubTensorInt8Module_basic",
"ElementwiseMulScalarModule_basic",
"ElementwiseMulScalarModule_float",
"ElementwiseMulScalarModule_int",
Expand All @@ -1039,13 +1049,17 @@
"BatchNorm1DStaticShapeModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseToDtypeI64ToI8Module_basic",
"ToDtypeLayoutCPUModule_basic",
"TupleModule_basic",
"TypeAsDifferentModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
Expand Down Expand Up @@ -1175,6 +1189,7 @@
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"FlattenStaticModule_basic",
"UnflattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
Expand Down Expand Up @@ -1383,6 +1398,8 @@
"SoftmaxIntNegDimModule_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseSubTensorInt8Module_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
Expand Down Expand Up @@ -1441,10 +1458,6 @@
"_ConvolutionDeprecated2DBenchmarkModule_basic",
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliTensorModule_basic",
Expand Down Expand Up @@ -1480,7 +1493,6 @@
"NeFloatIntModule_basic",
"NeIntModule_basic",
"QuantizedMLP_basic",
"RollModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic",
Expand Down Expand Up @@ -1512,7 +1524,6 @@
"ConvolutionBackwardModule2DPadded_basic",
"VarMeanCorrectionModule_basic",
"VarMeanCorrectionNoneModule_basic",
"PrimsConvertElementTypeModule_basic",
"ElementwisePreluModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic",
Expand Down Expand Up @@ -1547,4 +1558,7 @@
"UniformStaticShapeModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"EmptyStridedModule_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ createLinalgPayloadForElementwiseOp(Operation *op,
auto minFloat = clampOp.getMinFloat();
auto maxFloat = clampOp.getMaxFloat();
if (minFloat)
result = b.create<arith::MaxFOp>(
result = b.create<arith::MaximumFOp>(
loc, result,
b.create<arith::ConstantFloatOp>(loc, *minFloat, b.getF32Type()));
if (maxFloat)
result = b.create<arith::MinFOp>(
result = b.create<arith::MinimumFOp>(
loc, result,
b.create<arith::ConstantFloatOp>(loc, *maxFloat, b.getF32Type()));
} else if (elemType.isa<mlir::IntegerType>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
loc, init,
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value max = b.create<arith::MaxFOp>(loc, x, acc);
Value max = b.create<arith::MaximumFOp>(loc, x, acc);
b.create<scf::ReduceReturnOp>(loc, max);
});
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ func.func @tanh(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[CST0:.*]] = arith.constant 1.000000e-01 : f32
// CHECK: %[[MAX:.*]] = arith.maxf %[[BBARG0]], %[[CST0]] : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[BBARG0]], %[[CST0]] : f32
// CHECK: %[[CST1:.*]] = arith.constant 1.024000e+03 : f32
// CHECK: %[[MIN:.*]] = arith.minf %[[MAX]], %[[CST1]] : f32
// CHECK: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[CST1]] : f32
// CHECK: linalg.yield %[[MIN]] : f32
// CHECK: } -> tensor<?x?xf32>
// CHECK: return %[[GENERIC]] : tensor<?x?xf32>
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 12654 files
3 changes: 2 additions & 1 deletion include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt);
std::optional<Type> srcOriginalDtype = std::nullopt,
std::optional<Type> dstOriginalDtype = std::nullopt);

Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Expand Down
Loading