Skip to content

Commit

Permalink
Now probably supports clip
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Kaplan committed Oct 29, 2023
1 parent 516a42d commit 2d5a56d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
1 change: 1 addition & 0 deletions megatron/model/encoders/vision/open_clip
Submodule open_clip added at d0befe
68 changes: 65 additions & 3 deletions megatron/model/encoders/vision/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward(self, x):
# return embeddings


def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
def load_pretrained_weights_dino(model, pretrained_weights, checkpoint_key):
# [TODO] add logger here
if urlparse(pretrained_weights).scheme: # If it looks like an URL
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
Expand Down Expand Up @@ -118,8 +118,70 @@ def get_vision_encoder(
)
model = vits.__dict__[name](**vit_kwargs)
if pretrained:
model = load_pretrained_weights(model, args.pretrained_weights, "teacher")
model = load_pretrained_weights_dino(model, args.pretrained_weights, "teacher")
encoder = DinoWrapper(model,args)
elif "evaclip" in name:
model, preprocess = open_clip.create_model_and_transforms('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k')
#Todo unsure if preprocess is right - it was two things before...
encoder = ClipWrapper(model, args, preprocess)
else:
raise ValueError(f"vision encoder {name} not recognized")
return encoder
return encoder

class ClipWrapper(nn.Module):
def __init__(self, encoder, config, transform):
super().__init__()
self.encoder = encoder
self.config = config
self.prepare_encoder()

#TODO not sure about this one either...
self.transform = transform


#TODO this is the same for both models basically, fix that
def freeze_model(self):
num_layers_to_unfreeze = self.config.num_layers_to_unfreeze
# Freeze everything
self.encoder.requires_grad_(False)
# Unfreeze last num_layers_to_unfreeze layers
for child_name, child in list(self.encoder.named_modules())[-num_layers_to_unfreeze:]:
child.requires_grad_(True)

# Unfreeze norms?
#TODO check what norms are there...
recursive_freeze_unfreeze(self.encoder, param_types=['LayerNorm'], freeze=False)

# What about cls token? TODO

def prepare_encoder(self):
if self.config.freeze_encoder:
self.freeze_model()
if self.config.add_lora:
add_lora(self.encoder)

def forward(self, x):
'''
x: (b, t, c, h, w)
x.shape:
b=batch size,
t=number of frames in each image/video,
c=number of channels, h=height, w=width
'''
b, t, c, h, w = x.shape
combined_batch = rearrange(x, "b t c h w -> (b t) c h w")
preprocessed_vision = self.transform(combined_batch).half().contiguous()


x = rearrange(preprocessed_vision, "(b t) c h w -> b t c h w", b=b, t=t)
if True:
embeddings = self.encoder(x)
pooled, tokens = preprocessed_vision(image)# B, N_E, E
#For now, we just use tokens
#Later, once we have the -2 layer, it will need to replace the CLS with the FINAL cls and just return that
#No idea what this does for us
# else:
# x = rearrange(x, "b t c h w -> (b t) c h w")
# embeddings = self.encoder(x) # B*T, N_E, E
# embeddings = rearrange(embeddings, "(b t) n_e e -> b (t n_e) e", b=b, t=t)
return embeddings

0 comments on commit 2d5a56d

Please sign in to comment.