From a8a551294befc27430f09b0ce6e6ceb408c4518b Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Wed, 25 Oct 2023 17:41:22 -0700 Subject: [PATCH 1/3] Pad each bucket to ensure any dp_size can be used with distributed optimizer / overlap_grad_reduce --- .../distributed/distributed_data_parallel.py | 31 +-- megatron/core/distributed/grad_buffer.py | 187 ++++++++++++------ 2 files changed, 125 insertions(+), 93 deletions(-) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 4c2c2ee525..4f7278a4b3 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -1,6 +1,5 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import math from contextlib import contextmanager from typing import Dict @@ -76,7 +75,6 @@ def __init__( # Group parameters by their gradient type. grad_dtype_to_params = {} - grad_dtype_to_numel = {} param_to_name = {} for name, param in self.module.named_parameters(): if param.requires_grad and getattr(param, 'allreduce', True): @@ -88,24 +86,10 @@ def __init__( params.append(param) grad_dtype_to_params[dtype] = params - # Calculate number of elements per dtype. - grad_dtype_to_numel[dtype] = ( - grad_dtype_to_numel.get(dtype, 0) + param.data.nelement() - ) - # Allocate the grad buffers and map the grads. # The grad buffer under the hood creates buckets as appropriate based on bucket_size. - data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group) for dtype, params in grad_dtype_to_params.items(): - # Pad so size is divisible by the data parallel size. - numel = grad_dtype_to_numel[dtype] - numel_padded = ( - int(math.ceil(numel / data_parallel_world_size)) * data_parallel_world_size - ) - self.grad_buffers[dtype] = GradBuffer( - numel, - numel_padded, dtype, params, data_parallel_group, @@ -114,22 +98,9 @@ def __init__( self.overlap_grad_reduce, self.use_distributed_optimizer, ) - - # Parameters are laid out in the corresponding grad_buffer in reverse - # order, so count indices from the back. - index = grad_dtype_to_numel[dtype] + self.grad_buffer_param_index_map[dtype] = self.grad_buffers[dtype].param_index_map for param in params: self.param_to_grad_buffer[param] = self.grad_buffers[dtype] - if dtype not in self.grad_buffer_param_index_map: - self.grad_buffer_param_index_map[dtype] = {} - - index -= param.data.nelement() - # Store the indices / bucket of each param. - self.grad_buffer_param_index_map[dtype][param] = ( - index, - index + param.data.nelement(), - self.grad_buffers[dtype].param_to_bucket_index[param], - ) # Allocate discreate buffer for MoE params' grads for param in self.module.parameters(): diff --git a/megatron/core/distributed/grad_buffer.py b/megatron/core/distributed/grad_buffer.py index 223c2bef18..77b4a40f8e 100644 --- a/megatron/core/distributed/grad_buffer.py +++ b/megatron/core/distributed/grad_buffer.py @@ -1,5 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import math from logging import getLogger from typing import Dict, List @@ -10,13 +11,10 @@ logger = getLogger(__name__) -def shard_buffer(buffer: torch.Tensor): +def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int): """ - Shard buffer into dp_size chunks of equal size. + Shard buffer into data_parallel_world_size chunks of equal size. """ - data_parallel_world_size = parallel_state.get_data_parallel_world_size( - with_context_parallel=True - ) assert buffer.numel() % data_parallel_world_size == 0 shard_size = buffer.numel() // data_parallel_world_size sharded_buffer = [ @@ -36,6 +34,7 @@ class Bucket: data: View in larger GradBuffer that this bucket is responsible for. offset: Offset of this bucket's view in the larger GradBuffer. data_parallel_group: Data-parallel process group. + data_parallel_world_size: World size using the data-parallel group group. overlap_grad_reduce: If true, overlap communication with backprop computation by breaking up grads into buckets. If false, single synchronous communication call is used instead. @@ -49,6 +48,7 @@ def __init__( data: torch.Tensor, offset: int, data_parallel_group: torch.distributed.ProcessGroup, + data_parallel_world_size: int, overlap_grad_reduce: bool, use_distributed_optimizer: bool, ): @@ -64,12 +64,11 @@ def __init__( # within the full grad_buffer. self.offset = offset self.data_parallel_group = data_parallel_group + self.data_parallel_world_size = data_parallel_world_size + self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group) self.overlap_grad_reduce = overlap_grad_reduce self.use_distributed_optimizer = use_distributed_optimizer - self.data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group) - self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group) - self.reset() def reset(self): @@ -96,7 +95,9 @@ def start_grad_sync(self): self.data /= self.data_parallel_world_size # Use async_op only when overlap_grad_reduce is True. if self.use_distributed_optimizer: - local_data_view = shard_buffer(self.data)[self.data_parallel_rank] + local_data_view = shard_buffer(self.data, self.data_parallel_world_size)[ + self.data_parallel_rank + ] self.communication_handle = torch.distributed._reduce_scatter_base( local_data_view, self.data, @@ -151,8 +152,6 @@ class GradBuffer: roughly `bucket_size` parameters each. Arguments: - numel: True number of elements. - numel_padded: Number of elements in underlying tensor. dtype: Type of underlying tensor. params: List of parameters whose gradients are collated in the underlying tensor. data_parallel_group: Data-parallel process group. @@ -167,8 +166,6 @@ class GradBuffer: def __init__( self, - numel: int, - numel_padded: int, dtype: torch.dtype, params: List[torch.nn.Parameter], data_parallel_group: torch.distributed.ProcessGroup, @@ -177,23 +174,6 @@ def __init__( overlap_grad_reduce: bool, use_distributed_optimizer: bool, ): - self.numel = numel - self.numel_padded = numel_padded - self.dtype = dtype - self.data = torch.zeros( - self.numel_padded, - dtype=self.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - self.buckets = [] - self.param_to_bucket = {} - self.param_to_bucket_index = {} - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - - self.is_last_microbatch = True # Check that params are unique. unique_params = set() @@ -202,65 +182,111 @@ def __init__( unique_params.add(param) del unique_params - # Helper function to create new bucket, add it to list of buckets, and - # also update param->bucket mapping. - def _set_bucket( - bucket_params: List[torch.nn.Parameter], data_start_index: int, data_end_index: int - ): + # Store attributes that will be needed later. + self.dtype = dtype + self.data_parallel_group = data_parallel_group + self.data_parallel_world_size = torch.distributed.get_world_size( + group=self.data_parallel_group + ) + self.overlap_grad_reduce = overlap_grad_reduce + self.use_distributed_optimizer = use_distributed_optimizer + self.is_last_microbatch = True - # Get appropriate view into global GradBuffer. - bucket_data = self._get( - torch.Size([data_end_index - data_start_index]), data_start_index - ) - bucket = Bucket( - bucket_params, - bucket_data, - data_start_index, - data_parallel_group, - self.overlap_grad_reduce, - self.use_distributed_optimizer, - ) - self.buckets.append(bucket) - for bucket_param in bucket_params: - assert bucket_param not in self.param_to_bucket - assert bucket_param not in self.param_to_bucket_index - self.param_to_bucket[bucket_param] = bucket - self.param_to_bucket_index[bucket_param] = len(self.buckets) - 1 - - # Map the grads to the buffer and bucket them. + # Data structures to store underlying buckets and relevant indexing data. + self.buckets = [] + self.param_to_bucket = {} # Param -> bucket mapping. + self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer). + + def _pad_if_needed(data_index: int): + """Pads data indices if using distributed optimizer (to ensure uniform sharding).""" + if use_distributed_optimizer: + return ( + int(math.ceil(data_index / self.data_parallel_world_size)) + * self.data_parallel_world_size + ) + return data_index + + # First, figure out how many elements should be in the underlying buffer storage. + # Note that if we need to split the buffer into smaller buckets, each of these + # might need to be padded as well (if using the distributed optimizer). data_start_index = 0 bucket_data_start_index = data_start_index bucket_params = set() - - # Iterate through parameters in reverse order to roughly follow backprop order. + self.bucket_indices = [] + bucket_id = 0 for param in params[::-1]: - # Skip parameters that don't require gradients. + # Iterate through parameters in reverse order to roughly follow backprop order, + # and skip parameters that don't require gradients. if not param.requires_grad: continue this_numel = param.data.nelement() data_end_index = data_start_index + this_numel - param.main_grad = self._get(param.data.shape, data_start_index) + self.param_index_map[param] = ( + data_start_index, + data_end_index, + bucket_id, + ) bucket_params.add(param) - # If we have enough elements already, form a new buffer. + # If we have enough elements already, form a new bucket. # If bucket_size is None, accumulate everything into a single bucket. if bucket_size is not None: if (data_end_index - bucket_data_start_index) >= bucket_size: - _set_bucket(bucket_params, bucket_data_start_index, data_end_index) + data_end_index = _pad_if_needed(data_end_index) + self.bucket_indices.append((bucket_data_start_index, data_end_index)) bucket_data_start_index = data_end_index bucket_params = set() + bucket_id += 1 data_start_index = data_end_index # Add remaining params to a new bucket. if len(bucket_params) > 0: - _set_bucket(bucket_params, bucket_data_start_index, data_end_index) + data_end_index = _pad_if_needed(data_end_index) + self.bucket_indices.append((bucket_data_start_index, data_end_index)) + + # Next, create underlying storage for buffer (with numel elements that includes + # padding as necessary). + self.numel = data_end_index + if use_distributed_optimizer: + assert self.numel % self.data_parallel_world_size == 0 + self.data = torch.zeros( + self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False, + ) + + # Finally, map main_grad fields for each parameter with a .grad field. + bucket_params = set() + bucket_data_start_index = 0 + cur_bucket_id = 0 + for param in params[::-1]: + if not param.requires_grad: + continue + data_start_index, data_end_index, bucket_id = self.param_index_map[param] + param.main_grad = self._get(param.data.shape, data_start_index) + if bucket_id != cur_bucket_id: + bucket_data_end_index = _pad_if_needed(data_start_index) + self._set_bucket( + bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id + ) + bucket_data_start_index = bucket_data_end_index + bucket_params = set() + assert cur_bucket_id + 1 == len(self.buckets) + assert bucket_id == cur_bucket_id + 1 + cur_bucket_id = bucket_id + bucket_params.add(param) + + # Add remaining params to a new bucket. + if len(bucket_params) > 0: + bucket_data_end_index = _pad_if_needed(data_end_index) + self._set_bucket( + bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id + ) if not overlap_grad_reduce: assert len(bucket_params) == len( params ), 'All params should be in one bucket when overlap_grad_reduce is False' - # Print buckets for all PP stages. + # Log buckets for all PP stages. if ( parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0 and parallel_state.get_tensor_model_parallel_rank() == 0 @@ -287,6 +313,41 @@ def _get(self, shape: torch.Size, start_index: int) -> torch.Tensor: buffer_tensor = buffer_tensor.view(shape) return buffer_tensor + def _set_bucket( + self, + bucket_params: List[torch.nn.Parameter], + start_index: int, + end_index: int, + bucket_id: int, + ): + """ + Helper function to create new bucket, add it to list of buckets, and + also update param->bucket mapping. + """ + + # Assert that indices are correctly padded (if needed), and that bucket + # position is same as originally computed. + if self.use_distributed_optimizer: + assert start_index % self.data_parallel_world_size == 0 + assert end_index % self.data_parallel_world_size == 0 + assert (start_index, end_index) == self.bucket_indices[bucket_id] + + # Get appropriate view into global GradBuffer. + bucket_data = self._get(torch.Size([end_index - start_index]), start_index) + bucket = Bucket( + params=bucket_params, + data=bucket_data, + offset=start_index, + data_parallel_group=self.data_parallel_group, + data_parallel_world_size=self.data_parallel_world_size, + overlap_grad_reduce=self.overlap_grad_reduce, + use_distributed_optimizer=self.use_distributed_optimizer, + ) + self.buckets.append(bucket) + for bucket_param in bucket_params: + assert bucket_param not in self.param_to_bucket + self.param_to_bucket[bucket_param] = bucket + def reset(self): """ Zero out the underlying buffer and reset all buckets in preparation for the next From 0904a051ac22ab39340102a4a09fec57aeb4478b Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Sat, 4 Nov 2023 17:19:43 -0700 Subject: [PATCH 2/3] Make sure padding is the same across checkpoint and current run --- megatron/optimizer/distrib_optimizer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py index a45a3f101e..9875d192d9 100644 --- a/megatron/optimizer/distrib_optimizer.py +++ b/megatron/optimizer/distrib_optimizer.py @@ -388,10 +388,12 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, # Model grad buffer ranges. self.model_gbuf_ranges = [] - self.bucket_sizes = [] - for model_index, model in enumerate(self.models): - self.bucket_sizes.append(model.bucket_size) - self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model)) + self.per_bucket_numel = [] + for _, model_chunk in enumerate(self.models): + self.per_bucket_numel.append( + {dtype: [bucket.data.numel() for bucket in model_chunk.grad_buffers[dtype].buckets] + for dtype in model_chunk.grad_buffers}) + self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model_chunk)) self.model_param_gbuf_map = \ self.build_model_param_gbuf_map(self.model_gbuf_ranges) @@ -607,7 +609,7 @@ def save_parameter_state(self, filename): data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP) # Collect param states. - state = {"bucket_sizes": self.bucket_sizes} + state = {"per_bucket_numel": self.per_bucket_numel} for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): # Iterate grad buffers (by data type). @@ -706,10 +708,11 @@ def load_parameter_state(self, filename): # Load on DP rank 0. if data_parallel_rank == 0: loaded_state = torch.load(filename) - if "bucket_sizes" in loaded_state: - bucket_sizes_in_checkpoint = loaded_state["bucket_sizes"] - assert self.bucket_sizes == bucket_sizes_in_checkpoint, \ - f"Bucket sizes need to be the same in current run ({self.bucket_sizes}) and checkpoint ({bucket_sizes_in_checkpoint})" + if "per_bucket_numel" in loaded_state: + per_bucket_numel_in_checkpoint = loaded_state["per_bucket_numel"] + assert self.per_bucket_numel == per_bucket_numel_in_checkpoint, \ + (f"Number of elements in each bucket need to be the same in current run " + f"({self.per_bucket_numel}) and checkpoint ({per_bucket_numel_in_checkpoint})") # Scatter tensors to all DP ranks. for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): From 2bba0f995423e3b432c4bbc1dba7e9abdf03302f Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Mon, 30 Oct 2023 09:29:59 -0700 Subject: [PATCH 3/3] Update gold values for distributed optimizer CI tests Gold values changed because order of parameters in DistOpt data structures changed, changing the grad norm slightly --- .../gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer.json | 2 +- ...1_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json | 2 +- ...eaved_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json | 2 +- ...4_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer.json b/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer.json index 1bd8968a88..1363208e68 100644 --- a/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer.json +++ b/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer.json @@ -1 +1 @@ -{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.87174, 10.89545, 10.88847, 10.88533, 10.893, 10.84895, 10.70048, 10.64124, 10.53839, 10.3107]}, "num-zeros": {"start_step": 0, "end_step": 32, "step_interval": 5, "values": [1238.0, 1318.0, 1774.0, 1416.0, 1549.0, 1271.0, 1270.0]}, "iteration_timing_avg": 0.05975970588235295} \ No newline at end of file +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.87174, 10.89545, 10.88847, 10.88533, 10.893, 10.84895, 10.70048, 10.64124, 10.53839, 10.3107]}, "num-zeros": {"start_step": 0, "end_step": 32, "step_interval": 5, "values": [1238.0, 1318.0, 1648.0, 1423.0, 1535.0, 1350.0, 1271.0]}, "iteration_timing_avg": 0.06013999999999999} \ No newline at end of file diff --git a/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json b/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json index 6127288581..36ee6cf395 100644 --- a/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json +++ b/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json @@ -1 +1 @@ -{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.87174, 10.89545, 10.88847, 10.88533, 10.893, 10.84895, 10.70048, 10.64124, 10.53839, 10.3107]}, "num-zeros": {"start_step": 0, "end_step": 32, "step_interval": 5, "values": [1238.0, 1318.0, 1774.0, 1416.0, 1549.0, 1271.0, 1270.0]}, "iteration_timing_avg": 0.06060647058823528} \ No newline at end of file +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.87174, 10.89545, 10.88847, 10.88533, 10.893, 10.84895, 10.70048, 10.64124, 10.53839, 10.3107]}, "num-zeros": {"start_step": 0, "end_step": 32, "step_interval": 5, "values": [1238.0, 1318.0, 1648.0, 1423.0, 1535.0, 1350.0, 1271.0]}, "iteration_timing_avg": 0.05914823529411765} \ No newline at end of file diff --git a/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp4_interleaved_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json b/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp4_interleaved_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json index 40e7b9ea0a..4e0217e20f 100644 --- a/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp4_interleaved_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json +++ b/tests/functional_tests/test_results/gpt3/gpt3_tp1_pp4_interleaved_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json @@ -1 +1 @@ -{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.7951, 10.84939, 10.87411, 10.83459, 10.82865, 10.78677, 10.56492, 10.57063, 10.48544, 10.19547]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [2586.0, 2686.0, 2148.0, 2589.0, 2703.0, 2403.0, 3020.0]}, "iteration_timing_avg": 0.12560235294117644} \ No newline at end of file +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.7951, 10.84939, 10.87411, 10.83459, 10.82865, 10.78676, 10.56492, 10.57063, 10.48544, 10.19547]}, "num-zeros": {"start_step": 0, "end_step": 34, "step_interval": 5, "values": [2586.0, 2828.0, 2105.0, 2725.0, 2711.0, 2428.0, 2946.0]}, "iteration_timing_avg": 0.11526} \ No newline at end of file diff --git a/tests/functional_tests/test_results/gpt3/gpt3_tp4_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json b/tests/functional_tests/test_results/gpt3/gpt3_tp4_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json index b780ad3981..e22ec7e5bd 100644 --- a/tests/functional_tests/test_results/gpt3/gpt3_tp4_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json +++ b/tests/functional_tests/test_results/gpt3/gpt3_tp4_pp1_1nodes_50steps_dist_optimizer_overlap_grad_reduce.json @@ -1 +1 @@ -{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.85921, 10.8797, 10.87381, 10.88658, 10.88912, 10.84826, 10.68571, 10.62947, 10.5429, 10.26917]}, "num-zeros": {"start_step": 0, "end_step": 33, "step_interval": 5, "values": [2288.0, 2283.0, 2422.0, 2061.0, 2147.0, 2418.0, 2400.0]}, "iteration_timing_avg": 0.19536911764705878} \ No newline at end of file +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.85921, 10.8797, 10.87381, 10.88658, 10.88912, 10.84826, 10.68571, 10.62947, 10.54289, 10.26918]}, "num-zeros": {"start_step": 0, "end_step": 33, "step_interval": 5, "values": [2288.0, 2326.0, 2454.0, 2011.0, 2111.0, 2436.0, 2446.0]}, "iteration_timing_avg": 0.18781294117647054} \ No newline at end of file