-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8] Allow specifying arbitrary dtype for each tensor #1326
base: gh/lw/2/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ class CastConfig: | |
scaling_type: ScalingType = ScalingType.DYNAMIC | ||
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE | ||
static_scale: Optional[torch.Tensor] = None | ||
dtype: Optional[torch.dtype] = None | ||
|
||
def short_str(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we also add the dtype here, so it appears when we print an instance of |
||
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" | ||
|
@@ -75,6 +76,9 @@ def __post_init__(self): | |
assert ( | ||
self.scaling_type is ScalingType.DYNAMIC | ||
), "only dynamic scaling type is supported for axiswise scaling granularity" | ||
assert self.dtype is None or ( | ||
self.dtype.is_floating_point and self.dtype.itemsize == 1 | ||
), "must specify a 8-bit floating-point dtype" | ||
|
||
|
||
@dataclass(frozen=True) | ||
|
@@ -124,6 +128,12 @@ def __post_init__(self): | |
self.e5m2_dtype = torch.float8_e5m2fnuz | ||
|
||
|
||
# User defined type for using the individual F8 type based on config | ||
type_config = Float8TypeConfig() | ||
e4m3_dtype = type_config.e4m3_dtype | ||
e5m2_dtype = type_config.e5m2_dtype | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Float8GemmConfig: | ||
""" | ||
|
@@ -279,6 +289,20 @@ def __post_init__(self): | |
is_disabled_1 == is_disabled_2 | ||
), f"incompatible operand precision for {gemm_name}" | ||
|
||
for cc1, cc2, operand_name, default_dtype in [ | ||
(cc_i, cc_i_gw, "input", e4m3_dtype), | ||
(cc_w, cc_w_gi, "weight", e4m3_dtype), | ||
(cc_go, cc_go_gw, "grad_output", e5m2_dtype), | ||
]: | ||
# Override the dataclass being frozen | ||
if cc1.dtype is None: | ||
object.__setattr__(cc1, "dtype", default_dtype) | ||
if cc2.dtype is None: | ||
object.__setattr__(cc2, "dtype", default_dtype) | ||
assert ( | ||
cc1.dtype == cc2.dtype | ||
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" | ||
|
||
if self.use_fp8_all_gather_only: | ||
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True" | ||
|
||
|
@@ -343,12 +367,14 @@ def recipe_name_to_linear_config( | |
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) | ||
|
||
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise | ||
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) | ||
cc_go = CastConfig( | ||
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe we can also add some context in the comments on L353:L363 that it also uses e4m3 for grads? |
||
) | ||
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) | ||
|
||
# grad_weight_hp = input_t_hp @ grad_output_hp | ||
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) | ||
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) | ||
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype) | ||
|
||
return Float8LinearConfig( | ||
cast_config_input=cc_i, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,9 @@ | |
|
||
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType | ||
from torchao.float8.float8_scaling_utils import ( | ||
NoopFwToFloat8E5M2BwDelayed, | ||
NoopFwToFloat8E5M2BwDynamic, | ||
NoopFwToFloat8E5M2BwStatic, | ||
NoopFwToFloat8BwDelayed, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for updating these! |
||
NoopFwToFloat8BwDynamic, | ||
NoopFwToFloat8BwStatic, | ||
_maybe_initialize_amaxes_scales_for_float8_cast, | ||
get_maybe_axiswise_dim, | ||
hp_tensor_to_float8_delayed, | ||
|
@@ -31,8 +31,6 @@ | |
hp_tensor_and_scale_to_float8, | ||
) | ||
from torchao.float8.float8_utils import ( | ||
e4m3_dtype, | ||
e5m2_dtype, | ||
tensor_to_amax, | ||
tensor_to_scale, | ||
) | ||
|
@@ -135,7 +133,7 @@ def forward( | |
else: | ||
input_maybe_fp8 = hp_tensor_to_float8_dynamic( | ||
input_hp, | ||
e4m3_dtype, | ||
c.cast_config_input.dtype, | ||
linear_mm_config, | ||
gemm_input_role=GemmInputRole.INPUT, | ||
scaling_granularity=c.cast_config_input.scaling_granularity, | ||
|
@@ -149,7 +147,7 @@ def forward( | |
else: | ||
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( | ||
weight_hp_t, | ||
e4m3_dtype, | ||
c.cast_config_weight.dtype, | ||
linear_mm_config, | ||
gemm_input_role=GemmInputRole.WEIGHT, | ||
scaling_granularity=c.cast_config_weight.scaling_granularity, | ||
|
@@ -185,7 +183,7 @@ def backward(ctx, grad_output): | |
else: | ||
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( | ||
grad_output_reshaped, | ||
e5m2_dtype, | ||
c.cast_config_grad_output.dtype, | ||
ctx.linear_mm_config, | ||
gemm_input_role=GemmInputRole.GRAD_OUTPUT, | ||
scaling_granularity=c.cast_config_grad_output.scaling_granularity, | ||
|
@@ -203,7 +201,7 @@ def backward(ctx, grad_output): | |
# the entire tensor. | ||
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( | ||
weight_hp_t, | ||
e4m3_dtype, | ||
c.cast_config_weight_for_grad_input.dtype, | ||
ctx.linear_mm_config, | ||
gemm_input_role=GemmInputRole.WEIGHT, | ||
scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, | ||
|
@@ -235,7 +233,7 @@ def backward(ctx, grad_output): | |
else: | ||
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( | ||
grad_output_reshaped, | ||
e5m2_dtype, | ||
c.cast_config_grad_output_for_grad_weight.dtype, | ||
ctx.linear_mm_config, | ||
gemm_input_role=GemmInputRole.GRAD_OUTPUT, | ||
scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, | ||
|
@@ -249,7 +247,7 @@ def backward(ctx, grad_output): | |
else: | ||
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( | ||
input_hp_reshaped, | ||
e4m3_dtype, | ||
c.cast_config_input_for_grad_weight.dtype, | ||
ctx.linear_mm_config, | ||
gemm_input_role=GemmInputRole.INPUT, | ||
scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, | ||
|
@@ -354,11 +352,9 @@ def create_buffers(self): | |
# Default values for history buffers, see above TODO | ||
history_len = self.config.delayed_scaling_config.history_len | ||
device = self.weight.device | ||
# TODO(future PR): dtype values below don't have the other float8 | ||
# flavors, fix it | ||
default_input = torch.finfo(torch.float8_e4m3fn).max | ||
default_weight = torch.finfo(torch.float8_e4m3fn).max | ||
default_grad_output = torch.finfo(torch.float8_e5m2).max | ||
default_input = torch.finfo(self.config.cast_config_input.dtype).max | ||
default_weight = torch.finfo(self.config.cast_config_weight.dtype).max | ||
default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max | ||
|
||
# Note: for now, create all the buffers if any are needed, to postpone | ||
# the work to make the scale and amax syncing and history calculation | ||
|
@@ -445,29 +441,32 @@ def cast_input_to_float8( | |
self.fp8_amax_history_input, | ||
self.fp8_scale_input, | ||
scale_fn_name, | ||
e4m3_dtype, | ||
self.config.cast_config_input.dtype, | ||
is_amax_initialized, | ||
reduce_amax=True, | ||
) | ||
input_fp8 = hp_tensor_to_float8_delayed( | ||
input, | ||
self.fp8_scale_input, | ||
e4m3_dtype, | ||
self.config.cast_config_input.dtype, | ||
self.fp8_amax_input, | ||
linear_mm_config=self.linear_mm_config, | ||
gemm_input_role=GemmInputRole.INPUT, | ||
) | ||
elif self.scaling_type_input is ScalingType.DYNAMIC: | ||
input_fp8 = hp_tensor_to_float8_dynamic( | ||
input, | ||
e4m3_dtype, | ||
self.config.cast_config_input.dtype, | ||
self.linear_mm_config, | ||
gemm_input_role=GemmInputRole.INPUT, | ||
) | ||
else: | ||
assert self.scaling_type_input is ScalingType.STATIC | ||
input_fp8 = hp_tensor_to_float8_static( | ||
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config | ||
input, | ||
self.fp8_static_scale_input, | ||
self.config.cast_config_input.dtype, | ||
self.linear_mm_config, | ||
) | ||
|
||
return input_fp8 | ||
|
@@ -483,14 +482,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: | |
self.fp8_amax_history_weight, | ||
self.fp8_scale_weight, | ||
scale_fn_name, | ||
e4m3_dtype, | ||
self.config.cast_config_weight.dtype, | ||
self.is_amax_initialized, | ||
reduce_amax=True, | ||
) | ||
self.fp8_amax_weight.fill_(tensor_to_amax(weight)) | ||
return self.fp8_scale_weight | ||
elif self.scaling_type_weight is ScalingType.DYNAMIC: | ||
return tensor_to_scale(weight, e4m3_dtype) | ||
return tensor_to_scale(weight, self.config.cast_config_weight.dtype) | ||
else: | ||
assert self.scaling_type_weight is ScalingType.STATIC | ||
return self.fp8_static_scale_weight | ||
|
@@ -506,7 +505,7 @@ def cast_weight_to_float8_t( | |
weight_fp8 = hp_tensor_and_scale_to_float8( | ||
weight, | ||
weight_scale, | ||
e4m3_dtype, | ||
self.config.cast_config_weight.dtype, | ||
self.linear_mm_config, | ||
gemm_input_role=GemmInputRole.WEIGHT, | ||
) | ||
|
@@ -521,23 +520,29 @@ def cast_weight_to_original_t(self, weight: torch.Tensor): | |
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: | ||
if self.scaling_type_grad_output is ScalingType.DELAYED: | ||
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name | ||
output = NoopFwToFloat8E5M2BwDelayed.apply( | ||
output = NoopFwToFloat8BwDelayed.apply( | ||
output, | ||
self.fp8_amax_grad_output, | ||
self.fp8_amax_history_grad_output, | ||
self.fp8_scale_grad_output, | ||
scale_fn_name, | ||
self.is_amax_initialized, | ||
self.linear_mm_config, | ||
self.config.cast_config_grad_output.dtype, | ||
) | ||
elif self.scaling_type_grad_output is ScalingType.DYNAMIC: | ||
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config) | ||
output = NoopFwToFloat8BwDynamic.apply( | ||
output, | ||
self.linear_mm_config, | ||
self.config.cast_config_grad_output.dtype, | ||
) | ||
else: | ||
assert self.scaling_type_grad_output is ScalingType.STATIC | ||
output = NoopFwToFloat8E5M2BwStatic.apply( | ||
output = NoopFwToFloat8BwStatic.apply( | ||
output, | ||
self.fp8_static_scale_grad_output, | ||
self.linear_mm_config, | ||
self.config.cast_config_grad_output.dtype, | ||
) | ||
return output | ||
|
||
|
@@ -563,19 +568,16 @@ def float8_post_forward(self): | |
self.amax_and_scale_synced = False | ||
|
||
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: | ||
has_any_axiswise_scaling = ( | ||
self.config.cast_config_input.scaling_granularity | ||
is ScalingGranularity.AXISWISE | ||
or self.config.cast_config_weight.scaling_granularity | ||
is ScalingGranularity.AXISWISE | ||
or self.config.cast_config_grad_output.scaling_granularity | ||
is ScalingGranularity.AXISWISE | ||
or self.config.cast_config_input_for_grad_weight.scaling_granularity | ||
is ScalingGranularity.AXISWISE | ||
or self.config.cast_config_weight_for_grad_input.scaling_granularity | ||
is ScalingGranularity.AXISWISE | ||
or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity | ||
is ScalingGranularity.AXISWISE | ||
has_any_axiswise_scaling = any( | ||
cc.scaling_granularity is ScalingGranularity.AXISWISE | ||
for cc in [ | ||
self.config.cast_config_input, | ||
self.config.cast_config_weight, | ||
self.config.cast_config_grad_output, | ||
self.config.cast_config_input_for_grad_weight, | ||
self.config.cast_config_weight_for_grad_input, | ||
self.config.cast_config_grad_output_for_grad_weight, | ||
] | ||
) | ||
|
||
if not has_any_axiswise_scaling: | ||
|
@@ -698,6 +700,7 @@ def from_float( | |
WeightWithDynamicFloat8CastTensor( | ||
new_mod.weight, | ||
new_mod.linear_mm_config, | ||
new_mod.config.cast_config_weight.dtype, | ||
) | ||
) | ||
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: | ||
|
@@ -708,6 +711,7 @@ def from_float( | |
new_mod.fp8_amax_history_weight, | ||
new_mod.fp8_scale_weight, | ||
new_mod.linear_mm_config, | ||
new_mod.config.cast_config_weight.dtype, | ||
new_mod.is_amax_initialized, | ||
) | ||
) | ||
|
@@ -718,6 +722,7 @@ def from_float( | |
new_mod.weight, | ||
new_mod.fp8_static_scale_weight, | ||
new_mod.linear_mm_config, | ||
new_mod.config.cast_config_weight.dtype, | ||
) | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
None
means the default e4m3|e5m2 value will be used?target_dtype
,lowp_dtype
, etc?dtype
is a bit ambiguous across torchao unfortunately :(