Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in loading pretrained weight for 'mae_vit_base_patch16' #180

Open
nightrain-vampire opened this issue Oct 18, 2023 · 2 comments
Open

Comments

@nightrain-vampire
Copy link

nightrain-vampire commented Oct 18, 2023

I try to use mae_vit_base in the demo, but it reports:

RuntimeError                              Traceback (most recent call last)
/data/user3/zspace/Mcm/demo/mae_visualize.ipynb 单元格 9 line 8
      [5] get_ipython().system('wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth')
      [7] chkpt_dir = 'mae_visualize_vit_base.pth'
----> [8] model_mae = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
      [9] print('Model loaded.')

/data/user3/zspace/Mcm/demo/mae_visualize.ipynb 单元格 9 line 1
     [17] # load model
     [18] checkpoint = torch.load(chkpt_dir, map_location='cpu')
---> [19] msg = model.load_state_dict(checkpoint['model'], strict=False)
     [20] print(msg)
     [21] return model

File [~/miniconda3/envs/mae/lib/python3.8/site-packages/torch/nn/modules/module.py:1671](https://vscode-remote+ssh-002dremote-002b10-002e176-002e62-002e136.vscode-resource.vscode-cdn.net/data/user3/zspace/Mcm/demo/~/miniconda3/envs/mae/lib/python3.8/site-packages/torch/nn/modules/module.py:1671), in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MaskedAutoencoderViT:
	size mismatch for pos_embed: copying a param with shape torch.Size([1, 197, 768]) from checkpoint, the shape in current model is torch.Size([1, 4097, 768]).
	size mismatch for decoder_pos_embed: copying a param with shape torch.Size([1, 197, 512]) from checkpoint, the shape in current model is torch.Size([1, 4097, 512])

My code is below:

# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth

chkpt_dir = 'mae_visualize_vit_base.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
print('Model loaded.')

What's the matter with the pretrained model? I also tried 'mae_pretrain_vit_base_full.pth', but it reports the same error. Can anyone help?

@MakoOfficial
Copy link

According to the error report,it seems that something is wrong with your model parameters. The checkpoint's image_size should be 224, and it's patch_size is 16, so the shape of pos_embed is [1, 197, 768]. But your init-model pos_embed is [1, 4097, 768].So I guess that your change the parameter "imgae_size" from 224 into 1024 with the "patch_size" remained 16.Maybe restoring the super-parameter would be a solution.

@sALTaccount
Copy link

I'm also unable to get the weights to load :/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants