Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero checkpoint #6792

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,19 +514,20 @@ def to_torch_tensor(state_dict, return_empty_tensor=False):
"""
Convert state_dict of GatheredTensor to torch tensor
"""
torch_state_dict = {}
converted_tensors = {}
for name, tensor in state_dict.items():
tensor_id = id(tensor)
if tensor_id in converted_tensors:
shared_tensor = state_dict[converted_tensors[tensor_id]]
state_dict[name] = shared_tensor
if tensor_id in converted_tensors: # shared tensors
shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
torch_state_dict[name] = shared_tensor
else:
converted_tensors[tensor_id] = name
if return_empty_tensor:
state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
else:
state_dict[name] = tensor.contiguous()
return state_dict
torch_state_dict[name] = tensor.contiguous()
return torch_state_dict


def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
Expand Down Expand Up @@ -660,8 +661,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
else:
torch.save(shard_state_dict, output_path)
# release the memory of current shard
for tensor_name in shard_state_dict:
for tensor_name in list(shard_state_dict.keys()):
del state_dict[tensor_name]
del shard_state_dict[tensor_name]
del shard_state_dict
gc.collect()

Expand Down
Loading