Skip to content

Commit

Permalink
[float8] Allow specifying arbitrary dtype for each tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: 4b3a2f0007d74e3453cefde1307f2a9c5271e83e
Pull Request resolved: #1326
  • Loading branch information
lw committed Nov 22, 2024
1 parent bc0a29a commit 7d28acf
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 108 deletions.
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
e5m2_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
Expand All @@ -51,8 +53,6 @@
)
from torchao.float8.float8_utils import (
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn as nn
from torchao.float8.config import (
CastConfig,
e4m3_dtype,
Float8LinearConfig,
ScalingType,
Float8LinearRecipeName,
Expand All @@ -41,7 +42,6 @@
GemmInputRole,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config

from torch._dynamo.test_case import TestCase as DynamoTestCase
Expand Down
8 changes: 4 additions & 4 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training

from torchao.float8.config import CastConfig, ScalingType
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -40,7 +40,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.float8_utils import tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
Expand Down Expand Up @@ -197,7 +197,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
loss = torch.sum(torch.abs(out - dist_target))
loss.backward()
Expand Down
30 changes: 28 additions & 2 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None

def short_str(self):
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
Expand All @@ -75,6 +76,9 @@ def __post_init__(self):
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"
assert self.dtype is None or (
self.dtype.is_floating_point and self.dtype.itemsize == 1
), "must specify a 8-bit floating-point dtype"


@dataclass(frozen=True)
Expand Down Expand Up @@ -124,6 +128,12 @@ def __post_init__(self):
self.e5m2_dtype = torch.float8_e5m2fnuz


# User defined type for using the individual F8 type based on config
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand Down Expand Up @@ -279,6 +289,20 @@ def __post_init__(self):
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

for cc1, cc2, operand_name, default_dtype in [
(cc_i, cc_i_gw, "input", e4m3_dtype),
(cc_w, cc_w_gi, "weight", e4m3_dtype),
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
]:
# Override the dataclass being frozen
if cc1.dtype is None:
object.__setattr__(cc1, "dtype", default_dtype)
if cc2.dtype is None:
object.__setattr__(cc2, "dtype", default_dtype)
assert (
cc1.dtype == cc2.dtype
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

Expand Down Expand Up @@ -343,12 +367,14 @@ def recipe_name_to_linear_config(
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype)

return Float8LinearConfig(
cast_config_input=cc_i,
Expand Down
83 changes: 44 additions & 39 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
NoopFwToFloat8E5M2BwStatic,
NoopFwToFloat8BwDelayed,
NoopFwToFloat8BwDynamic,
NoopFwToFloat8BwStatic,
_maybe_initialize_amaxes_scales_for_float8_cast,
get_maybe_axiswise_dim,
hp_tensor_to_float8_delayed,
Expand All @@ -31,8 +31,6 @@
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
tensor_to_scale,
)
Expand Down Expand Up @@ -135,7 +133,7 @@ def forward(
else:
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
input_hp,
e4m3_dtype,
c.cast_config_input.dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=c.cast_config_input.scaling_granularity,
Expand All @@ -149,7 +147,7 @@ def forward(
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
e4m3_dtype,
c.cast_config_weight.dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight.scaling_granularity,
Expand Down Expand Up @@ -185,7 +183,7 @@ def backward(ctx, grad_output):
else:
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
grad_output_reshaped,
e5m2_dtype,
c.cast_config_grad_output.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=c.cast_config_grad_output.scaling_granularity,
Expand All @@ -203,7 +201,7 @@ def backward(ctx, grad_output):
# the entire tensor.
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
weight_hp_t,
e4m3_dtype,
c.cast_config_weight_for_grad_input.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity,
Expand Down Expand Up @@ -235,7 +233,7 @@ def backward(ctx, grad_output):
else:
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
grad_output_reshaped,
e5m2_dtype,
c.cast_config_grad_output_for_grad_weight.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity,
Expand All @@ -249,7 +247,7 @@ def backward(ctx, grad_output):
else:
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
input_hp_reshaped,
e4m3_dtype,
c.cast_config_input_for_grad_weight.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity,
Expand Down Expand Up @@ -354,11 +352,9 @@ def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
device = self.weight.device
# TODO(future PR): dtype values below don't have the other float8
# flavors, fix it
default_input = torch.finfo(torch.float8_e4m3fn).max
default_weight = torch.finfo(torch.float8_e4m3fn).max
default_grad_output = torch.finfo(torch.float8_e5m2).max
default_input = torch.finfo(self.config.cast_config_input.dtype).max
default_weight = torch.finfo(self.config.cast_config_weight.dtype).max
default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max

# Note: for now, create all the buffers if any are needed, to postpone
# the work to make the scale and amax syncing and history calculation
Expand Down Expand Up @@ -445,29 +441,32 @@ def cast_input_to_float8(
self.fp8_amax_history_input,
self.fp8_scale_input,
scale_fn_name,
e4m3_dtype,
self.config.cast_config_input.dtype,
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = hp_tensor_to_float8_delayed(
input,
self.fp8_scale_input,
e4m3_dtype,
self.config.cast_config_input.dtype,
self.fp8_amax_input,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
elif self.scaling_type_input is ScalingType.DYNAMIC:
input_fp8 = hp_tensor_to_float8_dynamic(
input,
e4m3_dtype,
self.config.cast_config_input.dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is ScalingType.STATIC
input_fp8 = hp_tensor_to_float8_static(
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
input,
self.fp8_static_scale_input,
self.config.cast_config_input.dtype,
self.linear_mm_config,
)

return input_fp8
Expand All @@ -483,14 +482,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
self.config.cast_config_weight.dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
return self.fp8_scale_weight
elif self.scaling_type_weight is ScalingType.DYNAMIC:
return tensor_to_scale(weight, e4m3_dtype)
return tensor_to_scale(weight, self.config.cast_config_weight.dtype)
else:
assert self.scaling_type_weight is ScalingType.STATIC
return self.fp8_static_scale_weight
Expand All @@ -506,7 +505,7 @@ def cast_weight_to_float8_t(
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
e4m3_dtype,
self.config.cast_config_weight.dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
Expand All @@ -521,23 +520,29 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8E5M2BwDelayed.apply(
output = NoopFwToFloat8BwDelayed.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
self.fp8_scale_grad_output,
scale_fn_name,
self.is_amax_initialized,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
elif self.scaling_type_grad_output is ScalingType.DYNAMIC:
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
output = NoopFwToFloat8BwDynamic.apply(
output,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
else:
assert self.scaling_type_grad_output is ScalingType.STATIC
output = NoopFwToFloat8E5M2BwStatic.apply(
output = NoopFwToFloat8BwStatic.apply(
output,
self.fp8_static_scale_grad_output,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
return output

Expand All @@ -563,19 +568,16 @@ def float8_post_forward(self):
self.amax_and_scale_synced = False

def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = (
self.config.cast_config_input.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_weight.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_grad_output.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_input_for_grad_weight.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_weight_for_grad_input.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity
is ScalingGranularity.AXISWISE
has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
self.config.cast_config_input,
self.config.cast_config_weight,
self.config.cast_config_grad_output,
self.config.cast_config_input_for_grad_weight,
self.config.cast_config_weight_for_grad_input,
self.config.cast_config_grad_output_for_grad_weight,
]
)

if not has_any_axiswise_scaling:
Expand Down Expand Up @@ -698,6 +700,7 @@ def from_float(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
)
)
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED:
Expand All @@ -708,6 +711,7 @@ def from_float(
new_mod.fp8_amax_history_weight,
new_mod.fp8_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
new_mod.is_amax_initialized,
)
)
Expand All @@ -718,6 +722,7 @@ def from_float(
new_mod.weight,
new_mod.fp8_static_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
)
)

Expand Down
Loading

0 comments on commit 7d28acf

Please sign in to comment.