Skip to content

Commit

Permalink
add mplugowl model
Browse files Browse the repository at this point in the history
  • Loading branch information
qingtian5 committed Aug 27, 2023
1 parent fc2754f commit 16ac01a
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions mmpretrain/models/multimodal/mplugowl/mplugowl.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,9 @@ def forward(


class MplugOwlVisualAbstractorLayer(BaseModel):
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024,chunk_size_feed_forward=None):
super().__init__()
self.chunk_size_feed_forward = None
self.chunk_size_feed_forward = chunk_size_feed_forward
self.seq_len_dim = 1

self.layer_idx = layer_idx
Expand Down Expand Up @@ -661,12 +661,11 @@ def custom_forward(*inputs):


class MplugOwlVisualAbstractorModel(BaseModel):
def __init__(self, config: MplugOwlVisualAbstractorConfig, language_hidden_size):
super().__init__(config)
self.config = config
def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
super().__init__()

self.encoder = MplugOwlVisualAbstractorEncoder(config)
self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size)
self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size)
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range)
self.post_init()
Expand Down Expand Up @@ -824,8 +823,8 @@ def forward(
)


class MplugOwlModel(MplugOwlPreTrainedModel):
config_class = MplugOwlConfig
@MODELS.register_module()
class MplugOwlModel(BaseModel):
main_input_name = "pixel_values"

def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
Expand Down Expand Up @@ -1080,9 +1079,4 @@ def custom_forward(*inputs):
attentions=all_self_attentions,
)

@MODELS.register_module()
class mPLUGOwl(BaseModel):
def __init__(self,):
pass


0 comments on commit 16ac01a

Please sign in to comment.