diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index e93cb1c95f15..0e759146cadd 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -514,19 +514,20 @@ def to_torch_tensor(state_dict, return_empty_tensor=False): """ Convert state_dict of GatheredTensor to torch tensor """ + torch_state_dict = {} converted_tensors = {} for name, tensor in state_dict.items(): tensor_id = id(tensor) - if tensor_id in converted_tensors: - shared_tensor = state_dict[converted_tensors[tensor_id]] - state_dict[name] = shared_tensor + if tensor_id in converted_tensors: # shared tensors + shared_tensor = torch_state_dict[converted_tensors[tensor_id]] + torch_state_dict[name] = shared_tensor else: converted_tensors[tensor_id] = name if return_empty_tensor: - state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype) + torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype) else: - state_dict[name] = tensor.contiguous() - return state_dict + torch_state_dict[name] = tensor.contiguous() + return torch_state_dict def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, @@ -660,8 +661,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, else: torch.save(shard_state_dict, output_path) # release the memory of current shard - for tensor_name in shard_state_dict: + for tensor_name in list(shard_state_dict.keys()): del state_dict[tensor_name] + del shard_state_dict[tensor_name] del shard_state_dict gc.collect()