diff --git a/tools/model_converters/publish_model.py b/tools/model_converters/publish_model.py index e2660578af..29130001b7 100644 --- a/tools/model_converters/publish_model.py +++ b/tools/model_converters/publish_model.py @@ -3,6 +3,8 @@ import subprocess import torch +from mmengine.logging import print_log +from mmengine.utils import digit_version def parse_args(): @@ -10,26 +12,50 @@ def parse_args(): description='Process a checkpoint to be published') parser.add_argument('in_file', help='input checkpoint filename') parser.add_argument('out_file', help='output checkpoint filename') + parser.add_argument( + '--save-keys', + nargs='+', + type=str, + default=['meta', 'state_dict'], + help='keys to save in the published checkpoint') args = parser.parse_args() return args -def process_checkpoint(in_file, out_file): +def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): checkpoint = torch.load(in_file, map_location='cpu') - # remove optimizer for smaller file size - if 'optimizer' in checkpoint: - del checkpoint['optimizer'] + + # only keep `meta` and `state_dict` for smaller file size + ckpt_keys = list(checkpoint.keys()) + for k in ckpt_keys: + if k not in save_keys: + print_log( + f'Key `{k}` will be removed because it is not in ' + 'save_keys. If you want to keep it, ' + 'please set --save-keys.', + logger='current') + checkpoint.pop(k, None) + # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. - torch.save(checkpoint, out_file) + if digit_version(torch.__version__) >= digit_version('1.6'): + torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) + else: + torch.save(checkpoint, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() - final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) + if out_file.endswith('.pth'): + out_file_name = out_file[:-4] + else: + out_file_name = out_file + final_file = out_file_name + f'-{sha[:8]}.pth' subprocess.Popen(['mv', out_file, final_file]) + print_log( + f'The published model is saved at {final_file}.', logger='current') def main(): args = parse_args() - process_checkpoint(args.in_file, args.out_file) + process_checkpoint(args.in_file, args.out_file, args.save_keys) if __name__ == '__main__':