Skip to content

Commit

Permalink
new changes
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 28, 2024
1 parent fbbbf4d commit b764b97
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 19 deletions.
3 changes: 2 additions & 1 deletion src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _set_and_quantize_weights(self, data: torch.Tensor, recipe: FP8LinearRecipe
# in [torch.int8, torch.uint8] dtype, then we can assign int|uint8 gradient to it
# TODO(xrsrke): keep the metadata of the original NanotronParameter
# setattr(self, "weight", NanotronParameter(tensor=quant_w))
setattr(self, "weight", NanotronParameter.create_param_that_share_metadata(quant_w, self.weight))
new_param = NanotronParameter.create_param_that_share_metadata(quant_w, param=self.weight)
setattr(self, "weight", new_param)

# if self.name == "model.decoder.0.attention.qkv_proj":
# assert 1 == 1
Expand Down
16 changes: 8 additions & 8 deletions src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,9 @@ def convert_model_to_fp8(model: NanotronModel, config: FP8Args) -> NanotronModel

assert 1 == 1
# NOTE: convert to FP8
from nanotron.fp8.tensor import FP8Tensor

# from nanotron import constants
from nanotron.fp8.utils import find_fp8_config_by_module_name
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.tensor_parallel.nn import (
FP8TensorParallelColumnLinear,
FP8TensorParallelRowLinear,
Expand All @@ -367,16 +365,18 @@ def convert_model_to_fp8(model: NanotronModel, config: FP8Args) -> NanotronModel
# TODO(xrsrke): retrieve custom recipe
module._set_and_quantize_weights(module.weight.data)

assert isinstance(module.weight, NanotronParameter)
assert isinstance(module.weight.data, FP8Tensor)
assert module.weight.data.dtype in [
torch.uint8,
torch.int8,
], f"got {module.weight.data.dtype}, name: {name}"
# assert isinstance(module.weight, NanotronParameter)
# assert module.weight.data.__class__ == FP8Tensor
# assert module.weight.data.dtype in [
# torch.uint8,
# torch.int8,
# ], f"got {module.weight.data.dtype}, name: {name}"
else:
# NOTE: convert it to the residual stream's dtype
# for p in module.parameters():
# p.data = p.data.to(self.config.model.dtype)
module.to(dtype=config.resid_dtype)
# pass
# assert module.weight.data.__class__ == torch.Tensor

return model
16 changes: 11 additions & 5 deletions src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def is_sharded(self) -> bool:
self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME
)

def __repr__(self):
return f"NanotronParameter({super().__repr__()})"
# def __repr__(self):
# return f"NanotronParameter({super().__repr__()})"

@property
def data(self):
Expand All @@ -291,12 +291,18 @@ def data(self, data):

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
from nanotron.fp8.tensor import FP8Tensor

print(f"__torch_dispatch__ called with func: {func}, args: {args}, kwargs: {kwargs}")

if func in {torch._tensor_str._str, repr}:
return super().__torch_dispatch__(func, types, args, kwargs)

def unwrap(e):
print(f"Unwrapping: {e} (type: {type(e)})")
return e._data if e.__class__ == NanotronParameter else e

def wrap(e):
from nanotron.fp8.tensor import FP8Tensor

if not e.__class__ == NanotronParameter and e.__class__ in [torch.Tensor, FP8Tensor]:
return cls(e)
else:
Expand All @@ -323,7 +329,7 @@ def wrap(e):
torch.ops.aten._to_copy.default,
]

if func == torch.ops.aten.detach.default and unwrapped_args[0].__class__ == FP8Parameter:
if func == torch.ops.aten.detach.default and unwrapped_args[0].__class__ == FP8Tensor:
# NOTE: this is for parameter.data or parameter.detach()
# NOTE: because we already retrieved the data from unwrap, we don't need to do it again
# data = args[0].data
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(
assert 1 == 1
print("before quantize")
print_sanity_params(self.model)
self.model = convert_model_to_fp8(self.model)
self.model = convert_model_to_fp8(self.model, config=constants.CONFIG.fp8)
print("after quantize")
print_sanity_params(self.model)
assert 1 == 1
Expand Down
23 changes: 19 additions & 4 deletions tests/fp8/test_fp8_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.utils import convert_model_to_fp8
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.testing.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config
from nanotron.testing.utils import init_distributed, rerun_if_address_is_in_use
from torch import nn


# NOTE: fp8 quantization should be parametrization-method-agnotic
Expand Down Expand Up @@ -36,11 +38,24 @@ def _test_initialize_fp8_model(parallel_context: ParallelContext, fp8_config: FP

for name, module in get_leaf_modules(llama):
recipe = find_fp8_config_by_module_name(name, fp8_config)

assert all(p.__class__ == NanotronParameter for p in module.parameters())
if recipe is None:
assert all(p.dtype == fp8_config.resid_dtype for p in module.parameters())
assert all(isinstance(p.data, torch.Tensor) for p in module.parameters())
assert all(
p.dtype == fp8_config.resid_dtype for p in module.parameters()
), f"name: {name}, __class__: {module.weight.data.__class__}"
try:
assert all(
p.data.__class__ == nn.Parameter for p in module.parameters()
), f"name: {name}, __class__: {module.weight.data.__class__}"
except:
assert 1 == 1
else:
assert all(isinstance(p.data, FP8Tensor) for p in module.parameters())

assert all(
isinstance(p.data.__class__, FP8Tensor) for p in module.parameters()
), f"name: {name}, __class__: {module.weight.data.__class__}"
assert all(
p.dtype in [torch.int8, torch.uint8] for p in module.parameters()
), f"name: {name}, __class__: {module.weight.data.__class__}"
# NOTE: check the expected parameters have fp8 dtype
# NOTE: check the dtype of non-fp8 parameters

0 comments on commit b764b97

Please sign in to comment.