From 16ac01aa8884cbe2c95d22c2b4c44e95deb24fe9 Mon Sep 17 00:00:00 2001 From: qingtian5 Date: Mon, 28 Aug 2023 00:51:11 +0800 Subject: [PATCH] add mplugowl model --- .../models/multimodal/mplugowl/mplugowl.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/mmpretrain/models/multimodal/mplugowl/mplugowl.py b/mmpretrain/models/multimodal/mplugowl/mplugowl.py index 82735ae4c4d..71e7a905a52 100644 --- a/mmpretrain/models/multimodal/mplugowl/mplugowl.py +++ b/mmpretrain/models/multimodal/mplugowl/mplugowl.py @@ -562,9 +562,9 @@ def forward( class MplugOwlVisualAbstractorLayer(BaseModel): - def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024): + def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024,chunk_size_feed_forward=None): super().__init__() - self.chunk_size_feed_forward = None + self.chunk_size_feed_forward = chunk_size_feed_forward self.seq_len_dim = 1 self.layer_idx = layer_idx @@ -661,12 +661,11 @@ def custom_forward(*inputs): class MplugOwlVisualAbstractorModel(BaseModel): - def __init__(self, config: MplugOwlVisualAbstractorConfig, language_hidden_size): - super().__init__(config) - self.config = config + def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024): + super().__init__() - self.encoder = MplugOwlVisualAbstractorEncoder(config) - self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size) + self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size) + self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size) self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size)) nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range) self.post_init() @@ -824,8 +823,8 @@ def forward( ) -class MplugOwlModel(MplugOwlPreTrainedModel): - config_class = MplugOwlConfig +@MODELS.register_module() +class MplugOwlModel(BaseModel): main_input_name = "pixel_values" def __init__(self, config: MplugOwlConfig, *inputs, **kwargs): @@ -1080,9 +1079,4 @@ def custom_forward(*inputs): attentions=all_self_attentions, ) -@MODELS.register_module() -class mPLUGOwl(BaseModel): - def __init__(self,): - pass -