From 6413e4a6875f055389794a2957d7b24517e2f2c0 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Thu, 7 Nov 2024 16:27:51 +0800 Subject: [PATCH] [Torch] support float_power and threshold ops (#3854) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 60 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 10 +--- .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 23 +++++++ 5 files changed, 110 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 374f0581c059..496a725dd0b8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5187,6 +5187,30 @@ def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [ }]; } +def Torch_AtenFloatPowerTensorTensorOp : Torch_Op<"aten.float_power.Tensor_Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::float_power.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$exponent + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFloatPowerTensorTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFloatPowerTensorTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 46c71698d350..667d04e556e5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9916,6 +9916,63 @@ class DecomposeAtenFMaxMinOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenThresholdOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenThresholdOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto selfType = dyn_cast(self.getType()); + if (!selfType || !selfType.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "requires input is tensor with sizes"); + } + + Value threshold = op.getThreshold(); + Value value = op.getValue(); + + auto comOp = rewriter.create( + loc, + selfType.getWithSizesAndDtype(selfType.getSizes(), + rewriter.getI1Type()), + self, threshold); + + rewriter.replaceOpWithNewOp(op, op.getType(), comOp, + self, value); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenFloatPowerTensorTensorOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFloatPowerTensorTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value exp = op.getExponent(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasDtype() || !selfTy.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "requires input is tensor with dtype and sizes"); + } + + Value selfF64 = + convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); + rewriter.replaceOpWithNewOp(op, op.getType(), + selfF64, exp); + + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -10181,6 +10238,9 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2847ceeee39a..4e1151196363 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -854,14 +854,6 @@ "TensorToFloatZeroRank_basic", "TensorToFloat_basic", "TensorToInt_basic", - "TestMultipleTensorAndPrimitiveTypesReturn_basic", - "Threshold1dFloatModule_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dIntModule_basic", - "Threshold2dFloatModule_basic", - "Threshold2dIntModule_basic", - "Threshold3dFloatModule_basic", - "Threshold3dIntModule_basic", "ThresholdBackward1dFloatModule_basic", "ThresholdBackward1dIntModule_basic", "ThresholdBackward1dMixedModule_basic", @@ -2367,6 +2359,7 @@ "ElementwiseFminModule_basic", "ElementwiseFmaxModule_basic", "Exp2StaticModule_basic", + "FloatPowerTensorTensorStaticModule_basic", "MultinomialModule2D_basic", "MultinomialModule2D_F32", "PixelShuffleModuleStaticRank4Float32_basic", @@ -2390,6 +2383,7 @@ "SliceCopy_Module_basic", "StdCorrectionLargeInputModule_basic", "TupleModule_basic", + "ThresholdStaticModule_basic", "VarCorrectionLargeInputModule_basic", # Failure - incorrect shape "ArangeStartOutDtypeModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ac505735a17e..1d2f19b9cf98 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -498,6 +498,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)") + emit("aten::float_power.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index fa24feb1d37f..489680d438cb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -491,6 +491,29 @@ def ElementwiseWhereSelfModule_basic(module, tu: TestUtils): # ============================================================================== +class FloatPowerTensorTensorStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.float_power(x, torch.tensor(2)) + + +@register_test_case(module_factory=lambda: FloatPowerTensorTensorStaticModule()) +def FloatPowerTensorTensorStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseWhereScalarModule(torch.nn.Module): def __init__(self): super().__init__()