-
Notifications
You must be signed in to change notification settings - Fork 12.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][LLVM]
LLVMTypeConverter
: Tighten materialization checks (#11…
…6532) This commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a `MemRefType`/`UnrankedMemRefType` from the unpacked elements of a MemRef descriptor or from a bare pointer. The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect. This commit also drops a check around bare pointer materializations: ``` // This is a bare pointer. We allow bare pointers only for function entry // blocks. ``` This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize regardless of the input format.
- Loading branch information
1 parent
ed1d90c
commit a0ef12c
Showing
5 changed files
with
154 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | ||
|
||
// Test the argument materializer for ranked MemRef types. | ||
|
||
// CHECK-LABEL: func @construct_ranked_memref_descriptor( | ||
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
// CHECK-COUNT-7: llvm.insertvalue | ||
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32> | ||
func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) { | ||
%0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>) | ||
"test.legal_op"(%0) : (memref<5x4xf32>) -> () | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// The argument materializer for ranked MemRef types is called with incorrect | ||
// input types. Make sure that the materializer is skipped and we do not | ||
// generate invalid IR. | ||
|
||
// CHECK-LABEL: func @invalid_ranked_memref_descriptor( | ||
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32> | ||
// CHECK: "test.legal_op"(%[[cast]]) | ||
func.func @invalid_ranked_memref_descriptor(%arg0: i1) { | ||
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>) | ||
"test.legal_op"(%0) : (memref<5x4xf32>) -> () | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// Test the argument materializer for unranked MemRef types. | ||
|
||
// CHECK-LABEL: func @construct_unranked_memref_descriptor( | ||
// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)> | ||
// CHECK-COUNT-2: llvm.insertvalue | ||
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32> | ||
func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) { | ||
%0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>) | ||
"test.legal_op"(%0) : (memref<*xf32>) -> () | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// The argument materializer for unranked MemRef types is called with incorrect | ||
// input types. Make sure that the materializer is skipped and we do not | ||
// generate invalid IR. | ||
|
||
// CHECK-LABEL: func @invalid_unranked_memref_descriptor( | ||
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32> | ||
// CHECK: "test.legal_op"(%[[cast]]) | ||
func.func @invalid_unranked_memref_descriptor(%arg0: i1) { | ||
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>) | ||
"test.legal_op"(%0) : (memref<*xf32>) -> () | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
/// Replace this op (which is expected to have 1 result) with the operands. | ||
struct TestDirectReplacementOp : public ConversionPattern { | ||
TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter) | ||
: ConversionPattern(converter, "test.direct_replacement", 1, ctx) {} | ||
LogicalResult | ||
matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
if (op->getNumResults() != 1) | ||
return failure(); | ||
rewriter.replaceOpWithMultiple(op, {operands}); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct TestLLVMLegalizePatternsPass | ||
: public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> { | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass) | ||
|
||
StringRef getArgument() const final { return "test-llvm-legalize-patterns"; } | ||
StringRef getDescription() const final { | ||
return "Run LLVM dialect legalization patterns"; | ||
} | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<LLVM::LLVMDialect>(); | ||
} | ||
|
||
void runOnOperation() override { | ||
MLIRContext *ctx = &getContext(); | ||
LLVMTypeConverter converter(ctx); | ||
mlir::RewritePatternSet patterns(ctx); | ||
patterns.add<TestDirectReplacementOp>(ctx, converter); | ||
|
||
// Define the conversion target used for the test. | ||
ConversionTarget target(*ctx); | ||
target.addLegalOp(OperationName("test.legal_op", ctx)); | ||
|
||
// Handle a partial conversion. | ||
DenseSet<Operation *> unlegalizedOps; | ||
ConversionConfig config; | ||
config.unlegalizedOps = &unlegalizedOps; | ||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns), config))) | ||
getOperation()->emitError() << "applyPartialConversion failed"; | ||
} | ||
}; | ||
} // namespace | ||
|
||
//===----------------------------------------------------------------------===// | ||
// PassRegistration | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace mlir { | ||
namespace test { | ||
void registerTestLLVMLegalizePatternsPass() { | ||
PassRegistration<TestLLVMLegalizePatternsPass>(); | ||
} | ||
} // namespace test | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters