Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Shenggan committed Feb 28, 2022
1 parent d5f3875 commit 53143f5
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 20 deletions.
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
1. Excellent kernel performance on GPU platform
2. Supporting Dynamic Axial Parallelism(DAP)
* Break the memory limit of single GPU and reduce the overall training time
* Distributed inference can significantly speed up inference and make extremely long sequence inference possible
* DAP can significantly speed up inference and make ultra-long sequence inference possible
3. Ease of use
* Replace a few lines and you can use FastFold in your project
* Huge performance gains with a few lines changes
* You don't need to care about how the parallel part is implemented

## Installation
Expand All @@ -38,6 +38,24 @@ cd FastFold
python setup.py install --cuda_ext
```

## Usage

You can use `Evoformer` as `nn.Module` in your project after `from fastfold.model import Evoformer`:

```python
from fastfold.model import Evoformer
evoformer_layer = Evoformer()
```

If you want to use Dynamic Axial Parallelism, add a line of initialize with `fastfold.distributed.init_dap` after `torch.distributed.init_process_group`.

```python
from fastfold.distributed import init_dap

torch.distributed.init_process_group(backend='nccl', init_method='env://')
init_dap(args.dap_size)
```

## Performance Benchmark

We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
Expand All @@ -47,6 +65,13 @@ cd ./benchmark
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256
```

Benchmark Dynamic Axial Parallelism with 2 GPUs:

```shell
cd ./benchmark
torchrun --nproc_per_node=2 perf.py --msa-length 128 --res-length 256 --dap-size 2
```

If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfold), you need to install OpenFold first and benchmark with option `--openfold`:

```shell
Expand Down
24 changes: 12 additions & 12 deletions benchmark/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,34 @@
import torch
import torch.nn as nn

from fastfold.distributed import init_shadowcore
from fastfold.distributed import init_dap
from fastfold.model import Evoformer


def main():

parser = argparse.ArgumentParser(description='MSA Attention Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int)
parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int, help='batch size')
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of Input')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA')
parser.add_argument('--res-length',
default=256,
type=int,
help='Start Range of Number of Sequences')
help='Sequence Length of Residues')
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers',
default=12,
type=int,
help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
help='Evoformer Layers to Execute')
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--openfold',
action='store_true',
help='torch.nn.MultitheadAttention Version.')
help='Benchmark with Evoformer Implementation from OpenFold.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--prof', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--prof', action='store_true', help='run with profiler.')

args = parser.parse_args()

Expand All @@ -48,10 +48,10 @@ def main():
print(
'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.global_rank, args.world_size))
init_shadowcore(args.tensor_model_parallel_size)
init_dap(args.dap_size)

precision = torch.bfloat16
if args.tensor_model_parallel_size > 1:
if args.dap_size > 1:
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
precision = torch.float16

Expand Down Expand Up @@ -111,13 +111,13 @@ def forward(self, node, pair, node_mask, pair_mask):
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))

inputs_node = torch.randn(args.batch_size,
args.msa_length // args.tensor_model_parallel_size,
args.msa_length // args.dap_size,
args.res_length,
args.cm,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
inputs_pair = torch.randn(args.batch_size,
args.res_length // args.tensor_model_parallel_size,
args.res_length // args.dap_size,
args.res_length,
args.cz,
dtype=precision,
Expand Down
1 change: 1 addition & 0 deletions fastfold/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
VERSION = "0.1.0-beta"
4 changes: 2 additions & 2 deletions fastfold/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .core import (init_shadowcore, shadowcore_is_initialized, get_tensor_model_parallel_group,
from .core import (init_dap, dap_is_initialized, get_tensor_model_parallel_group,
get_data_parallel_group, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank, get_data_parallel_world_size,
get_data_parallel_rank, get_tensor_model_parallel_src_rank)
from .comm import (_reduce, _split, _gather, copy, scatter, reduce, gather, col_to_row, row_to_col)

__all__ = [
'init_shadowcore', 'shadowcore_is_initialized', 'get_tensor_model_parallel_group',
'init_dap', 'dap_is_initialized', 'get_tensor_model_parallel_group',
'get_data_parallel_group', 'get_tensor_model_parallel_world_size',
'get_tensor_model_parallel_rank', 'get_data_parallel_world_size', 'get_data_parallel_rank',
'get_tensor_model_parallel_src_rank', '_reduce', '_split', '_gather', 'copy', 'scatter',
Expand Down
4 changes: 2 additions & 2 deletions fastfold/distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def ensure_divisibility(numerator, denominator):
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)


def init_shadowcore(tensor_model_parallel_size_=1):
def init_dap(tensor_model_parallel_size_=1):

assert dist.is_initialized()

Expand Down Expand Up @@ -51,7 +51,7 @@ def init_shadowcore(tensor_model_parallel_size_=1):
print('> initialize data parallel with size {}'.format(data_parallel_size_))


def shadowcore_is_initialized():
def dap_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
Expand Down
2 changes: 1 addition & 1 deletion fastfold/model/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class Evoformer(nn.Module):

def __init__(self, d_node, d_pair):
def __init__(self, d_node=256, d_pair=128):
super(Evoformer, self).__init__()

self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def cuda_ext_helper(name, sources, extra_cuda_flags):

setup(
name='fastfold',
version='0.0.1-beta',
version='0.1.0-beta',
packages=find_packages(exclude=(
'assets',
'benchmark',
Expand Down

0 comments on commit 53143f5

Please sign in to comment.