Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Commit

Permalink
ESM-2 Public Release (#252)
Browse files Browse the repository at this point in the history
Source code and model weights release of ESM-2 Protein Language Model from the paper Language models of protein sequences at the scale of evolution enable accurate structure prediction (Z. Lin, H. Akin, R. Rao, B. Hie, Z. Zhu, W. Lu, A.S. Costa, M. Fazel-Zarandi, T. Sercu, S. Candido, A. Rives, 2022)
  • Loading branch information
nikita-smetanin authored Aug 22, 2022
1 parent 723e858 commit 4e0ebb7
Show file tree
Hide file tree
Showing 16 changed files with 2,023 additions and 1,669 deletions.
312 changes: 179 additions & 133 deletions README.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@
from .version import version as __version__ # noqa

from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
from .model import ProteinBertModel, MSATransformer # noqa
from .model.esm1 import ProteinBertModel # noqa
from .model.esm2 import ESM2 # noqa
from .model.msa_transformer import MSATransformer #noqa
from . import pretrained # noqa
200 changes: 200 additions & 0 deletions esm/model/esm1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..modules import (
TransformerLayer,
LearnedPositionalEmbedding,
SinusoidalPositionalEmbedding,
RobertaLMHead,
ESM1bLayerNorm,
ContactPredictionHead,
)


class ProteinBertModel(nn.Module):
@classmethod
def add_args(cls, parser):
parser.add_argument(
"--num_layers", default=36, type=int, metavar="N", help="number of layers"
)
parser.add_argument(
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
)
parser.add_argument(
"--logit_bias", action="store_true", help="whether to apply bias to logits"
)
parser.add_argument(
"--ffn_embed_dim",
default=5120,
type=int,
metavar="N",
help="embedding dimension for FFN",
)
parser.add_argument(
"--attention_heads",
default=20,
type=int,
metavar="N",
help="number of attention heads",
)

def __init__(self, args, alphabet):
super().__init__()
self.args = args
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False)
if self.args.arch == "roberta_large":
self.model_version = "ESM-1b"
self._init_submodules_esm1b()
else:
self.model_version = "ESM-1"
self._init_submodules_esm1()

def _init_submodules_common(self):
self.embed_tokens = nn.Embedding(
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
)
self.layers = nn.ModuleList(
[
TransformerLayer(
self.args.embed_dim,
self.args.ffn_embed_dim,
self.args.attention_heads,
add_bias_kv=(self.model_version != "ESM-1b"),
use_esm1b_layer_norm=(self.model_version == "ESM-1b"),
)
for _ in range(self.args.layers)
]
)

self.contact_head = ContactPredictionHead(
self.args.layers * self.args.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)

def _init_submodules_esm1b(self):
self._init_submodules_common()
self.embed_scale = 1
self.embed_positions = LearnedPositionalEmbedding(
self.args.max_positions, self.args.embed_dim, self.padding_idx
)
self.emb_layer_norm_before = (
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None
)
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.args.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)

def _init_submodules_esm1(self):
self._init_submodules_common()
self.embed_scale = math.sqrt(self.args.embed_dim)
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx)
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim)))
self.embed_out_bias = None
if self.args.final_bias:
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size))

def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
if return_contacts:
need_head_weights = True

assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_idx) # B, T

x = self.embed_scale * self.embed_tokens(tokens)

if getattr(self.args, "token_dropout", False):
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]

x = x + self.embed_positions(tokens)

if self.model_version == "ESM-1b":
if self.emb_layer_norm_before:
x = self.emb_layer_norm_before(x)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x

if need_head_weights:
attn_weights = []

# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)

if not padding_mask.any():
padding_mask = None

for layer_idx, layer in enumerate(self.layers):
x, attn = layer(
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
)
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
if need_head_weights:
# (H, B, T, T) => (B, H, T, T)
attn_weights.append(attn.transpose(1, 0))

if self.model_version == "ESM-1b":
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)

# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x = self.lm_head(x)
else:
x = F.linear(x, self.embed_out, bias=self.embed_out_bias)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)

result = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# attentions: B x L x H x T x T
attentions = torch.stack(attn_weights, 1)
if self.model_version == "ESM-1":
# ESM-1 models have an additional null-token for attention, which we remove
attentions = attentions[..., :-1]
if padding_mask is not None:
attention_mask = 1 - padding_mask.type_as(attentions)
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts

return result

def predict_contacts(self, tokens):
return self(tokens, return_contacts=True)["contacts"]

@property
def num_layers(self):
return self.args.layers
147 changes: 147 additions & 0 deletions esm/model/esm2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Union
import torch
import torch.nn as nn

import esm
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer


class ESM2(nn.Module):
def __init__(
self,
num_layers: int = 33,
embed_dim: int = 1280,
attention_heads: int = 20,
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
token_dropout: bool = True,
):
super().__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.attention_heads = attention_heads
if not isinstance(alphabet, esm.data.Alphabet):
alphabet = esm.data.Alphabet.from_architecture(alphabet)
self.alphabet = alphabet
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.token_dropout = token_dropout

self._init_submodules()

def _init_submodules(self):
self.embed_scale = 1
self.embed_tokens = nn.Embedding(
self.alphabet_size,
self.embed_dim,
padding_idx=self.padding_idx,
)

self.layers = nn.ModuleList(
[
TransformerLayer(
self.embed_dim,
4 * self.embed_dim,
self.attention_heads,
add_bias_kv=False,
use_esm1b_layer_norm=True,
use_rotary_embeddings=True,
)
for _ in range(self.num_layers)
]
)

self.contact_head = ContactPredictionHead(
self.num_layers * self.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)

self.lm_head = RobertaLMHead(
embed_dim=self.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)

def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
if return_contacts:
need_head_weights = True

assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_idx) # B, T

x = self.embed_scale * self.embed_tokens(tokens)

if self.token_dropout:
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]

if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x

if need_head_weights:
attn_weights = []

# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)

if not padding_mask.any():
padding_mask = None

for layer_idx, layer in enumerate(self.layers):
x, attn = layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights,
)
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
if need_head_weights:
# (H, B, T, T) => (B, H, T, T)
attn_weights.append(attn.transpose(1, 0))

x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)

# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x = self.lm_head(x)

result = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# attentions: B x L x H x T x T
attentions = torch.stack(attn_weights, 1)
if padding_mask is not None:
attention_mask = 1 - padding_mask.type_as(attentions)
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts

return result

def predict_contacts(self, tokens):
return self(tokens, return_contacts=True)["contacts"]
Loading

0 comments on commit 4e0ebb7

Please sign in to comment.