diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index e16e471..4c1b580 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -236,11 +236,12 @@ def quantize_activations( cleanup_memory() # Pass through calibration data to measure activation scales - with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: - for row_idx in range(calibration_tokens.shape[0]): - model(calibration_tokens[row_idx].reshape(1, -1)) - cleanup_memory() - pbar.update(1) + with torch.inference_mode(): + with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: + for row_idx in range(calibration_tokens.shape[0]): + model(calibration_tokens[row_idx].reshape(1, -1)) + cleanup_memory() + pbar.update(1) # Replace dynamic quantizer observer with StaticLinear for export for name, quantizer in model.named_modules():