This repository has been archived by the owner on Aug 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 645
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Source code and model weights release of ESM-2 Protein Language Model from the paper Language models of protein sequences at the scale of evolution enable accurate structure prediction (Z. Lin, H. Akin, R. Rao, B. Hie, Z. Zhu, W. Lu, A.S. Costa, M. Fazel-Zarandi, T. Sercu, S. Candido, A. Rives, 2022)
- Loading branch information
1 parent
723e858
commit 4e0ebb7
Showing
16 changed files
with
2,023 additions
and
1,669 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from ..modules import ( | ||
TransformerLayer, | ||
LearnedPositionalEmbedding, | ||
SinusoidalPositionalEmbedding, | ||
RobertaLMHead, | ||
ESM1bLayerNorm, | ||
ContactPredictionHead, | ||
) | ||
|
||
|
||
class ProteinBertModel(nn.Module): | ||
@classmethod | ||
def add_args(cls, parser): | ||
parser.add_argument( | ||
"--num_layers", default=36, type=int, metavar="N", help="number of layers" | ||
) | ||
parser.add_argument( | ||
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension" | ||
) | ||
parser.add_argument( | ||
"--logit_bias", action="store_true", help="whether to apply bias to logits" | ||
) | ||
parser.add_argument( | ||
"--ffn_embed_dim", | ||
default=5120, | ||
type=int, | ||
metavar="N", | ||
help="embedding dimension for FFN", | ||
) | ||
parser.add_argument( | ||
"--attention_heads", | ||
default=20, | ||
type=int, | ||
metavar="N", | ||
help="number of attention heads", | ||
) | ||
|
||
def __init__(self, args, alphabet): | ||
super().__init__() | ||
self.args = args | ||
self.alphabet_size = len(alphabet) | ||
self.padding_idx = alphabet.padding_idx | ||
self.mask_idx = alphabet.mask_idx | ||
self.cls_idx = alphabet.cls_idx | ||
self.eos_idx = alphabet.eos_idx | ||
self.prepend_bos = alphabet.prepend_bos | ||
self.append_eos = alphabet.append_eos | ||
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False) | ||
if self.args.arch == "roberta_large": | ||
self.model_version = "ESM-1b" | ||
self._init_submodules_esm1b() | ||
else: | ||
self.model_version = "ESM-1" | ||
self._init_submodules_esm1() | ||
|
||
def _init_submodules_common(self): | ||
self.embed_tokens = nn.Embedding( | ||
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx | ||
) | ||
self.layers = nn.ModuleList( | ||
[ | ||
TransformerLayer( | ||
self.args.embed_dim, | ||
self.args.ffn_embed_dim, | ||
self.args.attention_heads, | ||
add_bias_kv=(self.model_version != "ESM-1b"), | ||
use_esm1b_layer_norm=(self.model_version == "ESM-1b"), | ||
) | ||
for _ in range(self.args.layers) | ||
] | ||
) | ||
|
||
self.contact_head = ContactPredictionHead( | ||
self.args.layers * self.args.attention_heads, | ||
self.prepend_bos, | ||
self.append_eos, | ||
eos_idx=self.eos_idx, | ||
) | ||
|
||
def _init_submodules_esm1b(self): | ||
self._init_submodules_common() | ||
self.embed_scale = 1 | ||
self.embed_positions = LearnedPositionalEmbedding( | ||
self.args.max_positions, self.args.embed_dim, self.padding_idx | ||
) | ||
self.emb_layer_norm_before = ( | ||
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None | ||
) | ||
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim) | ||
self.lm_head = RobertaLMHead( | ||
embed_dim=self.args.embed_dim, | ||
output_dim=self.alphabet_size, | ||
weight=self.embed_tokens.weight, | ||
) | ||
|
||
def _init_submodules_esm1(self): | ||
self._init_submodules_common() | ||
self.embed_scale = math.sqrt(self.args.embed_dim) | ||
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx) | ||
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim))) | ||
self.embed_out_bias = None | ||
if self.args.final_bias: | ||
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size)) | ||
|
||
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): | ||
if return_contacts: | ||
need_head_weights = True | ||
|
||
assert tokens.ndim == 2 | ||
padding_mask = tokens.eq(self.padding_idx) # B, T | ||
|
||
x = self.embed_scale * self.embed_tokens(tokens) | ||
|
||
if getattr(self.args, "token_dropout", False): | ||
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) | ||
# x: B x T x C | ||
mask_ratio_train = 0.15 * 0.8 | ||
src_lengths = (~padding_mask).sum(-1) | ||
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths | ||
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] | ||
|
||
x = x + self.embed_positions(tokens) | ||
|
||
if self.model_version == "ESM-1b": | ||
if self.emb_layer_norm_before: | ||
x = self.emb_layer_norm_before(x) | ||
if padding_mask is not None: | ||
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | ||
|
||
repr_layers = set(repr_layers) | ||
hidden_representations = {} | ||
if 0 in repr_layers: | ||
hidden_representations[0] = x | ||
|
||
if need_head_weights: | ||
attn_weights = [] | ||
|
||
# (B, T, E) => (T, B, E) | ||
x = x.transpose(0, 1) | ||
|
||
if not padding_mask.any(): | ||
padding_mask = None | ||
|
||
for layer_idx, layer in enumerate(self.layers): | ||
x, attn = layer( | ||
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights | ||
) | ||
if (layer_idx + 1) in repr_layers: | ||
hidden_representations[layer_idx + 1] = x.transpose(0, 1) | ||
if need_head_weights: | ||
# (H, B, T, T) => (B, H, T, T) | ||
attn_weights.append(attn.transpose(1, 0)) | ||
|
||
if self.model_version == "ESM-1b": | ||
x = self.emb_layer_norm_after(x) | ||
x = x.transpose(0, 1) # (T, B, E) => (B, T, E) | ||
|
||
# last hidden representation should have layer norm applied | ||
if (layer_idx + 1) in repr_layers: | ||
hidden_representations[layer_idx + 1] = x | ||
x = self.lm_head(x) | ||
else: | ||
x = F.linear(x, self.embed_out, bias=self.embed_out_bias) | ||
x = x.transpose(0, 1) # (T, B, E) => (B, T, E) | ||
|
||
result = {"logits": x, "representations": hidden_representations} | ||
if need_head_weights: | ||
# attentions: B x L x H x T x T | ||
attentions = torch.stack(attn_weights, 1) | ||
if self.model_version == "ESM-1": | ||
# ESM-1 models have an additional null-token for attention, which we remove | ||
attentions = attentions[..., :-1] | ||
if padding_mask is not None: | ||
attention_mask = 1 - padding_mask.type_as(attentions) | ||
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) | ||
attentions = attentions * attention_mask[:, None, None, :, :] | ||
result["attentions"] = attentions | ||
if return_contacts: | ||
contacts = self.contact_head(tokens, attentions) | ||
result["contacts"] = contacts | ||
|
||
return result | ||
|
||
def predict_contacts(self, tokens): | ||
return self(tokens, return_contacts=True)["contacts"] | ||
|
||
@property | ||
def num_layers(self): | ||
return self.args.layers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Union | ||
import torch | ||
import torch.nn as nn | ||
|
||
import esm | ||
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer | ||
|
||
|
||
class ESM2(nn.Module): | ||
def __init__( | ||
self, | ||
num_layers: int = 33, | ||
embed_dim: int = 1280, | ||
attention_heads: int = 20, | ||
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b", | ||
token_dropout: bool = True, | ||
): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
self.embed_dim = embed_dim | ||
self.attention_heads = attention_heads | ||
if not isinstance(alphabet, esm.data.Alphabet): | ||
alphabet = esm.data.Alphabet.from_architecture(alphabet) | ||
self.alphabet = alphabet | ||
self.alphabet_size = len(alphabet) | ||
self.padding_idx = alphabet.padding_idx | ||
self.mask_idx = alphabet.mask_idx | ||
self.cls_idx = alphabet.cls_idx | ||
self.eos_idx = alphabet.eos_idx | ||
self.prepend_bos = alphabet.prepend_bos | ||
self.append_eos = alphabet.append_eos | ||
self.token_dropout = token_dropout | ||
|
||
self._init_submodules() | ||
|
||
def _init_submodules(self): | ||
self.embed_scale = 1 | ||
self.embed_tokens = nn.Embedding( | ||
self.alphabet_size, | ||
self.embed_dim, | ||
padding_idx=self.padding_idx, | ||
) | ||
|
||
self.layers = nn.ModuleList( | ||
[ | ||
TransformerLayer( | ||
self.embed_dim, | ||
4 * self.embed_dim, | ||
self.attention_heads, | ||
add_bias_kv=False, | ||
use_esm1b_layer_norm=True, | ||
use_rotary_embeddings=True, | ||
) | ||
for _ in range(self.num_layers) | ||
] | ||
) | ||
|
||
self.contact_head = ContactPredictionHead( | ||
self.num_layers * self.attention_heads, | ||
self.prepend_bos, | ||
self.append_eos, | ||
eos_idx=self.eos_idx, | ||
) | ||
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim) | ||
|
||
self.lm_head = RobertaLMHead( | ||
embed_dim=self.embed_dim, | ||
output_dim=self.alphabet_size, | ||
weight=self.embed_tokens.weight, | ||
) | ||
|
||
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): | ||
if return_contacts: | ||
need_head_weights = True | ||
|
||
assert tokens.ndim == 2 | ||
padding_mask = tokens.eq(self.padding_idx) # B, T | ||
|
||
x = self.embed_scale * self.embed_tokens(tokens) | ||
|
||
if self.token_dropout: | ||
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) | ||
# x: B x T x C | ||
mask_ratio_train = 0.15 * 0.8 | ||
src_lengths = (~padding_mask).sum(-1) | ||
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths | ||
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] | ||
|
||
if padding_mask is not None: | ||
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | ||
|
||
repr_layers = set(repr_layers) | ||
hidden_representations = {} | ||
if 0 in repr_layers: | ||
hidden_representations[0] = x | ||
|
||
if need_head_weights: | ||
attn_weights = [] | ||
|
||
# (B, T, E) => (T, B, E) | ||
x = x.transpose(0, 1) | ||
|
||
if not padding_mask.any(): | ||
padding_mask = None | ||
|
||
for layer_idx, layer in enumerate(self.layers): | ||
x, attn = layer( | ||
x, | ||
self_attn_padding_mask=padding_mask, | ||
need_head_weights=need_head_weights, | ||
) | ||
if (layer_idx + 1) in repr_layers: | ||
hidden_representations[layer_idx + 1] = x.transpose(0, 1) | ||
if need_head_weights: | ||
# (H, B, T, T) => (B, H, T, T) | ||
attn_weights.append(attn.transpose(1, 0)) | ||
|
||
x = self.emb_layer_norm_after(x) | ||
x = x.transpose(0, 1) # (T, B, E) => (B, T, E) | ||
|
||
# last hidden representation should have layer norm applied | ||
if (layer_idx + 1) in repr_layers: | ||
hidden_representations[layer_idx + 1] = x | ||
x = self.lm_head(x) | ||
|
||
result = {"logits": x, "representations": hidden_representations} | ||
if need_head_weights: | ||
# attentions: B x L x H x T x T | ||
attentions = torch.stack(attn_weights, 1) | ||
if padding_mask is not None: | ||
attention_mask = 1 - padding_mask.type_as(attentions) | ||
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) | ||
attentions = attentions * attention_mask[:, None, None, :, :] | ||
result["attentions"] = attentions | ||
if return_contacts: | ||
contacts = self.contact_head(tokens, attentions) | ||
result["contacts"] = contacts | ||
|
||
return result | ||
|
||
def predict_contacts(self, tokens): | ||
return self(tokens, return_contacts=True)["contacts"] |
Oops, something went wrong.