Skip to content

Commit

Permalink
Merge branch 'dist_common_fix' into 'main'
Browse files Browse the repository at this point in the history
Check common state dict consistancy across ranks and log warning in case of mismatch.

See merge request ADLR/megatron-lm!2085
  • Loading branch information
Shanmugam Ramasamy committed Nov 13, 2024
2 parents 0e29f58 + 2e7030e commit 64cbae5
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 16 deletions.
5 changes: 4 additions & 1 deletion megatron/core/dist_checkpointing/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
_is_mismatch = not torch.all(x1 == x2)
if x1.device != x2.device:
_is_mismatch = not torch.all(x1.cpu() == x2.cpu())
else:
_is_mismatch = not torch.all(x1 == x2)
# TODO: change with concrete type that has both replica_id and data attrs
elif hasattr(x1, 'replica_id') and hasattr(x2, 'replica_id'):
assert type(x1) == type(x2)
Expand Down
1 change: 1 addition & 0 deletions megatron/core/dist_checkpointing/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict = Dict[str, Any]
CommonStateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]]

Expand Down
15 changes: 12 additions & 3 deletions megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import logging
from pathlib import Path
from typing import Dict, Optional, Set, Tuple, Union
from typing import Callable, Dict, Optional, Set, Tuple, Union

import torch

