Skip to content

Commit

Permalink
Feat (core): quant scale support
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 14, 2024
1 parent ee157bc commit 79aea69
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 7 deletions.
8 changes: 7 additions & 1 deletion src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste
Expand Down Expand Up @@ -138,13 +139,18 @@ def __init__(
scaling_impl: Module,
int_scaling_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module):
bit_width_impl: Module,
scaling_int_quant: Optional[Module] = None):
super(RescalingIntQuant, self).__init__()
self.int_quant = int_quant
self.scaling_impl = scaling_impl
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
if scaling_int_quant is None:
self.scaling_int_quant = Identity()
else:
self.scaling_int_quant = scaling_int_quant

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand Down
27 changes: 27 additions & 0 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,30 @@ def forward(self, x: torch.Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x


class QuantRestrictValue(brevitas.jit.ScriptModule):

def __init__(self, restrict_value_float_to_int_impl: Module):
super(QuantRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl

def restrict_init_float(self, x: float):
return Identity()

def restrict_init_tensor(self, x: torch.Tensor):
return Identity()

def restrict_init_module(self):
return Identity()

def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
o, *_ = self.float_to_int_impl(x)
return o
3 changes: 1 addition & 2 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def __init__(
dtype,
device)

def forward(
self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats(x)
if threshold is None:
threshold = torch.ones(1).type_as(stats)
Expand Down
3 changes: 1 addition & 2 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ def __init__(
self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(
self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(x)
# Threshold division must happen after we update self.value, but before we apply restrict_preproces
Expand Down
22 changes: 20 additions & 2 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tenso
return out


class _ScaleShiftQuantZeroPoint(brevitas.jit.ScriptModule):
__constants__ = ['quantize_zero_point']

def __init__(self, zp_int_quant: Module, quantize_zero_point: bool) -> None:
super(_ScaleShiftQuantZeroPoint, self).__init__()
self.zp_int_quant = zp_int_quant
self.quantize_zero_point = quantize_zero_point

@brevitas.jit.script_method
def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
quant_zp, scale, *_ = self.zp_int_quant(zero_point)
return quant_zp


class StatsFromParameterZeroPoint(brevitas.jit.ScriptModule):

def __init__(
Expand All @@ -70,15 +84,19 @@ def __init__(
zero_point_stats_input_concat_dim: int,
zero_point_stats_impl: Module,
zero_point_shape: Tuple[int, ...],
tracked_parameter_list: List[torch.nn.Parameter]) -> None:
tracked_parameter_list: List[torch.nn.Parameter],
scale_shit_zero_point_impl: Optional[Module] = None) -> None:
super(StatsFromParameterZeroPoint, self).__init__()
self.parameter_list_stats = _ParameterListStats(
zero_point_stats_impl,
zero_point_shape,
zero_point_stats_input_view_shape_impl,
zero_point_stats_input_concat_dim,
tracked_parameter_list)
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
if scale_shit_zero_point_impl is None:
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
else:
self.scale_shift_zero_point = scale_shit_zero_point_impl

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:
Expand Down
127 changes: 127 additions & 0 deletions tests/brevitas/core/test_scaling_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from dependencies import this
from dependencies import value
import torch

from brevitas.core.quant.int import RescalingIntQuant
from brevitas.core.restrict_val import QuantRestrictValue
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.inject.enum import ScalingPerOutputType
import brevitas.nn as qnn
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat


class QuantScalingInt(Int8WeightPerTensorFloat):
bit_width = 8
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
rescaling_int_quant = RescalingIntQuant

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling


from brevitas.core.zero_point import _ScaleShiftQuantZeroPoint


class QuantZPInt(Int8WeightPerTensorFloat):
bit_width = 8
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
rescaling_int_quant = RescalingIntQuant
bit_width = 6
quantize_zero_point = True
scaling_per_output_type = ScalingPerOutputType.CHANNEL

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling


class QuantScaleInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat):
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_int_quant = QuantScalingInt
zp_int = QuantZPInt
restrict_scaling_impl = QuantRestrictValue
scaling_per_output_type = ScalingPerOutputType.GROUP
scale_shit_zero_point_impl = _ScaleShiftQuantZeroPoint
group_size = 32

@value
def restrict_value_float_to_int_impl():
return this.scaling_int_quant.rescaling_int_quant

@value
def zp_int_quant():
return this.zp_int.rescaling_int_quant


def test_quant_scale():

def hook_scale(module, inp):
inp = inp[0]
quant_scale, scale, *_ = module.float_to_int_impl(inp)
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

def hook_zp(module, inp):
inp = inp[0]
quant_scale, scale, *_ = module.zp_int_quant(inp)
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleInt8WeightPerTensorFloat)
for module in linear.modules():
if isinstance(module, QuantRestrictValue):
module.register_forward_pre_hook(hook_scale)
for module in linear.modules():
if isinstance(module, _ScaleShiftQuantZeroPoint):
module.register_forward_pre_hook(hook_zp)

linear(torch.randn(1, 64))

0 comments on commit 79aea69

Please sign in to comment.