-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support Side Adapter Network (#3232)
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <[email protected]> Co-authored-by: yeedrag <[email protected]> Co-authored-by: Yang-ChangHui <[email protected]> Co-authored-by: Xu CAO <[email protected]> Co-authored-by: xiexinch <[email protected]> Co-authored-by: 小飞猪 <[email protected]>
- Loading branch information
1 parent
1471d1e
commit 608e319
Showing
42 changed files
with
4,114 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# model settings | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
|
||
data_preprocessor = dict( | ||
type='SegDataPreProcessor', | ||
mean=[122.7709, 116.7460, 104.0937], | ||
std=[68.5005, 66.6322, 70.3232], | ||
bgr_to_rgb=True, | ||
pad_val=0, | ||
seg_pad_val=255, | ||
size_divisor=640, | ||
test_cfg=dict(size_divisor=32)) | ||
|
||
num_classes = 171 | ||
model = dict( | ||
type='MultimodalEncoderDecoder', | ||
data_preprocessor=data_preprocessor, | ||
pretrained='pretrain/clip_vit_base_patch16_224.pth', | ||
asymetric_input=True, | ||
encoder_resolution=0.5, | ||
image_encoder=dict( | ||
type='VisionTransformer', | ||
img_size=(224, 224), | ||
patch_size=16, | ||
patch_pad=0, | ||
in_channels=3, | ||
embed_dims=768, | ||
num_layers=9, | ||
num_heads=12, | ||
mlp_ratio=4, | ||
out_origin=True, | ||
out_indices=(2, 5, 8), | ||
qkv_bias=True, | ||
drop_rate=0.0, | ||
attn_drop_rate=0.0, | ||
drop_path_rate=0.0, | ||
with_cls_token=True, | ||
output_cls_token=True, | ||
patch_bias=False, | ||
pre_norm=True, | ||
norm_cfg=dict(type='LN', eps=1e-5), | ||
act_cfg=dict(type='QuickGELU'), | ||
norm_eval=False, | ||
interpolate_mode='bicubic', | ||
frozen_exclude=['pos_embed']), | ||
text_encoder=dict( | ||
type='CLIPTextEncoder', | ||
dataset_name=None, | ||
templates='vild', | ||
embed_dims=512, | ||
num_layers=12, | ||
num_heads=8, | ||
mlp_ratio=4, | ||
output_dims=512, | ||
cache_feature=True, | ||
cat_bg=True, | ||
norm_cfg=dict(type='LN', eps=1e-5) | ||
), | ||
decode_head=dict( | ||
type='SideAdapterCLIPHead', | ||
num_classes=num_classes, | ||
deep_supervision_idxs=[7], | ||
san_cfg=dict( | ||
in_channels=3, | ||
clip_channels=768, | ||
embed_dims=240, | ||
patch_size=16, | ||
patch_bias=True, | ||
num_queries=100, | ||
cfg_encoder=dict( | ||
num_encode_layer=8, | ||
num_heads=6, | ||
mlp_ratio=4 | ||
), | ||
fusion_index=[0, 1, 2, 3], | ||
cfg_decoder=dict( | ||
num_heads=12, | ||
num_layers=1, | ||
embed_channels=256, | ||
mlp_channels=256, | ||
num_mlp=3, | ||
rescale=True), | ||
norm_cfg=dict(type='LN', eps=1e-6), | ||
), | ||
maskgen_cfg=dict( | ||
sos_token_format='cls_token', | ||
sos_token_num=100, | ||
cross_attn=False, | ||
num_layers=3, | ||
embed_dims=768, | ||
num_heads=12, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
out_dims=512, | ||
final_norm=True, | ||
act_cfg=dict(type='QuickGELU'), | ||
norm_cfg=dict(type='LN', eps=1e-5), | ||
frozen_exclude=[] | ||
), | ||
align_corners=False, | ||
train_cfg=dict( | ||
num_points=12544, | ||
oversample_ratio=3.0, | ||
importance_sample_ratio=0.75, | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='ClassificationCost', weight=2.0), | ||
dict( | ||
type='CrossEntropyLossCost', | ||
weight=5.0, | ||
use_sigmoid=True), | ||
dict( | ||
type='DiceCost', | ||
weight=5.0, | ||
pred_act=True, | ||
eps=1.0) | ||
])), | ||
loss_decode=[dict(type='CrossEntropyLoss', | ||
loss_name='loss_cls_ce', | ||
loss_weight=2.0, | ||
class_weight=[1.0] * num_classes + [0.1]), | ||
dict(type='CrossEntropyLoss', | ||
use_sigmoid=True, | ||
loss_name='loss_mask_ce', | ||
loss_weight=5.0), | ||
dict(type='DiceLoss', | ||
ignore_index=None, | ||
naive_dice=True, | ||
eps=1, | ||
loss_name='loss_mask_dice', | ||
loss_weight=5.0) | ||
]), | ||
|
||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) # yapf: disable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# SAN | ||
|
||
> [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) | ||
## Introduction | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
<a href="https://github.com/MendelXu/SAN">Official Repo</a> | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
|
||
This paper presents a new framework for open-vocabulary semantic segmentation with the pre-trained vision-language model, named Side Adapter Network (SAN). Our approach models the semantic segmentation task as a region recognition problem. A side network is attached to a frozen CLIP model with two branches: one for predicting mask proposals, and the other for predicting attention bias which is applied in the CLIP model to recognize the class of masks. This decoupled design has the benefit CLIP in recognizing the class of mask proposals. Since the attached side network can reuse CLIP features, it can be very light. In addition, the entire network can be trained end-to-end, allowing the side network to be adapted to the frozen CLIP model, which makes the predicted mask proposals CLIP-aware. Our approach is fast, accurate, and only adds a few additional trainable parameters. We evaluate our approach on multiple semantic segmentation benchmarks. Our method significantly outperforms other counterparts, with up to 18 times fewer trainable parameters and 19 times faster inference speed. We hope our approach will serve as a solid baseline and help ease future research in open-vocabulary semantic segmentation. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align=center> | ||
<img src="https://github.com/MendelXu/SAN/blob/main/resources/arch.png" width="800"/> | ||
</div> | ||
|
||
## Results and models | ||
|
||
### COCO-Stuff164k | ||
|
||
| Method | Backbone | Pretrained | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download | | ||
| ------ | -------- | ------------ | --------- | ------- | -------- | -------------- | ------ | ----- | ------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| SAN | ViT-B_16 | CLIP_ViT-B16 | 640x640 | 60000 | 12.61 | - | V100 | 41.93 | 41.77 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906-fd0a7684.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906.log) | | ||
| SAN | ViT-L_14 | CLIP_ViT-L14 | 640x640 | 60000 | 22.84 | - | V100 | 45.78 | 43.99 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907-a11e098f.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907.log) | | ||
|
||
## Notes | ||
|
||
git push | ||
The pretrained weights in config files are converted from open_clip models using tools/model_converters/clip2mmseg.py. | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@inproceedings{xu2023side, | ||
title={Side adapter network for open-vocabulary semantic segmentation}, | ||
author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Hu, Han and Bai, Xiang}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={2945--2954}, | ||
year={2023} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
_base_ = [ | ||
'../_base_/models/san_vit-b16.py', '../_base_/datasets/coco-stuff164k.py', | ||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' | ||
] | ||
crop_size = (640, 640) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations'), | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[int(640 * x * 0.1) for x in range(5, 16)], | ||
resize_type='ResizeShortestEdge', | ||
max_size=2560), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.0), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PackSegInputs') | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560), | ||
dict(type='LoadAnnotations'), | ||
dict(type='PackSegInputs') | ||
] | ||
|
||
# By default, models are trained on 4 GPUs with 8 images per GPU | ||
train_dataloader = dict(batch_size=8, dataset=dict(pipeline=train_pipeline)) | ||
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/san/clip_vit-base-patch16-224_3rdparty-d08f8887.pth' # noqa | ||
data_preprocessor = dict( | ||
mean=[122.7709, 116.7460, 104.0937], | ||
std=[68.5005, 66.6322, 70.3232], | ||
size_divisor=640, | ||
test_cfg=dict(size_divisor=32)) | ||
model = dict( | ||
pretrained=pretrained, | ||
text_encoder=dict(dataset_name='coco-stuff164k'), | ||
decode_head=dict(num_classes=171)) | ||
|
||
# training schedule for 60k | ||
train_cfg = dict( | ||
type='IterBasedTrainLoop', | ||
max_iters=60000, | ||
val_interval=500, | ||
val_begin=55000) | ||
default_hooks = dict( | ||
checkpoint=dict( | ||
type='CheckpointHook', | ||
by_epoch=False, | ||
interval=10000, | ||
save_best='mIoU')) | ||
|
||
# AdamW optimizer, no weight decay for position embedding & layer norm | ||
# in backbone | ||
optim_wrapper = dict( | ||
_delete_=True, | ||
type='AmpOptimWrapper', | ||
optimizer=dict( | ||
type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'img_encoder': dict(lr_mult=0.1, decay_mult=1.0), | ||
'pos_embed': dict(decay_mult=0.), | ||
'cls_token': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.) | ||
}), | ||
loss_scale='dynamic', | ||
clip_grad=dict(max_norm=0.01, norm_type=2)) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='PolyLR', | ||
eta_min=0.0, | ||
power=1.0, | ||
begin=0, | ||
end=60000, | ||
by_epoch=False, | ||
) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
_base_ = [ | ||
'../_base_/models/san_vit-b16.py', | ||
'../_base_/datasets/pascal_context_59.py', '../_base_/default_runtime.py', | ||
'../_base_/schedules/schedule_160k.py' | ||
] | ||
crop_size = (640, 640) | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560), | ||
dict(type='LoadAnnotations'), | ||
dict(type='PackSegInputs') | ||
] | ||
|
||
# By default, models are trained on 8 GPUs with 2 images per GPU | ||
train_dataloader = dict(batch_size=2) | ||
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
data_preprocessor = dict( | ||
mean=[122.7709, 116.7460, 104.0937], | ||
std=[68.5005, 66.6322, 70.3232], | ||
size_divisor=640, | ||
test_cfg=dict(size_divisor=32)) | ||
model = dict( | ||
data_preprocessor=data_preprocessor, | ||
pretrained='pretrain/vit_base_patch16_224.pth', | ||
text_encoder=dict(dataset_name='pascal_context'), | ||
decode_head=dict(num_classes=59)) | ||
|
||
# AdamW optimizer, no weight decay for position embedding & layer norm | ||
# in backbone | ||
optim_wrapper = dict( | ||
_delete_=True, | ||
type='OptimWrapper', | ||
optimizer=dict( | ||
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'pos_embed': dict(decay_mult=0.), | ||
'cls_token': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.) | ||
})) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), | ||
dict( | ||
type='PolyLR', | ||
eta_min=0.0, | ||
power=1.0, | ||
begin=1500, | ||
end=160000, | ||
by_epoch=False, | ||
) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
_base_ = [ | ||
'../_base_/models/san_vit-b16.py', | ||
'../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py', | ||
'../_base_/schedules/schedule_160k.py' | ||
] | ||
crop_size = (640, 640) | ||
|
||
metainfo = dict( | ||
classes=('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', | ||
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', | ||
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'), | ||
palette=[[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], | ||
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], | ||
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], | ||
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], | ||
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]) | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560), | ||
dict(type='LoadAnnotations'), | ||
dict(type='PackSegInputs') | ||
] | ||
# By default, models are trained on 8 GPUs with 2 images per GPU | ||
train_dataloader = dict(batch_size=2) | ||
val_dataloader = dict( | ||
batch_size=1, dataset=dict(metainfo=metainfo, pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
data_preprocessor = dict( | ||
mean=[122.7709, 116.7460, 104.0937], | ||
std=[68.5005, 66.6322, 70.3232], | ||
size_divisor=640, | ||
test_cfg=dict(size_divisor=32)) | ||
model = dict( | ||
data_preprocessor=data_preprocessor, | ||
pretrained='pretrain/vit_base_patch16_224.pth', | ||
text_encoder=dict(dataset_name='voc'), | ||
decode_head=dict(num_classes=20)) | ||
|
||
# AdamW optimizer, no weight decay for position embedding & layer norm | ||
# in backbone | ||
optim_wrapper = dict( | ||
_delete_=True, | ||
type='OptimWrapper', | ||
optimizer=dict( | ||
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'pos_embed': dict(decay_mult=0.), | ||
'cls_token': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.) | ||
})) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), | ||
dict( | ||
type='PolyLR', | ||
eta_min=0.0, | ||
power=1.0, | ||
begin=1500, | ||
end=160000, | ||
by_epoch=False, | ||
) | ||
] |
Oops, something went wrong.