Expand All @@ -19,6 +19,7 @@
from .dict_utils import extract_matching_values, merge
from .mapping import (
CheckpointingException,
CommonStateDict,
ShardedObject,
ShardedStateDict,
StateDict,
Expand Down Expand Up @@ -287,6 +288,7 @@ def save(
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None,
) -> Optional[AsyncRequest]:
"""Saving entrypoint.
Expand Down Expand Up @@ -320,11 +322,16 @@ def save(
common_strategy (SaveCommonStrategy, Tuple[str, int], optional):
configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
exactly once (as main replica) by some process.
It also makes sure the common state dict is consistant across all ranks
async_sharded_save (bool, optional): if True, for the sharded state dict part
an async save implementation will be called, with the AsyncRequest
being returned to the caller. Note that it is the caller responsibility to
actually schedule the async save. Defaults to False.
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None):
A callable function that will preprocess the common state dict (i.e can be used to
remove keys that we expect to be different in the state dict). The function must not
modify the original state dict
Returns:
AsyncRequest (optional): if `async_sharded_save` is True, returns
Expand Down Expand Up @@ -359,7 +366,9 @@ def save(
assert isinstance(common_strategy, tuple), type(common_strategy)
common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy)

sharded_state_dict, state_dict = save_preprocess(sharded_state_dict, validate_access_integrity)
sharded_state_dict, state_dict = save_preprocess(
sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check
)

common_strategy.save_common(state_dict, checkpoint_dir)

Expand Down
23 changes: 20 additions & 3 deletions megatron/core/dist_checkpointing/state_dict_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

import logging
from time import time
from typing import Any, Optional
from typing import Any, Callable, Optional

import torch

from .dict_utils import dict_list_map_inplace, extract_matching_values, merge, nested_values
from .exchange_utils import determine_main_replica_uniform_distribution, exchange_by_distribution
from .mapping import (
CommonStateDict,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
)
Expand All @@ -29,14 +31,21 @@
logger = logging.getLogger(__name__)


def save_preprocess(sharded_state_dict: ShardedStateDict, validate_access_integrity: bool = True):
def save_preprocess(
sharded_state_dict: ShardedStateDict,
validate_access_integrity: bool = True,
preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None,
):
"""Preprocesses the given state dictionary by applying factories,
discarding non-persistent data and extracting the common state dictionary.
Optionally, it can validate sharding integrity.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed.
validate_access_integrity (bool): If True, triggers validation of sharding integrity.
preprocess_common_before_consistancy_check (callable, None): A callable function
that will preprocess the common state dict (i.e can be used to remove keys
that we expect to be different in the state dict)
Returns:
Tuple[ShardedStateDict, dict]:
Expand All @@ -46,7 +55,15 @@ def save_preprocess(sharded_state_dict: ShardedStateDict, validate_access_integr
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
preprocessed_common_state_dict = common_state_dict
if preprocess_common_before_consistancy_check:
preprocessed_common_state_dict = preprocess_common_before_consistancy_check(
common_state_dict
)
validate_sharding_integrity(
determine_global_metadata(sharded_part)[1],
common_state_dict=preprocessed_common_state_dict,
)
return sharded_part, common_state_dict


Expand Down
39 changes: 36 additions & 3 deletions megatron/core/dist_checkpointing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config
from megatron.core.dist_checkpointing.dict_utils import (
diff,
extract_matching_values,
map_reduce,
nested_values,
)
from megatron.core.dist_checkpointing.mapping import (
CommonStateDict,
ShardedBase,
ShardedObject,
ShardedStateDict,
Expand All @@ -34,10 +36,10 @@
from megatron.core.dist_checkpointing.serialization import CkptShardedMetadata

logger = logging.getLogger(__name__)

# pylint: disable=line-too-long
# list of local saved/loaded ShardedBase objects
_LocalMetadata = List[Union[ShardedTensor, ShardedObject]]
# list of lists of global saved/loaded ShardedBase objects (each list element corresponds to global rank)
# list of lists of global saved/loaded ShardedBase objects (each element corresponds to global rank)
_GlobalMetadata = List[_LocalMetadata]


Expand Down Expand Up @@ -362,7 +364,33 @@ def maybe_report_missing_and_unexpected_keys(
logger.warning(error_msg)


def validate_sharding_integrity(global_metadata: _GlobalMetadata) -> None:
def _validate_common_state_dict(common_state_dict: CommonStateDict):
"""Validate consistancy across ranks for the common state dict
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.
Args:
common_state_dict: The common state dict present in all ransk
"""
other_rank_state_dicts = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(other_rank_state_dicts, common_state_dict)
common_state_dict_diff = {}
if torch.distributed.get_rank() == 0:
main_rank_state_dict = common_state_dict
for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1):
only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict)
if only_left or only_right or mismatch:
common_state_dict_diff[rank] = (only_left, only_right, mismatch)

if len(common_state_dict_diff) != 0:
logger.warning(
f'There is difference in the common state dict in different ranks. The differences are {common_state_dict_diff}'
)


def validate_sharding_integrity(
global_metadata: _GlobalMetadata, common_state_dict: CommonStateDict = None
) -> None:
"""Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.
Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object`
Expand All @@ -372,13 +400,18 @@ def validate_sharding_integrity(global_metadata: _GlobalMetadata) -> None:
Args:
global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks.
common_state_dict (CommonStateDict): The common state dict stored by rank 0
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""

if common_state_dict:
_validate_common_state_dict(common_state_dict)

if torch.distributed.get_rank() != 0:
return

Expand Down
5 changes: 3 additions & 2 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class CheckpointType(Enum):

def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far,
checkpointing_context=None, pipeline_rank=None, expert_rank=None, tensor_rank=None, pipeline_parallel=None, expert_parallel=None, non_persistent_ckpt=False,
train_data_iterator=None, ft_client=None):
train_data_iterator=None, ft_client=None, preprocess_common_state_dict_fn = None):
"""Save a model, optimizer and optionally dataloader checkpoint.
Checkpointing context is used to persist some checkpointing state
Expand Down Expand Up @@ -436,7 +436,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
async_sharded_save=args.async_save,
validate_access_integrity=validate_sharding_integrity)
validate_access_integrity=validate_sharding_integrity,
preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn)
# [ModelOpt]: save sharded modelopt_state
if has_nvidia_modelopt:
save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1))
Expand Down
15 changes: 13 additions & 2 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,17 @@ def _get_field(string, type):
start_num_floating_point_operations


def preprocess_common_state_dict(common_state_dict):
import copy
# Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different
preprocessed_common_state_dict['args'].pop('local_rank', None)
preprocessed_common_state_dict['args'].pop('rank', None)
return preprocessed_common_state_dict


