Skip to content

Commit

Permalink
Merge pull request #1 from YerevaNN/remove_unneeded_features
Browse files Browse the repository at this point in the history
remove tensor and pipeline parallelism related code
  • Loading branch information
philippguevorguian authored Aug 20, 2024
2 parents dad421f + 625c8fd commit 4eb849d
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 224 deletions.
3 changes: 1 addition & 2 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
torch >= 2.3.0
torchdata >= 0.8.0
datasets >= 2.19.0
datasets >= 2.21.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
Expand Down
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set -ex
# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"2"}
LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

Expand Down
8 changes: 2 additions & 6 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,7 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
Returns global rank 0 in non-pipeline-parallel configs, and returns the global
rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
"""
if parallel_dims.pp_enabled:
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
metrics_log_rank = (world_size // pp_size) * (pp_size - 1)
else:
metrics_log_rank = 0
metrics_log_rank = 0

return metrics_log_rank

Expand Down Expand Up @@ -154,3 +149,4 @@ def build_metric_logger(
log_dir = os.path.join(log_dir, rank_str)

return MetricLogger(log_dir, tag, enable_tb)

1 change: 0 additions & 1 deletion torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torchtitan.config_manager import JobConfig


# consider split between PP and non-PP
def build_optimizers(model_parts, job_config: JobConfig):
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
Expand Down
25 changes: 7 additions & 18 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
@dataclass
class ParallelDims:
dp: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str
Expand All @@ -25,22 +23,20 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
dp = self.dp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
self.dp = dp = self.world_size
assert dp >= 1, dp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
dp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.dp], ["dp"], strict=True
):
if d > 1:
dims.append(d)
Expand All @@ -53,18 +49,11 @@ def build_mesh(self, device_type):
def dp_enabled(self):
return self.dp > 1

@property
def tp_enabled(self):
return self.tp > 1

@property
def pp_enabled(self):
return self.pp > 1

@property
def loss_parallel_enabled(self):
return self.tp > 1 and self.enable_loss_parallel
return False # requires tensor parallelism


@cached_property
def model_parallel_size(self):
return self.tp * self.pp
return 1
132 changes: 4 additions & 128 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,6 @@ def parallelize_llama(
the model must fit on GPU or CPU memory.
"""

if parallel_dims.tp_enabled:
if (
job_config.experimental.enable_async_tensor_parallel
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=job_config.float8.enable_float8_linear,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Expand All @@ -84,8 +71,6 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_reduce
],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)
else:
if world_mesh.ndim > 1:
Expand All @@ -98,102 +83,6 @@ def parallelize_llama(
)


def apply_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)

# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
if enable_float8:
# TODO(vkuzo): once float8 configuration supports delayed scaling,
# add a check here to enforce supported float8 all-gather configurations
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)

rowwise_parallel, colwise_parallel, prepare_module_input = (
Float8RowwiseParallel,
Float8ColwiseParallel,
PrepareFloat8ModuleInput,
)
else:
rowwise_parallel, colwise_parallel, prepare_module_input = (
RowwiseParallel,
ColwiseParallel,
PrepareModuleInput,
)

# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
"attention.wv": colwise_parallel(),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
}

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)


# for selective op activation checkpointing
_save_list = {
torch.ops.aten.mm.default,
Expand Down Expand Up @@ -291,36 +180,23 @@ def apply_fsdp(
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
fully_shard(model, **fsdp_config, reshard_after_forward=True) # in torch titan, this was "not pp_enabled"

logger.info("Applied FSDP to the model")

Expand Down
6 changes: 0 additions & 6 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
# LICENSE file in the root directory of this source tree.
from typing import Tuple

from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
)
from torchtitan.logging import logger


Expand Down
Loading

0 comments on commit 4eb849d

Please sign in to comment.