forked from meituan/YOLOv6
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ptq.py
161 lines (137 loc) · 6.17 KB
/
ptq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import torch.nn as nn
import copy
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import tensor_quant
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor
from tools.partial_quantization.utils import set_module, module_quant_disable
def collect_stats(model, data_loader, batch_number, device='cuda'):
"""Feed data to the network and collect statistic"""
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
for i, data_tuple in enumerate(data_loader):
image = data_tuple[0]
image = image.float()/255.0
model(image.to(device))
if i + 1 >= batch_number:
break
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
def compute_amax(model, **kwargs):
# Load calib result
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
print(F"{name:40}: {module}")
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
def quantable_op_check(k, quantable_ops):
if quantable_ops is None:
return True
if k in quantable_ops:
return True
else:
return False
def quant_model_init(model, device):
model_ptq = copy.deepcopy(model)
model_ptq.eval()
model_ptq.to(device)
conv2d_weight_default_desc = tensor_quant.QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL
conv2d_input_default_desc = QuantDescriptor(num_bits=8, calib_method='histogram')
convtrans2d_weight_default_desc = tensor_quant.QUANT_DESC_8BIT_CONVTRANSPOSE2D_WEIGHT_PER_CHANNEL
convtrans2d_input_default_desc = QuantDescriptor(num_bits=8, calib_method='histogram')
for k, m in model_ptq.named_modules():
if 'proj_conv' in k:
print("Skip Layer {}".format(k))
continue
if isinstance(m, nn.Conv2d):
in_channels = m.in_channels
out_channels = m.out_channels
kernel_size = m.kernel_size
stride = m.stride
padding = m.padding
quant_conv = quant_nn.QuantConv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
quant_desc_input = conv2d_input_default_desc,
quant_desc_weight = conv2d_weight_default_desc)
quant_conv.weight.data.copy_(m.weight.detach())
if m.bias is not None:
quant_conv.bias.data.copy_(m.bias.detach())
else:
quant_conv.bias = None
set_module(model_ptq, k, quant_conv)
elif isinstance(m, nn.ConvTranspose2d):
in_channels = m.in_channels
out_channels = m.out_channels
kernel_size = m.kernel_size
stride = m.stride
padding = m.padding
quant_convtrans = quant_nn.QuantConvTranspose2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
quant_desc_input = convtrans2d_input_default_desc,
quant_desc_weight = convtrans2d_weight_default_desc)
quant_convtrans.weight.data.copy_(m.weight.detach())
if m.bias is not None:
quant_convtrans.bias.data.copy_(m.bias.detach())
else:
quant_convtrans.bias = None
set_module(model_ptq, k, quant_convtrans)
elif isinstance(m, nn.MaxPool2d):
kernel_size = m.kernel_size
stride = m.stride
padding = m.padding
dilation = m.dilation
ceil_mode = m.ceil_mode
quant_maxpool2d = quant_nn.QuantMaxPool2d(kernel_size,
stride,
padding,
dilation,
ceil_mode,
quant_desc_input = conv2d_input_default_desc)
set_module(model_ptq, k, quant_maxpool2d)
else:
# module can not be quantized, continue
continue
return model_ptq.to(device)
def do_ptq(model, train_loader, batch_number, device):
model_ptq = quant_model_init(model, device)
# It is a bit slow since we collect histograms on CPU
with torch.no_grad():
collect_stats(model_ptq, train_loader, batch_number, device)
compute_amax(model_ptq, method='entropy')
return model_ptq
def load_ptq(model, calib_path, device):
model_ptq = quant_model_init(model, device)
model_ptq.load_state_dict(torch.load(calib_path)['model'].state_dict())
return model_ptq
def partial_quant(model_ptq, quantable_ops=None):
# ops not in quantable_ops will reserve full-precision.
for k, m in model_ptq.named_modules():
if quantable_op_check(k, quantable_ops):
continue
# enable full-precision
if isinstance(m, quant_nn.QuantConv2d) or \
isinstance(m, quant_nn.QuantConvTranspose2d) or \
isinstance(m, quant_nn.QuantMaxPool2d):
module_quant_disable(model_ptq, k)