def pretrain(
train_valid_test_dataset_provider,
model_provider,
Expand Down Expand Up @@ -369,7 +380,7 @@ def pretrain(
num_floating_point_operations_so_far, checkpointing_context,
train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT))
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)

one_logger and one_logger.log_metrics({
'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()
Expand Down Expand Up @@ -1095,7 +1106,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context,
non_persistent_ckpt=non_persistent_ckpt, train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT))
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
timers(timer_key).stop(barrier=True)
Expand Down
16 changes: 15 additions & 1 deletion tests/unit_tests/dist_checkpointing/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,16 @@ def teardown_method(self, method):
)
def test_fp32_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp):
# sync=True to make sure other ranks wait for rank 0 to finish creating directory.

def preprocess_fn(optim_common_dict):
import copy

preprocessed_optimzier_common_dict = copy.deepcopy(optim_common_dict)
list = preprocessed_optimzier_common_dict['optimizer']['param_groups']
for dict_item in list:
del dict_item['wd_mult']
return preprocessed_optimzier_common_dict

Utils.initialize_model_parallel(*src_tp_pp)
with TempNamedDir(
tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=True
Expand All @@ -418,7 +428,11 @@ def test_fp32_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_
bf16=False,
)

save(optimizer_A.sharded_state_dict(model_A[0].sharded_state_dict()), ckpt_dir_A)
save(
optimizer_A.sharded_state_dict(model_A[0].sharded_state_dict()),
ckpt_dir_A,
preprocess_common_before_consistancy_check=preprocess_fn,
)
Utils.destroy_model_parallel()

# Load checkpoint A with different TP/PP and save as checkpoint B
Expand Down
49 changes: 48 additions & 1 deletion tests/unit_tests/dist_checkpointing/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,22 @@ def test_multi_process_save(self, tmp_path_dist_ckpt):
'sd_keyB': ShardedTensor.from_rank_offsets(
'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
),
'lr': 0.01,
'rank': torch.distributed.get_rank(),
}

def preprocess_fn(x):
del x['rank']
return x

# sync=True to make sure other ranks wait for rank 0 to finish creating directory.
with TempNamedDir(tmp_path_dist_ckpt / 'test_multi_process_save', sync=True) as ckpt_dir:
save(state_dict, ckpt_dir)
save(
state_dict,
ckpt_dir,
validate_access_integrity=True,
preprocess_common_before_consistancy_check=preprocess_fn,
)

saved_config = maybe_load_config(ckpt_dir)
if saved_config.sharded_backend == 'zarr':
Expand All @@ -94,6 +105,42 @@ def test_multi_process_save(self, tmp_path_dist_ckpt):

Utils.destroy_model_parallel()

def test_multi_process_save_log_difference(self, tmp_path_dist_ckpt, caplog):
Utils.initialize_model_parallel(2, 4)

state_dict = {
'sd_keyA': ShardedTensor.from_rank_offsets(
'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size)
),
'sd_keyB': ShardedTensor.from_rank_offsets(
'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
),
'rank': torch.distributed.get_rank(),
}

def preprocess_fn(x):
return x

with caplog.at_level(logging.WARNING):
# sync=True to make sure other ranks wait for rank 0 to finish creating directory.
with TempNamedDir(
tmp_path_dist_ckpt / 'test_multi_process_save', sync=True
) as ckpt_dir:
save(
state_dict,
ckpt_dir,
validate_access_integrity=True,
preprocess_common_before_consistancy_check=preprocess_fn,
)
# pylint: disable=line-too-long
if torch.distributed.get_rank() == 0:
assert (
"There is difference in the common state dict in different ranks. The differences are {1: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 2: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 3: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 4: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 5: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 6: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 7: ([], [], [(('rank',), <class 'int'>, <class 'int'>)])}"
in caplog.text
)

Utils.destroy_model_parallel()

def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None):
Utils.initialize_model_parallel(2, 4)

Expand Down

0 comments on commit 64cbae5

Please sign in to comment.