diff --git a/.github/workflows/base.yml.template b/.github/workflows/base.yml.template index bf296e597..465cf1c41 100644 --- a/.github/workflows/base.yml.template +++ b/.github/workflows/base.yml.template @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/base_reduced.yml.template b/.github/workflows/base_reduced.yml.template index c50903499..46b916895 100644 --- a/.github/workflows/base_reduced.yml.template +++ b/.github/workflows/base_reduced.yml.template @@ -22,7 +22,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/develop_install.yml b/.github/workflows/develop_install.yml index cea916825..bdc0df76b 100644 --- a/.github/workflows/develop_install.yml +++ b/.github/workflows/develop_install.yml @@ -30,7 +30,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/end_to_end.yml b/.github/workflows/end_to_end.yml index dba8a2911..a83ba8899 100644 --- a/.github/workflows/end_to_end.yml +++ b/.github/workflows/end_to_end.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/examples_llm_pytest.yml b/.github/workflows/examples_llm_pytest.yml index e939a93b2..065a06738 100644 --- a/.github/workflows/examples_llm_pytest.yml +++ b/.github/workflows/examples_llm_pytest.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/examples_pytest.yml b/.github/workflows/examples_pytest.yml index d514c6c33..262625e01 100644 --- a/.github/workflows/examples_pytest.yml +++ b/.github/workflows/examples_pytest.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/finn_integration.yml b/.github/workflows/finn_integration.yml index cb1946b84..f62876a45 100644 --- a/.github/workflows/finn_integration.yml +++ b/.github/workflows/finn_integration.yml @@ -30,7 +30,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/notebook.yml b/.github/workflows/notebook.yml index 7678a9d15..4e8c06e78 100644 --- a/.github/workflows/notebook.yml +++ b/.github/workflows/notebook.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/ort_integration.yml b/.github/workflows/ort_integration.yml index 519873f05..02c75fcf2 100644 --- a/.github/workflows/ort_integration.yml +++ b/.github/workflows/ort_integration.yml @@ -30,7 +30,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index c7e407799..4505d8df0 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_develop_install.yml b/.github/workflows/reduced_develop_install.yml index b23fa97b6..b5ead15ee 100644 --- a/.github/workflows/reduced_develop_install.yml +++ b/.github/workflows/reduced_develop_install.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_end_to_end.yml b/.github/workflows/reduced_end_to_end.yml index f06c29872..c52fb0ceb 100644 --- a/.github/workflows/reduced_end_to_end.yml +++ b/.github/workflows/reduced_end_to_end.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_examples_llm_pytest.yml b/.github/workflows/reduced_examples_llm_pytest.yml index b9c3deffe..44b0de612 100644 --- a/.github/workflows/reduced_examples_llm_pytest.yml +++ b/.github/workflows/reduced_examples_llm_pytest.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_examples_pytest.yml b/.github/workflows/reduced_examples_pytest.yml index 62541236f..b4d42540c 100644 --- a/.github/workflows/reduced_examples_pytest.yml +++ b/.github/workflows/reduced_examples_pytest.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_finn_integration.yml b/.github/workflows/reduced_finn_integration.yml index b4e0e62d1..342a01c34 100644 --- a/.github/workflows/reduced_finn_integration.yml +++ b/.github/workflows/reduced_finn_integration.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_notebook.yml b/.github/workflows/reduced_notebook.yml index 159d4a9a4..5d2dc1b6f 100644 --- a/.github/workflows/reduced_notebook.yml +++ b/.github/workflows/reduced_notebook.yml @@ -34,7 +34,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_ort_integration.yml b/.github/workflows/reduced_ort_integration.yml index 06219f128..9fcb678d2 100644 --- a/.github/workflows/reduced_ort_integration.yml +++ b/.github/workflows/reduced_ort_integration.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/.github/workflows/reduced_pytest.yml b/.github/workflows/reduced_pytest.yml index 8af119e15..f3d0763e3 100644 --- a/.github/workflows/reduced_pytest.yml +++ b/.github/workflows/reduced_pytest.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 diff --git a/src/brevitas/config.py b/src/brevitas/config.py index a5685721c..082a6508b 100644 --- a/src/brevitas/config.py +++ b/src/brevitas/config.py @@ -25,3 +25,4 @@ def env_to_bool(name, default): _FULL_STATE_DICT = False _IS_INSIDE_QUANT_LAYER = None _ONGOING_EXPORT = None +_RETROCOMPATIBLE_SCALING = False diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index f4fd79f1a..145f5ca06 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -67,12 +67,13 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scale / float_scaling_impl_value + else: + float_scaling_impl_value = None + scale = self.scaling_impl(x, float_scaling_impl_value) x = self.input_view_impl(x) scaled_x = x / scale internal_scale = float_internal_scale( diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..e7c5560f8 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -149,9 +149,8 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: bit_width = self.msb_clamp_bit_width_impl() - threshold = self.scaling_impl(x) int_threshold = self.int_scaling_impl(bit_width) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -184,8 +183,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te pre_threshold = self.pre_scaling_impl(x) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - threshold = self.scaling_impl(x) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point @@ -250,8 +248,7 @@ def forward(self, x: Tensor, input_bit_width: Tensor, pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - threshold = self.scaling_impl(x) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 449318765..0720e595e 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -90,6 +90,9 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() + def retrocompatibility_op(self, x): + return x + @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> Tensor: return x @@ -113,6 +116,9 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() + def retrocompatibility_op(self, x): + return self.power_of_two(x) + @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.power_of_two(x) @@ -137,6 +143,9 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() + def retrocompatibility_op(self, x): + return x + @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.float_to_int_impl(x) @@ -162,6 +171,9 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() + def retrocompatibility_op(self, x): + return self.power_of_two(x) + @brevitas.jit.script_method def forward(self, x: torch.Tensor): x = self.float_to_int_impl(x) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index e8a6b04ee..53d9c67f8 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -50,10 +50,12 @@ def __init__( dtype, device) - @brevitas.jit.script_method - def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor: + def forward( + self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.parameter_list_stats(x) - return self.stats_scaling_impl(stats) + if threshold is None: + threshold = torch.ones(1).type_as(stats) + return self.stats_scaling_impl(stats, threshold) class _StatsScaling(brevitas.jit.ScriptModule): @@ -80,8 +82,11 @@ def __init__( self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() @brevitas.jit.script_method - def forward(self, stats: torch.Tensor) -> torch.Tensor: - stats = self.restrict_scaling_pre(stats) + def forward( + self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats) + stats = self.restrict_scaling_pre(stats / threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats @@ -120,9 +125,9 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.runtime_stats(x) - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) class _AffineRescaling(brevitas.jit.ScriptModule): @@ -179,9 +184,14 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, stats_input) -> torch.Tensor: + def forward( + self, + stats_input: torch.Tensor, + threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) + out = self.scaling_stats_impl(stats_input_reshaped) / threshold # Scaling min val out = self.restrict_clamp_scaling(out) return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 7e7dca944..391ddca67 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -77,8 +77,10 @@ def __init__( self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = self.value() + def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + value = self.value() / threshold restricted_value = self.restrict_clamp_scaling(value) return restricted_value @@ -149,8 +151,10 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) + def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value) / threshold) return value def _load_from_state_dict( @@ -190,29 +194,40 @@ def __init__( scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) + self.restrict_scaling_impl = restrict_scaling_impl self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) if restrict_scaling_impl is not None: self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() else: self.restrict_inplace_preprocess = Identity() + self.restrict_preprocess = Identity() + self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor: + def forward( + self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(x) + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.init_done: - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = self.restrict_preprocess(self.value / threshold) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) return value else: stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: - return self.stats_scaling_impl(stats) - stats = self.restrict_inplace_preprocess(stats) + return self.stats_scaling_impl(stats, threshold) inplace_tensor_mul(self.value.detach(), stats) + value = self.restrict_preprocess(self.value / threshold) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) self.init_done = True return value @@ -228,9 +243,18 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + value_key = prefix + 'value' + + # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) + # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) + # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) + if config._RETROCOMPATIBLE_SCALING: + if not isinstance(self.restrict_scaling_impl, Identity): + state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op( + state_dict[value_key]) + super(ParameterFromStatsFromParameterScaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - value_key = prefix + 'value' # disable stats collection when a pretrained value is loaded if value_key not in missing_keys: self.init_done = True @@ -305,6 +329,7 @@ def __init__( scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) + self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( @@ -317,7 +342,10 @@ def __init__( self.restrict_preprocess = Identity() @brevitas.jit.script_method - def training_forward(self, stats_input: Tensor) -> Tensor: + def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor: + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -327,32 +355,37 @@ def training_forward(self, stats_input: Tensor) -> Tensor: new_counter = self.counter + 1 # Whenever we are in local loss mode, we don't update the counter nor the buffer if self.local_loss_mode: - return abs_binary_sign_grad(clamped_stats) + # Local loss mode, we early exit and divide by threshold + return abs_binary_sign_grad(clamped_stats / threshold) if self.counter == 0: inplace_tensor_mul(self.buffer, clamped_stats.detach()) else: inplace_momentum_update( self.buffer, clamped_stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter - return abs_binary_sign_grad(clamped_stats) + return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: - self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) + value = self.restrict_preprocess(self.value / threshold) self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + value = self.restrict_preprocess(self.value / threshold) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) @brevitas.jit.script_method - def forward(self, stats_input: Tensor) -> Tensor: + def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) if self.training: - return self.training_forward(stats_input) + # Threshold division handled inside the training_forward + return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer + out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.value + out = self.restrict_preprocess(self.value / threshold) out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out @@ -378,6 +411,14 @@ def _load_from_state_dict( if retrocomp_value_key in state_dict: state_dict[value_key] = state_dict.pop(retrocomp_value_key) + # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) + # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) + # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) + if config._RETROCOMPATIBLE_SCALING: + if not isinstance(self.restrict_scaling_impl, Identity): + state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op( + state_dict[value_key]) + super(ParameterFromRuntimeStatsScaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # Buffer is supposed to be always missing diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 93cc235e2..776f1f6b2 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -25,10 +25,10 @@ def __init__( self.stats_impl = scaling_stats_impl self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn - def forward(self, x) -> Tensor: + def forward(self, x, threshold) -> Tensor: shape = x.shape x = self.scaling_stats_input_view_shape_impl(x) - x = self.stats_impl(x) + x = self.stats_impl(x) / threshold x = self.dynamic_scaling_broadcastable_fn(x, shape) return x diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 16b8a4b5f..a5f597586 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -109,9 +109,10 @@ def test_float_to_quant_float(inp, minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): + float_scaling_impl_return = 1. bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format - scaling_impl = mock.Mock(side_effect=lambda x: 1.) - float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) + scaling_impl = mock.Mock(side_effect=lambda x, y: 1.) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: float_scaling_impl_return) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( @@ -148,7 +149,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): torch.tensor(exponent_bit_width), torch.tensor(mantissa_bit_width), torch.tensor(exponent_bias)) - scaling_impl.assert_called_once_with(inp) + scaling_impl.assert_called_once_with(inp, float_scaling_impl_return) @given( @@ -160,7 +161,7 @@ def test_inner_scale(inp, minifloat_format, scale): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) - scaling_impl = mock.Mock(side_effect=lambda x: scale) + scaling_impl = mock.Mock(side_effect=lambda x, y: scale) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index b22994275..10d8f7e7c 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -12,6 +12,7 @@ from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import load_quant_model_mode +from brevitas.inject.enum import RestrictValueType import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat @@ -27,7 +28,9 @@ BATCH = 1 REFERENCE_SCALES = { 'int_quant': (0.00935234408825635910, 0.01362917013466358185), - 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} + 'fp_quant': (0.00249395845457911491, 0.00363444536924362183), + 'int_po2_quant': (0.015625, 0.015625), + 'fp_po2_quant': (0.001953125, 0.00390625),} REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], [1.4573, -0.9074, -0.2708]]) @@ -44,9 +47,9 @@ def reference_implementation_scale_factors_po2( quant = compute_quantile(x, q) quant = torch.max(min_val, quant) quant_float_to_int = torch.ceil( - torch.log2(quant)) # Float to Int Implementation for PowerOfTwo scale + torch.log2(quant / int_scale)) # Float to Int Implementation for PowerOfTwo scale - scale = torch.pow(torch.tensor(2.), quant_float_to_int) / int_scale + scale = torch.pow(torch.tensor(2.), quant_float_to_int) return scale @@ -75,7 +78,15 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) -QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} +class Fp8e4m3ActPerTensorFixedPoint(Fp8e4m3ActPerTensorFloat): + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + + +QUANTS = { + 'int_quant': Int8ActPerTensorFloat, + 'fp_quant': Fp8e4m3ActPerTensorFloat, + 'int_po2_quant': Int8ActPerTensorFixedPoint, + 'fp_po2_quant': Fp8e4m3ActPerTensorFixedPoint} @pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) diff --git a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py index dded0e276..c72a2d1f8 100644 --- a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py @@ -12,9 +12,11 @@ from qonnx.transformation.infer_shapes import InferShapes import torch +import brevitas.config as config from brevitas.export import export_qonnx from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b +config._RETROCOMPATIBLE_SCALING = True QUARTZNET_POSTPROCESSED_INPUT_SIZE = (1, 64, 256) # B, features, sequence MIN_INP_VAL = 0.0 MAX_INP_VAL = 200.0