diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 3f4e6ed66354..c49646e2f1c0 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, if (!isUnsignedType) return; int64_t minSI = -(1 << (numBits - 1)); - Value minSIValue = rewriter.create(loc, minSI, 32); + Value minSIValue = rewriter.create( + loc, minSI, zp.getType().cast().getWidth()); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( @@ -797,6 +798,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto resultTy = cast(op.getType()); Value inputZp, weightZp; + bool inputUnsigned = false; + bool weightUnsigned = false; if (auto make = op.getInput() .getDefiningOp()) { input = make.getSelf(); @@ -806,6 +809,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); + auto torchDtype = cast(make.getType()).getDtype(); + inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (auto make = op.getWeight() @@ -818,6 +823,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); + auto torchDtype = cast(make.getType()).getDtype(); + weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (static_cast(inputZp) != static_cast(weightZp)) { @@ -916,15 +923,35 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); + // convert any uint8 quantization to int8 quantization + if (auto integerType = dyn_cast(inputDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, input, inputZp, inputUnsigned, width); + } + if (auto integerType = dyn_cast(weightDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, weight, weightZp, weightUnsigned, width); + } // Pad the input tensor according to padding. SmallVector outDims{inBatch, weightBatch}; Value paddedInput; - if (transposed) { - if (!isa(inputDTy) || !isa(weightDTy) || - !isa(resultDTy)) - return rewriter.notifyMatchFailure( - op, "transpose does not support non-fp type yet"); + Value pad = inputZp; + if (!pad) { + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); + } + if (pad.getType() != inputDTy) { + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + } + if (transposed) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value c1 = @@ -994,7 +1021,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Allocate padded input tensor Value initTensor = - createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); + createInitTensor(rewriter, loc, outerSizes, inputDTy, pad); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; @@ -1017,24 +1044,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { strideInts.clear(); strideInts.append(numSpatialDims, 1); } else { - Value pad = inputZp; - if (!pad) { - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); - } - - if (pad.getType() != inputDTy) { - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - } - // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 10c24b657128..b5d2f4ed580f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -272,6 +272,7 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", "FloatImplicitModule_basic", @@ -372,6 +373,7 @@ "Conv2dQInt8Module_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -544,6 +546,7 @@ "ContainsIntList_True", "Conv2dQInt8Module_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -2097,6 +2100,7 @@ "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8_basic", } ONNX_XFAIL_SET = { @@ -2251,6 +2255,7 @@ "Conv2dWithPaddingModule_basic", "Conv3dModule_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 9600b090032e..e99525c32d88 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) bias = torch.rand(3) module.forward(inputVec, weight, bias) + + +N = 10 +Cin = 5 +Cout = 7 +Hin = 10 +Win = 8 +Hker = 3 +Wker = 2 + + +class ConvTranspose2DQInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ] + ) + def forward(self, input, weight, bias): + qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25) + qinput = torch.dequantize(qinput) + qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50) + qweight = torch.dequantize(qweight) + qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) + qbias = torch.dequantize(qbias) + qz = torch.ops.aten.convolution( + qinput, + qweight, + bias=qbias, + stride=[2, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) + return qz + + +@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) +def ConvTranspose2DQInt8_basic(module, tu: TestUtils): + module.forward( + tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), + tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), + torch.rand(Cout), + )