Skip to content

Commit

Permalink
Merge branch 'pad_each_bucket' into 'main'
Browse files Browse the repository at this point in the history
Pad each bucket so that its size is divisible by the dp_size when using the distributed optimizer

See merge request ADLR/megatron-lm!889
  • Loading branch information
jaredcasper committed Nov 9, 2023
2 parents e4ef38e + 2bba0f9 commit 443ce9f
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 106 deletions.
31 changes: 1 addition & 30 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import math
from contextlib import contextmanager
from typing import Dict

Expand Down Expand Up @@ -76,7 +75,6 @@ def __init__(

# Group parameters by their gradient type.
grad_dtype_to_params = {}
grad_dtype_to_numel = {}
param_to_name = {}
for name, param in self.module.named_parameters():
if param.requires_grad and getattr(param, 'allreduce', True):
Expand All @@ -88,24 +86,10 @@ def __init__(
params.append(param)
grad_dtype_to_params[dtype] = params

# Calculate number of elements per dtype.
grad_dtype_to_numel[dtype] = (
grad_dtype_to_numel.get(dtype, 0) + param.data.nelement()
)

# Allocate the grad buffers and map the grads.
# The grad buffer under the hood creates buckets as appropriate based on bucket_size.
data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group)
for dtype, params in grad_dtype_to_params.items():
# Pad so size is divisible by the data parallel size.
numel = grad_dtype_to_numel[dtype]
numel_padded = (
int(math.ceil(numel / data_parallel_world_size)) * data_parallel_world_size
)

self.grad_buffers[dtype] = GradBuffer(
numel,
numel_padded,
dtype,
params,
data_parallel_group,
Expand All @@ -114,22 +98,9 @@ def __init__(
self.overlap_grad_reduce,
self.use_distributed_optimizer,
)

# Parameters are laid out in the corresponding grad_buffer in reverse
# order, so count indices from the back.
index = grad_dtype_to_numel[dtype]
self.grad_buffer_param_index_map[dtype] = self.grad_buffers[dtype].param_index_map
for param in params:
self.param_to_grad_buffer[param] = self.grad_buffers[dtype]
if dtype not in self.grad_buffer_param_index_map:
self.grad_buffer_param_index_map[dtype] = {}

index -= param.data.nelement()
# Store the indices / bucket of each param.
self.grad_buffer_param_index_map[dtype][param] = (
index,
index + param.data.nelement(),
self.grad_buffers[dtype].param_to_bucket_index[param],
)

# Allocate discreate buffer for MoE params' grads
for param in self.module.parameters():
Expand Down
187 changes: 124 additions & 63 deletions megatron/core/distributed/grad_buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import math
from logging import getLogger
from typing import Dict, List

Expand All @@ -10,13 +11,10 @@
logger = getLogger(__name__)


def shard_buffer(buffer: torch.Tensor):
def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
"""
Shard buffer into dp_size chunks of equal size.
Shard buffer into data_parallel_world_size chunks of equal size.
"""
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
assert buffer.numel() % data_parallel_world_size == 0
shard_size = buffer.numel() // data_parallel_world_size
sharded_buffer = [
Expand All @@ -36,6 +34,7 @@ class Bucket:
data: View in larger GradBuffer that this bucket is responsible for.
offset: Offset of this bucket's view in the larger GradBuffer.
data_parallel_group: Data-parallel process group.
data_parallel_world_size: World size using the data-parallel group group.
overlap_grad_reduce: If true, overlap communication with backprop computation by
breaking up grads into buckets. If false, single synchronous communication call
is used instead.
Expand All @@ -49,6 +48,7 @@ def __init__(
data: torch.Tensor,
offset: int,
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_world_size: int,
overlap_grad_reduce: bool,
use_distributed_optimizer: bool,
):
Expand All @@ -64,12 +64,11 @@ def __init__(
# within the full grad_buffer.
self.offset = offset
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
self.overlap_grad_reduce = overlap_grad_reduce
self.use_distributed_optimizer = use_distributed_optimizer

self.data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group)
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)

self.reset()

def reset(self):
Expand All @@ -96,7 +95,9 @@ def start_grad_sync(self):
self.data /= self.data_parallel_world_size
# Use async_op only when overlap_grad_reduce is True.
if self.use_distributed_optimizer:
local_data_view = shard_buffer(self.data)[self.data_parallel_rank]
local_data_view = shard_buffer(self.data, self.data_parallel_world_size)[
self.data_parallel_rank
]
self.communication_handle = torch.distributed._reduce_scatter_base(
local_data_view,
self.data,
Expand Down Expand Up @@ -151,8 +152,6 @@ class GradBuffer:
roughly `bucket_size` parameters each.
Arguments:
numel: True number of elements.
numel_padded: Number of elements in underlying tensor.
dtype: Type of underlying tensor.
params: List of parameters whose gradients are collated in the underlying tensor.
data_parallel_group: Data-parallel process group.
Expand All @@ -167,8 +166,6 @@ class GradBuffer:

def __init__(
self,
numel: int,
numel_padded: int,
dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
Expand All @@ -177,23 +174,6 @@ def __init__(
overlap_grad_reduce: bool,
use_distributed_optimizer: bool,
):
self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype
self.data = torch.zeros(
self.numel_padded,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)

self.buckets = []
self.param_to_bucket = {}
self.param_to_bucket_index = {}
self.overlap_grad_reduce = overlap_grad_reduce
self.use_distributed_optimizer = use_distributed_optimizer

self.is_last_microbatch = True

# Check that params are unique.
unique_params = set()
Expand All @@ -202,65 +182,111 @@ def __init__(
unique_params.add(param)
del unique_params

# Helper function to create new bucket, add it to list of buckets, and
# also update param->bucket mapping.
def _set_bucket(
bucket_params: List[torch.nn.Parameter], data_start_index: int, data_end_index: int
):
# Store attributes that will be needed later.
self.dtype = dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.overlap_grad_reduce = overlap_grad_reduce
self.use_distributed_optimizer = use_distributed_optimizer
self.is_last_microbatch = True

# Get appropriate view into global GradBuffer.
bucket_data = self._get(
torch.Size([data_end_index - data_start_index]), data_start_index
)
bucket = Bucket(
bucket_params,
bucket_data,
data_start_index,
data_parallel_group,
self.overlap_grad_reduce,
self.use_distributed_optimizer,
)
self.buckets.append(bucket)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
assert bucket_param not in self.param_to_bucket_index
self.param_to_bucket[bucket_param] = bucket
self.param_to_bucket_index[bucket_param] = len(self.buckets) - 1

# Map the grads to the buffer and bucket them.
# Data structures to store underlying buckets and relevant indexing data.
self.buckets = []
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).

def _pad_if_needed(data_index: int):
"""Pads data indices if using distributed optimizer (to ensure uniform sharding)."""
if use_distributed_optimizer:
return (
int(math.ceil(data_index / self.data_parallel_world_size))
* self.data_parallel_world_size
)
return data_index

# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
data_start_index = 0
bucket_data_start_index = data_start_index
bucket_params = set()

# Iterate through parameters in reverse order to roughly follow backprop order.
self.bucket_indices = []
bucket_id = 0
for param in params[::-1]:
# Skip parameters that don't require gradients.
# Iterate through parameters in reverse order to roughly follow backprop order,
# and skip parameters that don't require gradients.
if not param.requires_grad:
continue
this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel
param.main_grad = self._get(param.data.shape, data_start_index)
self.param_index_map[param] = (
data_start_index,
data_end_index,
bucket_id,
)
bucket_params.add(param)

# If we have enough elements already, form a new buffer.
# If we have enough elements already, form a new bucket.
# If bucket_size is None, accumulate everything into a single bucket.
if bucket_size is not None:
if (data_end_index - bucket_data_start_index) >= bucket_size:
_set_bucket(bucket_params, bucket_data_start_index, data_end_index)
data_end_index = _pad_if_needed(data_end_index)
self.bucket_indices.append((bucket_data_start_index, data_end_index))
bucket_data_start_index = data_end_index
bucket_params = set()
bucket_id += 1
data_start_index = data_end_index

# Add remaining params to a new bucket.
if len(bucket_params) > 0:
_set_bucket(bucket_params, bucket_data_start_index, data_end_index)
data_end_index = _pad_if_needed(data_end_index)
self.bucket_indices.append((bucket_data_start_index, data_end_index))

# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self.numel = data_end_index
if use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
self.data = torch.zeros(
self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False,
)

# Finally, map main_grad fields for each parameter with a .grad field.
bucket_params = set()
bucket_data_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.param_index_map[param]
param.main_grad = self._get(param.data.shape, data_start_index)
if bucket_id != cur_bucket_id:
bucket_data_end_index = _pad_if_needed(data_start_index)
self._set_bucket(
bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id
)
bucket_data_start_index = bucket_data_end_index
bucket_params = set()
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.add(param)

# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_data_end_index = _pad_if_needed(data_end_index)
self._set_bucket(
bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id
)

if not overlap_grad_reduce:
assert len(bucket_params) == len(
params
), 'All params should be in one bucket when overlap_grad_reduce is False'

# Print buckets for all PP stages.
# Log buckets for all PP stages.
if (
parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
and parallel_state.get_tensor_model_parallel_rank() == 0
Expand All @@ -287,6 +313,41 @@ def _get(self, shape: torch.Size, start_index: int) -> torch.Tensor:
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor

def _set_bucket(
self,
bucket_params: List[torch.nn.Parameter],
start_index: int,
end_index: int,
bucket_id: int,
):
"""
Helper function to create new bucket, add it to list of buckets, and
also update param->bucket mapping.
"""

# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if self.use_distributed_optimizer:
assert start_index % self.data_parallel_world_size == 0
assert end_index % self.data_parallel_world_size == 0
assert (start_index, end_index) == self.bucket_indices[bucket_id]

# Get appropriate view into global GradBuffer.
bucket_data = self._get(torch.Size([end_index - start_index]), start_index)
bucket = Bucket(
params=bucket_params,
data=bucket_data,
offset=start_index,
data_parallel_group=self.data_parallel_group,
data_parallel_world_size=self.data_parallel_world_size,
overlap_grad_reduce=self.overlap_grad_reduce,
use_distributed_optimizer=self.use_distributed_optimizer,
)
self.buckets.append(bucket)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
self.param_to_bucket[bucket_param] = bucket

def reset(self):
"""
Zero out the underlying buffer and reset all buckets in preparation for the next
Expand Down
21 changes: 12 additions & 9 deletions megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,

# Model grad buffer ranges.
self.model_gbuf_ranges = []
self.bucket_sizes = []
for model_index, model in enumerate(self.models):
self.bucket_sizes.append(model.bucket_size)
self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
self.per_bucket_numel = []
for _, model_chunk in enumerate(self.models):
self.per_bucket_numel.append(
{dtype: [bucket.data.numel() for bucket in model_chunk.grad_buffers[dtype].buckets]
for dtype in model_chunk.grad_buffers})
self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model_chunk))
self.model_param_gbuf_map = \
self.build_model_param_gbuf_map(self.model_gbuf_ranges)

Expand Down Expand Up @@ -607,7 +609,7 @@ def save_parameter_state(self, filename):
data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP)

# Collect param states.
state = {"bucket_sizes": self.bucket_sizes}
state = {"per_bucket_numel": self.per_bucket_numel}
for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges):

# Iterate grad buffers (by data type).
Expand Down Expand Up @@ -706,10 +708,11 @@ def load_parameter_state(self, filename):
# Load on DP rank 0.
if data_parallel_rank == 0:
loaded_state = torch.load(filename)
if "bucket_sizes" in loaded_state:
bucket_sizes_in_checkpoint = loaded_state["bucket_sizes"]
assert self.bucket_sizes == bucket_sizes_in_checkpoint, \
f"Bucket sizes need to be the same in current run ({self.bucket_sizes}) and checkpoint ({bucket_sizes_in_checkpoint})"
if "per_bucket_numel" in loaded_state:
per_bucket_numel_in_checkpoint = loaded_state["per_bucket_numel"]
assert self.per_bucket_numel == per_bucket_numel_in_checkpoint, \
(f"Number of elements in each bucket need to be the same in current run "
f"({self.per_bucket_numel}) and checkpoint ({per_bucket_numel_in_checkpoint})")

# Scatter tensors to all DP ranks.
for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges):
Expand Down
Loading

0 comments on commit 443ce9f

Please sign in to comment.