Skip to content

Commit

Permalink
black reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
zjgarvey committed Apr 29, 2024
1 parent e70d835 commit da9f005
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,46 +1047,55 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
bias = torch.rand(3)
module.forward(inputVec, weight, bias)


N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2


class ConvTranspose2DQInt8Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
])
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, input, weight, bias):
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
qinput = torch.dequantize(qinput)
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
qweight = torch.dequantize(qweight)
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
qbias = torch.dequantize(qbias)
qz = torch.ops.aten.convolution(qinput,
qweight,
bias=qbias,
stride=[2,1],
padding=[1,1],
dilation=[1,1],
transposed=True,
output_padding=[0,0],
groups=1)
qz = torch.ops.aten.convolution(
qinput,
qweight,
bias=qbias,
stride=[2, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=True,
output_padding=[0, 0],
groups=1,
)
return qz


@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
module.forward(tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
torch.rand(Cout),
)
module.forward(
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
torch.rand(Cout),
)

0 comments on commit da9f005

Please sign in to comment.