diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 85fd491..e16e471 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -1,6 +1,7 @@ import gc import re from typing import List, Tuple +import copy import torch import tqdm @@ -47,7 +48,7 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: ) else: min_val, max_val = tensor.aminmax() - amax = min_val.abs().max(max_val.abs()) + amax = torch.maximum(min_val.abs(), max_val.abs()) scale = finfo.max / amax.clamp(min=1e-12) # scale and clamp the tensor to bring it to # the representative range of float8 data type @@ -202,8 +203,8 @@ def quantize_weights( or name in quantize_config.ignored_layers ): continue - quant_weight, quant_scale = per_tensor_quantize(linear.weight.clone()) - bias = linear.bias.clone() if linear.bias is not None else None + quant_weight, quant_scale = per_tensor_quantize(linear.weight) + bias = copy.deepcopy(linear.bias) if linear.bias is not None else None quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias) replace_module(model, name, quant_linear) del linear.weight