Skip to content

Commit

Permalink
[TOSA] Fix conversion for depthwise convolutions (#2398)
Browse files Browse the repository at this point in the history
* [TOSA] Fix conversion for depthwise convolutions

* Add e2e tests for depthwise and grouped convolutions

Co-authored-by: Lucas Camphausen <[email protected]>
  • Loading branch information
simon-camp and lucas-camp authored Aug 18, 2023
1 parent 594a1fa commit d77b9cf
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 35 deletions.
16 changes: 15 additions & 1 deletion e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
from torch_mlir._version import torch_version_for_comparison, version

LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier"
}

TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors
Expand Down Expand Up @@ -276,6 +280,10 @@
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",

# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",

}

TORCHDYNAMO_CRASHING_SET = {
Expand Down Expand Up @@ -640,6 +648,10 @@
"AvgPool1dStaticModule_basic",
"AvgPool2dStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic",
Expand Down Expand Up @@ -989,6 +1001,8 @@
"ElementwiseIsnanModule_basic",
"TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
Expand Down
135 changes: 109 additions & 26 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
auto biasElemTy =
inputElemTy.isa<mlir::FloatType>() ? inputElemTy : rewriter.getI32Type();

int64_t groups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
return rewriter.notifyMatchFailure(op, "non-const group size unsupported");
} else if (groups != 1 && weightShape[1] != 1) {
return rewriter.notifyMatchFailure(
op, "group size must be 1 (convolution) or weight.dim(1) must be 1 "
"(depthwise convolution)");
}

SmallVector<int64_t, 2> stride;
if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride)))
return rewriter.notifyMatchFailure(op, "non-const stride list unsupported");
Expand All @@ -1918,7 +1927,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op,
"non-const dilation list unsupported");

// TOSA works in NHWC and takes OHWI weights. Perform the necessary transpose.
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
// Perform the necessary transformations.
std::optional<Value> nchwToNhwcTransposeConst =
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/{0, 2, 3, 1},
Expand All @@ -1935,26 +1945,82 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
nchwToNhwcTransposeConst.value())
.getResult();

SmallVector<int64_t> transposedWeightShape(
{weightShape[0], weightShape[2], weightShape[3], weightShape[1]});
auto transposedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
auto transposedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transposedWeightType), weight,
nchwToNhwcTransposeConst.value())
.getResult();
SmallVector<int64_t> transformedWeightShape;
RankedTensorType transformedWeightType;
Value transformedWeight;
int64_t outputCDim;
if (groups == 1) {
// full convolution: O(I/G)HW-> OHWI
transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3],
weightShape[1]};
transformedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transformedWeightShape), weightElemTy);
transformedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transformedWeightType), weight,
nchwToNhwcTransposeConst.value())
.getResult();
outputCDim = transformedWeightShape[0];
} else if (weightShape[1] == 1) {
// depthwise convolution: O(I/G)HW-> HWIM)
// transpose: O(I/G)HW -> HWO(I/G)
std::optional<Value> transposeConst =
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/{2, 3, 0, 1},
/*shape=*/{static_cast<int32_t>(4)});
SmallVector<int64_t> transposedWeightShape = {
weightShape[2], weightShape[3], weightShape[0], weightShape[1]};
auto transposedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
auto transposedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transposedWeightType), weight,
transposeConst.value())
.getResult();

// reshape: HWO(I/G) -> HWIM
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
if (outputCDim == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "number of output channels must be statically known for "
"depthwise convolutions");
}
transformedWeightShape = {
transposedWeightShape[0],
transposedWeightShape[1],
groups,
outputCDim / groups,
};
transformedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transformedWeightShape), weightElemTy);
transformedWeight =
rewriter
.create<tosa::ReshapeOp>(
op->getLoc(),
getTypeConverter()->convertType(transformedWeightType),
transposedWeight,
rewriter.getDenseI64ArrayAttr(transformedWeightShape))
.getResult();
} else {
llvm_unreachable("Unhandled convolution type");
}

int64_t outputHDim, outputWDim;
if (inputTy.hasStaticShape()) {
outputHDim = (transposedInputShape[1] + padding[0] + padding[1] -
dilation[0] * (transposedWeightShape[1] - 1) - 1) /
int64_t inputHDim = inputShape[2];
int64_t inputWDim = inputShape[3];
int64_t weightHDim = weightShape[2];
int64_t weightWDim = weightShape[3];
outputHDim = (inputHDim + padding[0] + padding[1] -
dilation[0] * (weightHDim - 1) - 1) /
stride[0] +
1;
outputWDim = (transposedInputShape[2] + padding[2] + padding[3] -
dilation[1] * (transposedWeightShape[2] - 1) - 1) /
outputWDim = (inputWDim + padding[2] + padding[3] -
dilation[1] * (weightWDim - 1) - 1) /
stride[1] +
1;
} else {
Expand All @@ -1965,19 +2031,36 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
// Output shape is NHWC, to be transposed back to NCHW. Output elemTy for
// quantized input is i32, which gets rescaled down to quantized output range.
SmallVector<int64_t> outputShape = {transposedInputShape[0], outputHDim,
outputWDim, transposedWeightShape[0]};
outputWDim, outputCDim};
auto convOpTy =
RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy);

Value convOpResult =
rewriter
.create<tosa::Conv2DOp>(op->getLoc(),
getTypeConverter()->convertType(convOpTy),
transposedInput, transposedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
.getResult();
Value convOpResult;
if (groups == 1) {
// full convolution
convOpResult =
rewriter
.create<tosa::Conv2DOp>(op->getLoc(),
getTypeConverter()->convertType(convOpTy),
transposedInput, transformedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
.getResult();
} else if (weightShape[1] == 1) {
// depthwise convolution
convOpResult =
rewriter
.create<tosa::DepthwiseConv2DOp>(
op->getLoc(), getTypeConverter()->convertType(convOpTy),
transposedInput, transformedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
.getResult();
} else {
llvm_unreachable("Unhandled convolution type");
}

std::optional<Value> nhwcToNchwTransposeConst =
tosa::getConstTensor<int32_t>(rewriter, op,
Expand Down
40 changes: 32 additions & 8 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,32 +112,56 @@ def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):

class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):

def __init__(self):
def __init__(self, out_channels, groups):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(in_channels=2,
out_channels=10,
self.conv = torch.nn.Conv2d(in_channels=4,
out_channels=out_channels,
kernel_size=3,
padding=3,
stride=2,
dilation=3,
bias=False)
bias=False,
groups=groups)
self.train(False)

@export
@annotate_args([
None,
([5, 2, 10, 20], torch.float32, True),
([5, 4, 10, 20], torch.float32, True),
])
def forward(self, x):
return self.conv(x)


@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule())
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=10, groups=1))
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
module.forward(tu.rand(5, 4, 10, 20))


@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=4))
def Conv2dWithPaddingDilationStrideStaticModule_depthwise(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))


@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=4))
def Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))


@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=4, groups=2))
def Conv2dWithPaddingDilationStrideStaticModule_grouped(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))


@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule(out_channels=8, groups=2))
def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 10, 20))


# ==============================================================================
Expand Down

0 comments on commit d77b9cf

Please sign in to comment.