From 2d5a56dabd42dbde201a0996095b52755a9fd410 Mon Sep 17 00:00:00 2001 From: Daniel Kaplan Date: Sat, 28 Oct 2023 22:18:43 -0400 Subject: [PATCH] Now probably supports clip --- megatron/model/encoders/vision/open_clip | 1 + .../model/encoders/vision/vision_encoder.py | 68 ++++++++++++++++++- 2 files changed, 66 insertions(+), 3 deletions(-) create mode 160000 megatron/model/encoders/vision/open_clip diff --git a/megatron/model/encoders/vision/open_clip b/megatron/model/encoders/vision/open_clip new file mode 160000 index 000000000..d0befe114 --- /dev/null +++ b/megatron/model/encoders/vision/open_clip @@ -0,0 +1 @@ +Subproject commit d0befe114486a51b109bfc307fa3fbcaa3283b8b diff --git a/megatron/model/encoders/vision/vision_encoder.py b/megatron/model/encoders/vision/vision_encoder.py index 0dabc6b3b..ed39732d7 100644 --- a/megatron/model/encoders/vision/vision_encoder.py +++ b/megatron/model/encoders/vision/vision_encoder.py @@ -82,7 +82,7 @@ def forward(self, x): # return embeddings -def load_pretrained_weights(model, pretrained_weights, checkpoint_key): +def load_pretrained_weights_dino(model, pretrained_weights, checkpoint_key): # [TODO] add logger here if urlparse(pretrained_weights).scheme: # If it looks like an URL state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") @@ -118,8 +118,70 @@ def get_vision_encoder( ) model = vits.__dict__[name](**vit_kwargs) if pretrained: - model = load_pretrained_weights(model, args.pretrained_weights, "teacher") + model = load_pretrained_weights_dino(model, args.pretrained_weights, "teacher") encoder = DinoWrapper(model,args) + elif "evaclip" in name: + model, preprocess = open_clip.create_model_and_transforms('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k') + #Todo unsure if preprocess is right - it was two things before... + encoder = ClipWrapper(model, args, preprocess) else: raise ValueError(f"vision encoder {name} not recognized") - return encoder \ No newline at end of file + return encoder + +class ClipWrapper(nn.Module): + def __init__(self, encoder, config, transform): + super().__init__() + self.encoder = encoder + self.config = config + self.prepare_encoder() + + #TODO not sure about this one either... + self.transform = transform + + + #TODO this is the same for both models basically, fix that + def freeze_model(self): + num_layers_to_unfreeze = self.config.num_layers_to_unfreeze + # Freeze everything + self.encoder.requires_grad_(False) + # Unfreeze last num_layers_to_unfreeze layers + for child_name, child in list(self.encoder.named_modules())[-num_layers_to_unfreeze:]: + child.requires_grad_(True) + + # Unfreeze norms? + #TODO check what norms are there... + recursive_freeze_unfreeze(self.encoder, param_types=['LayerNorm'], freeze=False) + + # What about cls token? TODO + + def prepare_encoder(self): + if self.config.freeze_encoder: + self.freeze_model() + if self.config.add_lora: + add_lora(self.encoder) + + def forward(self, x): + ''' + x: (b, t, c, h, w) + x.shape: + b=batch size, + t=number of frames in each image/video, + c=number of channels, h=height, w=width + ''' + b, t, c, h, w = x.shape + combined_batch = rearrange(x, "b t c h w -> (b t) c h w") + preprocessed_vision = self.transform(combined_batch).half().contiguous() + + + x = rearrange(preprocessed_vision, "(b t) c h w -> b t c h w", b=b, t=t) + if True: + embeddings = self.encoder(x) + pooled, tokens = preprocessed_vision(image)# B, N_E, E + #For now, we just use tokens + #Later, once we have the -2 layer, it will need to replace the CLS with the FINAL cls and just return that + #No idea what this does for us + # else: + # x = rearrange(x, "b t c h w -> (b t) c h w") + # embeddings = self.encoder(x) # B*T, N_E, E + # embeddings = rearrange(embeddings, "(b t) n_e e -> b (t n_e) e", b=b, t=t) + return embeddings