Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical scales #1038

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste
Expand Down Expand Up @@ -138,7 +139,8 @@ def __init__(
scaling_impl: Module,
int_scaling_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module):
bit_width_impl: Module,
scaling_int_quant: Optional[Module] = None):
super(RescalingIntQuant, self).__init__()
self.int_quant = int_quant
self.scaling_impl = scaling_impl
Expand Down
27 changes: 27 additions & 0 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,30 @@ def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x


class QuantRestrictValue(brevitas.jit.ScriptModule):

def __init__(self, restrict_value_float_to_int_impl: Module):
super(QuantRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl

def restrict_init_float(self, x: float):
return Identity()

def restrict_init_tensor(self, x: torch.Tensor):
return Identity()

def restrict_init_module(self):
return Identity()

def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
o, *_ = self.float_to_int_impl(x)
return o
22 changes: 20 additions & 2 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tenso
return out


class _ScaleShiftQuantZeroPoint(brevitas.jit.ScriptModule):

def __init__(self, zp_int_quant: Module) -> None:
super(_ScaleShiftQuantZeroPoint, self).__init__()
self.zp_int_quant = zp_int_quant

@brevitas.jit.script_method
def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
quant_zp, *_ = self.zp_int_quant(zero_point)
return quant_zp


class StatsFromParameterZeroPoint(brevitas.jit.ScriptModule):

def __init__(
Expand All @@ -70,15 +82,21 @@ def __init__(
zero_point_stats_input_concat_dim: int,
zero_point_stats_impl: Module,
zero_point_shape: Tuple[int, ...],
tracked_parameter_list: List[torch.nn.Parameter]) -> None:
tracked_parameter_list: List[torch.nn.Parameter],
scale_shift_zero_point_impl: Optional[Module] = None) -> None:
super(StatsFromParameterZeroPoint, self).__init__()
self.parameter_list_stats = _ParameterListStats(
zero_point_stats_impl,
zero_point_shape,
zero_point_stats_input_view_shape_impl,
zero_point_stats_input_concat_dim,
tracked_parameter_list)
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
# This is for backward compatibility. Having int_quant/quantize_zero_point required for this
# interface but not for the else seems a bit off and might require some clean-up.
if scale_shift_zero_point_impl is None:
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
else:
self.scale_shift_zero_point = scale_shift_zero_point_impl

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:
Expand Down
131 changes: 131 additions & 0 deletions tests/brevitas/core/test_scaling_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from dependencies import this
from dependencies import value
import torch

from brevitas.core.quant.int import RescalingIntQuant
from brevitas.core.restrict_val import QuantRestrictValue
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.inject.enum import ScalingPerOutputType
import brevitas.nn as qnn
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat

ZP_BIT_WIDTH = 6
SCALE_BIT_WIDTH = 5


class QuantScalingInt(Int8WeightPerTensorFloat):
bit_width = SCALE_BIT_WIDTH
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
rescaling_int_quant = RescalingIntQuant

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling


from brevitas.core.zero_point import _ScaleShiftQuantZeroPoint


class QuantZPInt(Int8WeightPerTensorFloat):
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
rescaling_int_quant = RescalingIntQuant
bit_width = ZP_BIT_WIDTH
quantize_zero_point = True
scaling_per_output_type = ScalingPerOutputType.CHANNEL

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling


class QuantScaleQuantZPInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat):
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_int_quant = QuantScalingInt
zp_int = QuantZPInt
restrict_scaling_impl = QuantRestrictValue
scaling_per_output_type = ScalingPerOutputType.GROUP
scale_shift_zero_point_impl = _ScaleShiftQuantZeroPoint
group_size = 32

@value
def restrict_value_float_to_int_impl():
return this.scaling_int_quant.rescaling_int_quant

@value
def zp_int_quant():
return this.zp_int.rescaling_int_quant


def test_quant_scale():

def hook_scale(module, inp):
inp = inp[0]
quant_scale, scale, zp, bit_width = module.float_to_int_impl(inp)
assert bit_width == SCALE_BIT_WIDTH
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

def hook_zp(module, inp):
inp = inp[0]
quant_scale, scale, zp, bit_width = module.zp_int_quant(inp)
assert bit_width == ZP_BIT_WIDTH
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleQuantZPInt8WeightPerTensorFloat)
for module in linear.modules():
if isinstance(module, QuantRestrictValue):
module.register_forward_pre_hook(hook_scale)
for module in linear.modules():
if isinstance(module, _ScaleShiftQuantZeroPoint):
module.register_forward_pre_hook(hook_zp)

linear(torch.randn(1, 64))
Loading