Not really functional. Can only emulate solutions proposed for large models due to resource limitation.
Virtual implementation of distributed training/serving parallelism. Virtual hardware are implemented as:
cluster.py
VirtualDevice
: Emulates a compute device/machine using amultiprocessing.Process
. Have all collective ops (e.g.AllScatter
,AllGather
,AllReduce
,ReduceScatter
, etc.) as builtin.VirtualChannel
: Emulates intra-device and host2device, device2host communication throughmultiprocessing.Queue
.VirtualCluster
: Emulates a compute cluster through a pool ofmultiprocessing.Process
s andmultiprocessing.Queue
s.
mesh.py
Mesh
: Hardware topology configuration.MeshIndex
: Hardware topology index.
operation.py
Operation
: An operation to be executed by the compute cluster.ShardedOperation
: An operation to run w/ tensor parallelism.PipelinedOperation
: An operation to run w/ pipeline parallelism.
sharding.py
DimSharding
: Sharding configuration along a dimension.TensorSharding
: Sharding configuration along a tensor.
pipelining.py
Pipeline
: A pipeline configuration. Mostly including total number of concurrent batches, stages, and hardware mesh mapping.PipelineStage
: A pipeline stage configuration and status. Mostly including the status of which cycle this stage is running, and the associated helper functions to decide if it is inforward
,backward
, oridle
status.
Intra-op model parallelism through sharding tensors and computation.
matmul.py
MatMulSharding
: Tensor parallelism sharding configuration.MatMulShardingPolicy
: Tensor parallelism sharding policies for function dispatch.Unsharded
: No sharding. Duplicated runs.MatchedInnerSharding
: Matched inner dimension sharding. Run withAllReduce
.UnmatchedInnerSharding
: Unmatched inner dimension sharding. Run withAllGather
.OuterSharding
: Outer (m
) and inner (k
) dimension sharding. Run withAllGather
->AllGather
.FullSharding
: All uter (m
,n
) and inner (k
) dimension sharding. Run withAllGather
->AllGather
.
MatMul
: Tensor parallelism based operation. Implementation is dispatched based onMatMulShardingPolicy
.
ffn.py
FeedForwardSharding
: Tensor parallelism sharding configuration.FeedForwardShardingPolicy
: Tensor parallelism sharding policies for function dispatch.Unsharded
: No sharding. Duplicated runs.GSPMD
: https://arxiv.org/abs/2105.04663. Run withAllGather
->AllGather
->ReduceScatter
.
FeedForward
: Tensor parallelism based operation. Implementation is dispatched based onFeedForwardShardingPolicy
.
Inter-op model parallelism through run different stages of the model asynchronously.
mlp.py
MLP
: A simple MLP model with 4 dense layers running as different stages of pipeline. Both forward and backward paths are managed. Activations, parameters, gradients, and states are managed but optimizers are not.