From d77b9cf7ae061f39f0c938116fe7546b6994b46d Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Fri, 18 Aug 2023 17:15:54 +0200 Subject: [PATCH] [TOSA] Fix conversion for depthwise convolutions (#2398) * [TOSA] Fix conversion for depthwise convolutions * Add e2e tests for depthwise and grouped convolutions Co-authored-by: Lucas Camphausen --- e2e_testing/xfail_sets.py | 16 ++- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 135 ++++++++++++++---- python/torch_mlir_e2e_test/test_suite/conv.py | 40 ++++-- 3 files changed, 156 insertions(+), 35 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 2c121e5ab9ba..c9cd0b8cbb27 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -13,7 +13,11 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS from torch_mlir._version import torch_version_for_comparison, version -LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS +LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { + # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed + # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier" +} TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors @@ -276,6 +280,10 @@ # AssertionError: Unregistered operation: torch.aten._unsafe_index_put "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed + # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + } TORCHDYNAMO_CRASHING_SET = { @@ -640,6 +648,10 @@ "AvgPool1dStaticModule_basic", "AvgPool2dStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Convolution2DStaticModule_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "ElementwiseCloneContiguousModule_basic", @@ -989,6 +1001,8 @@ "ElementwiseIsnanModule_basic", "TypePromotionAlphaWiderModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "BatchNorm1DModule_basic", "BatchNorm1DWith2DInputModule_basic", "BatchNorm2DModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index f701df29cfa4..bf2f20d8202a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1898,6 +1898,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto biasElemTy = inputElemTy.isa() ? inputElemTy : rewriter.getI32Type(); + int64_t groups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) { + return rewriter.notifyMatchFailure(op, "non-const group size unsupported"); + } else if (groups != 1 && weightShape[1] != 1) { + return rewriter.notifyMatchFailure( + op, "group size must be 1 (convolution) or weight.dim(1) must be 1 " + "(depthwise convolution)"); + } + SmallVector stride; if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride))) return rewriter.notifyMatchFailure(op, "non-const stride list unsupported"); @@ -1918,7 +1927,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-const dilation list unsupported"); - // TOSA works in NHWC and takes OHWI weights. Perform the necessary transpose. + // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. + // Perform the necessary transformations. std::optional nchwToNhwcTransposeConst = tosa::getConstTensor(rewriter, op, /*vec=*/{0, 2, 3, 1}, @@ -1935,26 +1945,82 @@ LogicalResult ConvertAtenOp::matchAndRewrite( nchwToNhwcTransposeConst.value()) .getResult(); - SmallVector transposedWeightShape( - {weightShape[0], weightShape[2], weightShape[3], weightShape[1]}); - auto transposedWeightType = RankedTensorType::get( - makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); - auto transposedWeight = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transposedWeightType), weight, - nchwToNhwcTransposeConst.value()) - .getResult(); + SmallVector transformedWeightShape; + RankedTensorType transformedWeightType; + Value transformedWeight; + int64_t outputCDim; + if (groups == 1) { + // full convolution: O(I/G)HW-> OHWI + transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3], + weightShape[1]}; + transformedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); + transformedWeight = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), weight, + nchwToNhwcTransposeConst.value()) + .getResult(); + outputCDim = transformedWeightShape[0]; + } else if (weightShape[1] == 1) { + // depthwise convolution: O(I/G)HW-> HWIM) + // transpose: O(I/G)HW -> HWO(I/G) + std::optional transposeConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/{2, 3, 0, 1}, + /*shape=*/{static_cast(4)}); + SmallVector transposedWeightShape = { + weightShape[2], weightShape[3], weightShape[0], weightShape[1]}; + auto transposedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); + auto transposedWeight = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transposedWeightType), weight, + transposeConst.value()) + .getResult(); + + // reshape: HWO(I/G) -> HWIM + outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; + if (outputCDim == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "number of output channels must be statically known for " + "depthwise convolutions"); + } + transformedWeightShape = { + transposedWeightShape[0], + transposedWeightShape[1], + groups, + outputCDim / groups, + }; + transformedWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); + transformedWeight = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), + transposedWeight, + rewriter.getDenseI64ArrayAttr(transformedWeightShape)) + .getResult(); + } else { + llvm_unreachable("Unhandled convolution type"); + } int64_t outputHDim, outputWDim; if (inputTy.hasStaticShape()) { - outputHDim = (transposedInputShape[1] + padding[0] + padding[1] - - dilation[0] * (transposedWeightShape[1] - 1) - 1) / + int64_t inputHDim = inputShape[2]; + int64_t inputWDim = inputShape[3]; + int64_t weightHDim = weightShape[2]; + int64_t weightWDim = weightShape[3]; + outputHDim = (inputHDim + padding[0] + padding[1] - + dilation[0] * (weightHDim - 1) - 1) / stride[0] + 1; - outputWDim = (transposedInputShape[2] + padding[2] + padding[3] - - dilation[1] * (transposedWeightShape[2] - 1) - 1) / + outputWDim = (inputWDim + padding[2] + padding[3] - + dilation[1] * (weightWDim - 1) - 1) / stride[1] + 1; } else { @@ -1965,19 +2031,36 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Output shape is NHWC, to be transposed back to NCHW. Output elemTy for // quantized input is i32, which gets rescaled down to quantized output range. SmallVector outputShape = {transposedInputShape[0], outputHDim, - outputWDim, transposedWeightShape[0]}; + outputWDim, outputCDim}; auto convOpTy = RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); - Value convOpResult = - rewriter - .create(op->getLoc(), - getTypeConverter()->convertType(convOpTy), - transposedInput, transposedWeight, bias, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) - .getResult(); + Value convOpResult; + if (groups == 1) { + // full convolution + convOpResult = + rewriter + .create(op->getLoc(), + getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation)) + .getResult(); + } else if (weightShape[1] == 1) { + // depthwise convolution + convOpResult = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation)) + .getResult(); + } else { + llvm_unreachable("Unhandled convolution type"); + } std::optional nhwcToNchwTransposeConst = tosa::getConstTensor(rewriter, op, diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index 006301b9fc79..5fc443d98605 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -112,32 +112,56 @@ def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils): class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module): - def __init__(self): + def __init__(self, out_channels, groups): super().__init__() torch.manual_seed(0) - self.conv = torch.nn.Conv2d(in_channels=2, - out_channels=10, + self.conv = torch.nn.Conv2d(in_channels=4, + out_channels=out_channels, kernel_size=3, padding=3, stride=2, dilation=3, - bias=False) + bias=False, + groups=groups) self.train(False) @export @annotate_args([ None, - ([5, 2, 10, 20], torch.float32, True), + ([5, 4, 10, 20], torch.float32, True), ]) def forward(self, x): return self.conv(x) @register_test_case( - module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule()) + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=10, groups=1)) def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils): - t = tu.rand(5, 2, 10, 20) - module.forward(t) + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=4)) +def Conv2dWithPaddingDilationStrideStaticModule_depthwise(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=4)) +def Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=2)) +def Conv2dWithPaddingDilationStrideStaticModule_grouped(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) + + +@register_test_case( + module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=2)) +def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 10, 20)) # ==============================================================================