diff --git a/projects/CAT-Seg/README.md b/projects/CAT-Seg/README.md
new file mode 100644
index 0000000000..890e461ce4
--- /dev/null
+++ b/projects/CAT-Seg/README.md
@@ -0,0 +1,92 @@
+# CAT-Seg
+
+> [CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2303.11797)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+Existing works on open-vocabulary semantic segmentation have utilized large-scale vision-language models, such as CLIP, to leverage their exceptional open-vocabulary recognition capabilities. However, the problem of transferring these capabilities learned from image-level supervision to the pixel-level task of segmentation and addressing arbitrary unseen categories at inference makes this task challenging. To address these issues, we aim to attentively relate objects within an image to given categories by leveraging relational information among class categories and visual semantics through aggregation, while also adapting the CLIP representations to the pixel-level task. However, we observe that direct optimization of the CLIP embeddings can harm its open-vocabulary capabilities. In this regard, we propose an alternative approach to optimize the imagetext similarity map, i.e. the cost map, using a novel cost aggregation-based method. Our framework, namely CATSeg, achieves state-of-the-art performance across all benchmarks. We provide extensive ablation studies to validate our choices. [Project page](https://ku-cvlab.github.io/CAT-Seg).
+
+
+
+
+
+CAT-Seg model structure
+
+
+## Usage
+
+CAT-Seg model training needs pretrained `CLIP` model. We have implemented `ViT-B` and `ViT-L` based `CLIP` model. To further use `ViT-bigG` or `ViT-H` ones, you need additional dependencies. Please install [open_clip](https://github.com/mlfoundations/open_clip) first. The pretrained `CLIP` model state dicts are loaded from [Huggingface-OpenCLIP](https://huggingface.co/models?library=open_clip). **If you come up with `ConnectionError` when downloading CLIP weights**, you can manually download them from the given repo and use `custom_clip_weights=/path/to/you/folder` of backbone in config file. Related tools are as shown in [requirements/optional.txt](requirements/optional.txt):
+
+```shell
+pip install ftfy==6.0.1
+pip install huggingface-hub
+pip install regex
+```
+
+In addition to the necessary [data preparation](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md), you also need class texts for clip text encoder. Please download the class text json file first [cls_texts](https://github.com/open-mmlab/mmsegmentation/files/11714914/cls_texts.zip) and arrange the folder as follows:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── VOCdevkit
+│ │ ├── VOC2012
+│ │ ├── VOC2010
+│ │ ├── VOCaug
+│ ├── ade
+│ ├── coco_stuff164k
+│ ├── coco.json
+│ ├── pc59.json
+│ ├── pc459.json
+│ ├── ade150.json
+│ ├── ade847.json
+│ ├── voc20b.json
+│ ├── voc20.json
+```
+
+```shell
+# setup PYTHONPATH
+export PYTHONPATH=`pwd`:$PYTHONPATH
+# run evaluation
+mim test mmsegmentation ${CONFIG} --checkpoint ${CHECKPOINT} --launcher pytorch --gpus=8
+```
+
+## Results and models
+
+### ADE20K-150-ZeroShot
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
+| ------- | ------------- | --------- | ------- | -------: | -------------- | ------- | ---- | ------------: | ------------------------------------------------------------------------------------------: | --------------------------------------------------------------------------------------------------------------------------------------------- |
+| CAT-Seg | R-101 & ViT-B | 384x384 | 80000 | - | - | RTX3090 | 27.2 | - | [config](./configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384-54194d72.pth) |
+
+Note:
+
+- All experiments of CAT-Seg are implemented with 4 RTX3090 GPUs, except the last one with pretrained ViT-bigG CLIP model (GPU Memory insufficient, you may need A100).
+- Due to the feature size bottleneck of the CLIP image encoder, the inference and testing can only be done under `slide` mode, the inference time is longer since the test size is much more bigger that training size of `(384, 384)`.
+- The ResNet backbones utilized in CAT-Seg models are standard `ResNet` rather than `ResNetV1c`.
+- The zero-shot segmentation results on PASCAL VOC and ADE20K are from the original paper. Our results are coming soon. We appreatiate your contribution!
+- In additional to zero-shot segmentation performance results, we also provided the evaluation results on the `val2017` set of **COCO-stuff164k** for reference, which is the training dataset of CAT-Seg. The testing was done **without TTA**.
+- The number behind the dataset name is the category number for segmentation evaluation (except training data **COCO-stuff 164k**). **PASCAL VOC-20b** defines the "background" as classes present in **PASCAL-Context-59** but not in **PASCAL VOC-20**.
+
+## Citation
+
+```bibtex
+@inproceedings{cheng2021mask2former,
+ title={CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation},
+ author={Seokju Cho and Heeseong Shin and Sunghwan Hong and Seungjun An and Seungjun Lee and Anurag Arnab and Paul Hongsuck Seo and Seungryong Kim},
+ journal={CVPR},
+ year={2023}
+}
+```
diff --git a/projects/CAT-Seg/cat_seg/__init__.py b/projects/CAT-Seg/cat_seg/__init__.py
new file mode 100644
index 0000000000..2c51fbaa2e
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/__init__.py
@@ -0,0 +1,2 @@
+from .models import * # noqa: F401,F403
+from .utils import * # noqa: F401,F403
diff --git a/projects/CAT-Seg/cat_seg/models/__init__.py b/projects/CAT-Seg/cat_seg/models/__init__.py
new file mode 100644
index 0000000000..cd0e15d3ec
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .cat_aggregator import (AggregatorLayer, CATSegAggregator,
+ ClassAggregateLayer, SpatialAggregateLayer)
+from .cat_head import CATSegHead
+from .clip_ovseg import CLIPOVCATSeg
+
+__all__ = [
+ 'AggregatorLayer', 'CATSegAggregator', 'ClassAggregateLayer',
+ 'SpatialAggregateLayer', 'CATSegHead', 'CLIPOVCATSeg'
+]
diff --git a/projects/CAT-Seg/cat_seg/models/cat_aggregator.py b/projects/CAT-Seg/cat_seg/models/cat_aggregator.py
new file mode 100644
index 0000000000..a0483fe505
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/cat_aggregator.py
@@ -0,0 +1,763 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, build_dropout
+from mmengine.model import BaseModule
+from mmengine.utils import to_2tuple
+
+from mmseg.registry import MODELS
+from ..utils import FullAttention, LinearAttention
+
+
+class AGWindowMSA(BaseModule):
+ """Appearance Guidance Window based multi-head self-attention (W-MSA)
+ module with relative position bias.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ appearance_dims (int): Number of appearance guidance feature channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): The height and width of the window.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Default: 0.0
+ proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
+ init_cfg (dict | None, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ window_size,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+ self.embed_dims = embed_dims
+ self.appearance_dims = appearance_dims
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_embed_dims = embed_dims // num_heads
+ self.scale = qk_scale or head_embed_dims**-0.5
+
+ # About 2x faster than original impl
+ Wh, Ww = self.window_size
+ rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
+ rel_position_index = rel_index_coords + rel_index_coords.T
+ rel_position_index = rel_position_index.flip(1).contiguous()
+ self.register_buffer('relative_position_index', rel_position_index)
+
+ self.qk = nn.Linear(
+ embed_dims + appearance_dims, embed_dims * 2, bias=qkv_bias)
+ self.v = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop_rate)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+ self.proj_drop = nn.Dropout(proj_drop_rate)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x (tensor): input features with shape of (num_windows*B, N, C),
+ C = embed_dims + appearance_dims.
+ mask (tensor | None, Optional): mask with shape of (num_windows,
+ Wh*Ww, Wh*Ww), value should be between (-inf, 0].
+ """
+ B, N, _ = x.shape
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads,
+ self.embed_dims // self.num_heads).permute(
+ 2, 0, 3, 1,
+ 4) # 2 B NUM_HEADS N embed_dims//NUM_HEADS
+ v = self.v(x[:, :, :self.embed_dims]).reshape(
+ B, N, self.num_heads, self.embed_dims // self.num_heads).permute(
+ 0, 2, 1, 3) # B NUM_HEADS N embed_dims//NUM_HEADS
+ # make torchscript happy (cannot use tensor as tuple)
+ q, k = qk[0], qk[1]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ @staticmethod
+ def double_step_seq(step1, len1, step2, len2):
+ """Double step sequence."""
+ seq1 = torch.arange(0, step1 * len1, step1)
+ seq2 = torch.arange(0, step2 * len2, step2)
+ return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
+
+
+class AGShiftWindowMSA(BaseModule):
+ """Appearance Guidance Shifted Window Multihead Self-Attention Module.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ appearance_dims (int): Number of appearance guidance channels
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window.
+ shift_size (int, optional): The shift step of each window towards
+ right-bottom. If zero, act as regular window-msa. Defaults to 0.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Defaults: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Defaults: 0.
+ proj_drop_rate (float, optional): Dropout ratio of output.
+ Defaults: 0.
+ dropout_layer (dict, optional): The dropout_layer used before output.
+ Defaults: dict(type='DropPath', drop_prob=0.).
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ window_size,
+ shift_size=0,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0,
+ proj_drop_rate=0,
+ dropout_layer=dict(type='DropPath', drop_prob=0.),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.window_size = window_size
+ self.shift_size = shift_size
+ assert 0 <= self.shift_size < self.window_size
+
+ self.w_msa = AGWindowMSA(
+ embed_dims=embed_dims,
+ appearance_dims=appearance_dims,
+ num_heads=num_heads,
+ window_size=to_2tuple(window_size),
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=proj_drop_rate,
+ init_cfg=None)
+
+ self.drop = build_dropout(dropout_layer)
+
+ def forward(self, query, hw_shape):
+ """
+ Args:
+ query: The input query.
+ hw_shape: The shape of the feature height and width.
+ """
+ B, L, C = query.shape
+ H, W = hw_shape
+ assert L == H * W, 'input feature has wrong size'
+ query = query.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
+ H_pad, W_pad = query.shape[1], query.shape[2]
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_query = torch.roll(
+ query,
+ shifts=(-self.shift_size, -self.shift_size),
+ dims=(1, 2))
+
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ # nW, window_size, window_size, 1
+ mask_windows = self.window_partition(img_mask)
+ mask_windows = mask_windows.view(
+ -1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
+ float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0))
+ else:
+ shifted_query = query
+ attn_mask = None
+
+ # nW*B, window_size, window_size, C
+ query_windows = self.window_partition(shifted_query)
+ # nW*B, window_size*window_size, C
+ query_windows = query_windows.view(-1, self.window_size**2, C)
+
+ # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
+ attn_windows = self.w_msa(query_windows, mask=attn_mask)
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size,
+ self.w_msa.embed_dims)
+
+ # B H' W' self.w_msa.embed_dims
+ shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x,
+ shifts=(self.shift_size, self.shift_size),
+ dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, self.w_msa.embed_dims)
+
+ x = self.drop(x)
+ return x
+
+ def window_reverse(self, windows, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ window_size = self.window_size
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+ def window_partition(self, x):
+ """
+ Args:
+ x: (B, H, W, C)
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ window_size = self.window_size
+ x = x.view(B, H // window_size, window_size, W // window_size,
+ window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ windows = windows.view(-1, window_size, window_size, C)
+ return windows
+
+
+class AGSwinBlock(BaseModule):
+ """Appearance Guidance Swin Transformer Block.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ appearance_dims (int): The appearance guidance dimension.
+ num_heads (int): Parallel attention heads.
+ mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
+ for FFNs.
+ window_size (int, optional): The local window scale.
+ Default: 7.
+ shift (bool, optional): whether to shift window or not.
+ Default False.
+ qkv_bias (bool, optional): enable bias for qkv if True.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float, optional): Dropout rate. Default: 0.
+ attn_drop_rate (float, optional): Attention dropout rate.
+ Default: 0.
+ drop_path_rate (float, optional): Stochastic depth rate.
+ Default: 0.
+ act_cfg (dict, optional): The config dict of activation function.
+ Default: dict(type='GELU').
+ norm_cfg (dict, optional): The config dict of normalization.
+ Default: dict(type='LN').
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios=4,
+ window_size=7,
+ shift=False,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.attn = AGShiftWindowMSA(
+ embed_dims=embed_dims,
+ appearance_dims=appearance_dims,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=window_size // 2 if shift else 0,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ init_cfg=None)
+
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.ffn = FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * mlp_ratios,
+ num_fcs=2,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ act_cfg=act_cfg,
+ add_identity=True,
+ init_cfg=None)
+
+ def forward(self, inputs, hw_shape):
+ """
+ Args:
+ inputs (list[Tensor]): appearance_guidance (B, H, W, C);
+ x (B, L, C)
+ hw_shape (tuple[int]): shape of feature.
+ """
+ x, appearance_guidance = inputs
+ B, L, C = x.shape
+ H, W = hw_shape
+ assert L == H * W, 'input feature has wrong size'
+
+ identity = x
+ x = self.norm1(x)
+
+ # appearance guidance
+ x = x.view(B, H, W, C)
+ if appearance_guidance is not None:
+ x = torch.cat([x, appearance_guidance], dim=-1).flatten(1, 2)
+
+ x = self.attn(x, hw_shape)
+
+ x = x + identity
+
+ identity = x
+ x = self.norm2(x)
+ x = self.ffn(x, identity=identity)
+
+ return x
+
+
+@MODELS.register_module()
+class SpatialAggregateLayer(BaseModule):
+ """Spatial aggregation layer of CAT-Seg.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ appearance_dims (int): The appearance guidance dimension.
+ num_heads (int): Parallel attention heads.
+ mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
+ for FFNs.
+ window_size (int, optional): The local window scale. Default: 7.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios,
+ window_size=7,
+ qk_scale=None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.block_1 = AGSwinBlock(
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios,
+ window_size=window_size,
+ shift=False,
+ qk_scale=qk_scale)
+ self.block_2 = AGSwinBlock(
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios,
+ window_size=window_size,
+ shift=True,
+ qk_scale=qk_scale)
+ self.guidance_norm = nn.LayerNorm(
+ appearance_dims) if appearance_dims > 0 else None
+
+ def forward(self, x, appearance_guidance):
+ """
+ Args:
+ x (torch.Tensor): B C T H W.
+ appearance_guidance (torch.Tensor): B C H W.
+ """
+ B, C, T, H, W = x.shape
+ x = x.permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1, 2) # BT, HW, C
+ if appearance_guidance is not None:
+ appearance_guidance = appearance_guidance.repeat(
+ T, 1, 1, 1).permute(0, 2, 3, 1) # BT, HW, C
+ appearance_guidance = self.guidance_norm(appearance_guidance)
+ else:
+ assert self.appearance_dims == 0
+ x = self.block_1((x, appearance_guidance), (H, W))
+ x = self.block_2((x, appearance_guidance), (H, W))
+ x = x.transpose(1, 2).reshape(B, T, C, -1)
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
+ return x
+
+
+class AttentionLayer(nn.Module):
+ """Attention layer for ClassAggregration of CAT-Seg.
+
+ Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L310 # noqa
+ """
+
+ def __init__(self,
+ hidden_dim,
+ guidance_dim,
+ nheads=8,
+ attention_type='linear'):
+ super().__init__()
+ self.nheads = nheads
+ self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
+ self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
+ self.v = nn.Linear(hidden_dim, hidden_dim)
+
+ if attention_type == 'linear':
+ self.attention = LinearAttention()
+ elif attention_type == 'full':
+ self.attention = FullAttention()
+ else:
+ raise NotImplementedError
+
+ def forward(self, x, guidance=None):
+ """
+ Args:
+ x: B*H_p*W_p, T, C
+ guidance: B*H_p*W_p, T, C
+ """
+ B, L, _ = x.shape
+ q = self.q(torch.cat([x, guidance],
+ dim=-1)) if guidance is not None else self.q(x)
+ k = self.k(torch.cat([x, guidance],
+ dim=-1)) if guidance is not None else self.k(x)
+ v = self.v(x)
+
+ q = q.reshape(B, L, self.nheads, -1)
+ k = k.reshape(B, L, self.nheads, -1)
+ v = v.reshape(B, L, self.nheads, -1)
+
+ out = self.attention(q, k, v)
+ out = out.reshape(B, L, -1)
+ return out
+
+
+@MODELS.register_module()
+class ClassAggregateLayer(BaseModule):
+ """Class aggregation layer of CAT-Seg.
+
+ Args:
+ hidden_dims (int): The feature dimension.
+ guidance_dims (int): The appearance guidance dimension.
+ num_heads (int): Parallel attention heads.
+ attention_type (str): Type of attention layer. Default: 'linear'.
+ pooling_size (tuple[int] | list[int]): Pooling size.
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ hidden_dims=64,
+ guidance_dims=64,
+ num_heads=8,
+ attention_type='linear',
+ pooling_size=(4, 4),
+ init_cfg=None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+ self.pool = nn.AvgPool2d(pooling_size)
+ self.attention = AttentionLayer(
+ hidden_dims,
+ guidance_dims,
+ nheads=num_heads,
+ attention_type=attention_type)
+ self.MLP = FFN(
+ embed_dims=hidden_dims,
+ feedforward_channels=hidden_dims * 4,
+ num_fcs=2)
+ self.norm1 = nn.LayerNorm(hidden_dims)
+ self.norm2 = nn.LayerNorm(hidden_dims)
+
+ def pool_features(self, x):
+ """Intermediate pooling layer for computational efficiency.
+
+ Args:
+ x: B, C, T, H, W
+ """
+ B, C, T, H, W = x.shape
+ x = x.transpose(1, 2).reshape(-1, C, H, W)
+ x = self.pool(x)
+ *_, H_, W_ = x.shape
+ x = x.reshape(B, T, C, H_, W_).transpose(1, 2)
+ return x
+
+ def forward(self, x, guidance):
+ """
+ Args:
+ x: B, C, T, H, W
+ guidance: B, T, C
+ """
+ B, C, T, H, W = x.size()
+ x_pool = self.pool_features(x)
+ *_, H_pool, W_pool = x_pool.size()
+
+ x_pool = x_pool.permute(0, 3, 4, 2, 1).reshape(-1, T, C)
+ # B*H_p*W_p T C
+ if guidance is not None:
+ guidance = guidance.repeat(H_pool * W_pool, 1, 1)
+
+ x_pool = x_pool + self.attention(self.norm1(x_pool),
+ guidance) # Attention
+ x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
+
+ x_pool = x_pool.reshape(B, H_pool * W_pool, T,
+ C).permute(0, 2, 3, 1).reshape(
+ B, T, C, H_pool,
+ W_pool).flatten(0, 1) # BT C H_p W_p
+ x_pool = F.interpolate(
+ x_pool, size=(H, W), mode='bilinear', align_corners=True)
+ x_pool = x_pool.reshape(B, T, C, H, W).transpose(1, 2) # B C T H W
+ x = x + x_pool # Residual
+
+ return x
+
+
+@MODELS.register_module()
+class AggregatorLayer(BaseModule):
+ """Single Aggregator Layer of CAT-Seg."""
+
+ def __init__(self,
+ embed_dims=64,
+ text_guidance_dims=512,
+ appearance_guidance_dims=512,
+ num_heads=4,
+ mlp_ratios=4,
+ window_size=7,
+ attention_type='linear',
+ pooling_size=(2, 2),
+ init_cfg=None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.spatial_agg = SpatialAggregateLayer(
+ embed_dims,
+ appearance_guidance_dims,
+ num_heads=num_heads,
+ mlp_ratios=mlp_ratios,
+ window_size=window_size)
+ self.class_agg = ClassAggregateLayer(
+ embed_dims,
+ text_guidance_dims,
+ num_heads=num_heads,
+ attention_type=attention_type,
+ pooling_size=pooling_size)
+
+ def forward(self, x, appearance_guidance, text_guidance):
+ """
+ Args:
+ x: B C T H W
+ """
+ x = self.spatial_agg(x, appearance_guidance)
+ x = self.class_agg(x, text_guidance)
+ return x
+
+
+@MODELS.register_module()
+class CATSegAggregator(BaseModule):
+ """CATSeg Aggregator.
+
+ This Aggregator is the mmseg implementation of
+ `CAT-Seg `_.
+
+ Args:
+ text_guidance_dim (int): Text guidance dimensions. Default: 512.
+ text_guidance_proj_dim (int): Text guidance projection dimensions.
+ Default: 128.
+ appearance_guidance_dim (int): Appearance guidance dimensions.
+ Default: 512.
+ appearance_guidance_proj_dim (int): Appearance guidance projection
+ dimensions. Default: 128.
+ num_layers (int): Aggregator layer number. Default: 4.
+ num_heads (int): Attention layer head number. Default: 4.
+ embed_dims (int): Input feature dimensions. Default: 128.
+ pooling_size (tuple | list): Pooling size of the class aggregator
+ layer. Default: (6, 6).
+ mlp_ratios (int): The hidden dimension ratio w.r.t. input dimension.
+ Default: 4.
+ window_size (int): Swin block window size. Default:12.
+ attention_type (str): Attention type of class aggregator layer.
+ Default:'linear'.
+ prompt_channel (int): Prompt channels. Default: 80.
+ """
+
+ def __init__(self,
+ text_guidance_dim=512,
+ text_guidance_proj_dim=128,
+ appearance_guidance_dim=512,
+ appearance_guidance_proj_dim=128,
+ num_layers=4,
+ num_heads=4,
+ embed_dims=128,
+ pooling_size=(6, 6),
+ mlp_ratios=4,
+ window_size=12,
+ attention_type='linear',
+ prompt_channel=80,
+ **kwargs):
+ super().__init__(**kwargs)
+ self.num_layers = num_layers
+ self.embed_dims = embed_dims
+
+ self.layers = nn.ModuleList([
+ AggregatorLayer(
+ embed_dims=embed_dims,
+ text_guidance_dims=text_guidance_proj_dim,
+ appearance_guidance_dims=appearance_guidance_proj_dim,
+ num_heads=num_heads,
+ mlp_ratios=mlp_ratios,
+ window_size=window_size,
+ attention_type=attention_type,
+ pooling_size=pooling_size) for _ in range(num_layers)
+ ])
+
+ self.conv1 = nn.Conv2d(
+ prompt_channel, embed_dims, kernel_size=7, stride=1, padding=3)
+
+ self.guidance_projection = nn.Sequential(
+ nn.Conv2d(
+ appearance_guidance_dim,
+ appearance_guidance_proj_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1),
+ nn.ReLU(),
+ ) if appearance_guidance_dim > 0 else None
+
+ self.text_guidance_projection = nn.Sequential(
+ nn.Linear(text_guidance_dim, text_guidance_proj_dim),
+ nn.ReLU(),
+ ) if text_guidance_dim > 0 else None
+
+ def feature_map(self, img_feats, text_feats):
+ """Concatenation type cost volume.
+
+ For ablation study of cost volume type.
+ """
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
+ img_feats = img_feats.unsqueeze(2).repeat(1, 1, text_feats.shape[1], 1,
+ 1)
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
+ text_feats = text_feats.mean(dim=-2)
+ text_feats = F.normalize(text_feats, dim=-1) # B T C
+ text_feats = text_feats.unsqueeze(-1).unsqueeze(-1).repeat(
+ 1, 1, 1, img_feats.shape[-2], img_feats.shape[-1]).transpose(1, 2)
+ return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
+
+ def correlation(self, img_feats, text_feats):
+ """Correlation of image features and text features."""
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
+ corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
+ return corr
+
+ def corr_embed(self, x):
+ """Correlation embeddings encoding."""
+ B = x.shape[0]
+ corr_embed = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ corr_embed = self.conv1(corr_embed)
+ corr_embed = corr_embed.reshape(B, -1, self.embed_dims, x.shape[-2],
+ x.shape[-1]).transpose(1, 2)
+ return corr_embed
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (dict): including the following keys,
+ 'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
+ `self.feature_extractor`.
+ 'clip_text_feat': the text feature extracted by clip text
+ encoder.
+ 'clip_text_feat_test': the text feature extracted by clip text
+ encoder for testing.
+ 'clip_img_feat': the image feature extracted clip image
+ encoder.
+ """
+ img_feats = inputs['clip_img_feat']
+ B = img_feats.size(0)
+ appearance_guidance = inputs[
+ 'appearance_feat'][::-1] # order (out_indices) 2, 1, 0
+ text_feats = inputs['clip_text_feat'] if self.training else inputs[
+ 'clip_text_feat_test']
+ text_feats = text_feats.repeat(B, 1, 1, 1)
+
+ corr = self.correlation(img_feats, text_feats)
+ # corr = self.feature_map(img_feats, text_feats)
+ corr_embed = self.corr_embed(corr)
+
+ projected_guidance, projected_text_guidance = None, None
+
+ if self.guidance_projection is not None:
+ projected_guidance = self.guidance_projection(
+ appearance_guidance[0])
+
+ if self.text_guidance_projection is not None:
+ text_feats = text_feats.mean(dim=-2)
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
+ projected_text_guidance = self.text_guidance_projection(text_feats)
+
+ for layer in self.layers:
+ corr_embed = layer(corr_embed, projected_guidance,
+ projected_text_guidance)
+
+ return dict(
+ corr_embed=corr_embed, appearance_feats=appearance_guidance[1:])
diff --git a/projects/CAT-Seg/cat_seg/models/cat_head.py b/projects/CAT-Seg/cat_seg/models/cat_head.py
new file mode 100644
index 0000000000..36bb1c5617
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/cat_head.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+from mmseg.registry import MODELS
+
+
+class UpBlock(nn.Module):
+ """Upsample Block with two consecutive convolution layers."""
+
+ def __init__(self, in_channels, out_channels, guidance_channels):
+ super().__init__()
+ self.up = nn.ConvTranspose2d(
+ in_channels,
+ in_channels - guidance_channels,
+ kernel_size=2,
+ stride=2)
+ self.conv1 = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ padding=1,
+ bias=False,
+ norm_cfg=dict(type='GN', num_groups=out_channels // 16))
+ self.conv2 = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ bias=False,
+ norm_cfg=dict(type='GN', num_groups=out_channels // 16))
+
+ def forward(self, x, guidance=None):
+ """Forward function with visual guidance."""
+ x = self.up(x)
+ if guidance is not None:
+ T = x.size(0) // guidance.size(0)
+ # guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T)
+ guidance = guidance.repeat(T, 1, 1, 1)
+ x = torch.cat([x, guidance], dim=1)
+ x = self.conv1(x)
+
+ return self.conv2(x)
+
+
+@MODELS.register_module()
+class CATSegHead(BaseDecodeHead):
+ """CATSeg Head.
+
+ This segmentation head is the mmseg implementation of
+ `CAT-Seg `_.
+
+ Args:
+ embed_dims (int): The number of input dimensions.
+ decoder_dims (list): The number of decoder dimensions.
+ decoder_guidance_proj_dims (list): The number of appearance
+ guidance dimensions.
+ init_cfg
+ """
+
+ def __init__(self,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(256, 128),
+ decoder_guidance_proj_dims=(32, 16),
+ **kwargs):
+ super().__init__(**kwargs)
+ self.decoder_guidance_projection = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(
+ dec_dims,
+ dec_dims_proj,
+ kernel_size=3,
+ stride=1,
+ padding=1),
+ nn.ReLU(),
+ ) for dec_dims, dec_dims_proj in zip(decoder_guidance_dims,
+ decoder_guidance_proj_dims)
+ ]) if decoder_guidance_dims[0] > 0 else None
+
+ self.decoder1 = UpBlock(embed_dims, decoder_dims[0],
+ decoder_guidance_proj_dims[0])
+ self.decoder2 = UpBlock(decoder_dims[0], decoder_dims[1],
+ decoder_guidance_proj_dims[1])
+ self.conv_seg = nn.Conv2d(
+ decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, inputs):
+ """Forward function.
+
+ Args:
+ inputs (dict): Input features including the following features,
+ corr_embed: aggregated correlation embeddings.
+ appearance_feats: decoder appearance feature guidance.
+ """
+ # decoder guidance projection
+ if self.decoder_guidance_projection is not None:
+ projected_decoder_guidance = [
+ proj(g) for proj, g in zip(self.decoder_guidance_projection,
+ inputs['appearance_feats'])
+ ]
+
+ # decoder layers
+ B = inputs['corr_embed'].size(0)
+ corr_embed = inputs['corr_embed'].transpose(1, 2).flatten(0, 1)
+ corr_embed = self.decoder1(corr_embed, projected_decoder_guidance[0])
+ corr_embed = self.decoder2(corr_embed, projected_decoder_guidance[1])
+
+ output = self.cls_seg(corr_embed)
+
+ # rearrange the output to (B, T, H, W)
+ H_ori, W_ori = output.shape[-2:]
+ output = output.reshape(B, -1, H_ori, W_ori)
+ return output
diff --git a/projects/CAT-Seg/cat_seg/models/clip_ovseg.py b/projects/CAT-Seg/cat_seg/models/clip_ovseg.py
new file mode 100644
index 0000000000..cb67744e34
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/clip_ovseg.py
@@ -0,0 +1,293 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from huggingface_hub.utils._errors import LocalEntryNotFoundError
+from mmengine.model import BaseModule
+
+from mmseg.registry import MODELS
+from mmseg.utils import ConfigType
+from ..utils import clip_wrapper
+from ..utils.clip_templates import (IMAGENET_TEMPLATES,
+ IMAGENET_TEMPLATES_SELECT)
+
+
+@MODELS.register_module()
+class CLIPOVCATSeg(BaseModule):
+ """CLIP based Open Vocabulary CAT-Seg model backbone.
+
+ This backbone is the modified implementation of `CAT-Seg Backbone
+ `_. It combines the CLIP model and
+ another feature extractor, a.k.a the appearance guidance extractor
+ in the original `CAT-Seg`.
+
+ Args:
+ feature_extractor (ConfigType): Appearance guidance extractor
+ config dict.
+ train_class_json (str): The training class json file.
+ test_class_json (str): The path to test class json file.
+ clip_pretrained (str): The pre-trained clip type.
+ clip_finetune (str): The finetuning settings of clip model.
+ custom_clip_weights (str): The custmized clip weights directory. When
+ encountering huggingface model download errors, you can manually
+ download the pretrained weights.
+ backbone_multiplier (float): The learning rate multiplier.
+ Default: 0.01.
+ prompt_depth (int): The prompt depth. Default: 0.
+ prompt_length (int): The prompt length. Default: 0.
+ prompt_ensemble_type (str): The prompt ensemble type.
+ Default: "imagenet".
+ pixel_mean (List[float]): The pixel mean for feature extractor.
+ pxiel_std (List[float]): The pixel std for feature extractor.
+ clip_pixel_mean (List[float]): The pixel mean for clip model.
+ clip_pxiel_std (List[float]): The pixel std for clip model.
+ clip_img_feat_size: (List[int]: Clip image embedding size from
+ image encoder.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ feature_extractor: ConfigType,
+ train_class_json: str,
+ test_class_json: str,
+ clip_pretrained: str,
+ clip_finetune: str,
+ custom_clip_weights: str = None,
+ backbone_multiplier=0.01,
+ prompt_depth: int = 0,
+ prompt_length: int = 0,
+ prompt_ensemble_type: str = 'imagenet',
+ pixel_mean: List[float] = [123.675, 116.280, 103.530],
+ pixel_std: List[float] = [58.395, 57.120, 57.375],
+ clip_pixel_mean: List[float] = [
+ 122.7709383, 116.7460125, 104.09373615
+ ],
+ clip_pixel_std: List[float] = [68.5005327, 66.6321579, 70.3231630],
+ clip_img_feat_size: List[int] = [24, 24],
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ # normalization parameters
+ self.register_buffer('pixel_mean',
+ torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
+ self.register_buffer('pixel_std',
+ torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
+ self.register_buffer('clip_pixel_mean',
+ torch.Tensor(clip_pixel_mean).view(1, -1, 1, 1),
+ False)
+ self.register_buffer('clip_pixel_std',
+ torch.Tensor(clip_pixel_std).view(1, -1, 1, 1),
+ False)
+ self.clip_resolution = (
+ 384, 384) if clip_pretrained == 'ViT-B/16' else (336, 336)
+ # modified clip image encoder with fixed size dense output
+ self.clip_img_feat_size = clip_img_feat_size
+
+ # prepare clip templates
+ self.prompt_ensemble_type = prompt_ensemble_type
+ if self.prompt_ensemble_type == 'imagenet_select':
+ prompt_templates = IMAGENET_TEMPLATES_SELECT
+ elif self.prompt_ensemble_type == 'imagenet':
+ prompt_templates = IMAGENET_TEMPLATES
+ elif self.prompt_ensemble_type == 'single':
+ prompt_templates = [
+ 'A photo of a {} in the scene',
+ ]
+ else:
+ raise NotImplementedError
+ self.prompt_templates = prompt_templates
+
+ # build the feature extractor
+ self.feature_extractor = MODELS.build(feature_extractor)
+
+ # build CLIP model
+ with open(train_class_json) as f_in:
+ self.class_texts = json.load(f_in)
+ with open(test_class_json) as f_in:
+ self.test_class_texts = json.load(f_in)
+ assert self.class_texts is not None
+ if self.test_class_texts is None:
+ self.test_class_texts = self.class_texts
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.tokenizer = None
+ if clip_pretrained == 'ViT-G' or clip_pretrained == 'ViT-H':
+ # for OpenCLIP models
+ import open_clip
+ name, pretrain = (
+ 'ViT-H-14',
+ 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else (
+ 'ViT-bigG-14', 'laion2b_s39b_b160k')
+ try:
+ open_clip_model = open_clip.create_model_and_transforms(
+ name,
+ pretrained=pretrain,
+ device=device,
+ force_image_size=336,
+ )
+ clip_model, _, clip_preprocess = open_clip_model
+ except ConnectionError or LocalEntryNotFoundError as e:
+ print(f'Has {e} when loading weights from huggingface!')
+ print(
+ f'Will load {pretrain} weights from {custom_clip_weights}.'
+ )
+ assert custom_clip_weights is not None, 'Please specify custom weights directory.' # noqa
+ assert os.path.exists(
+ os.path.join(custom_clip_weights,
+ 'open_clip_pytorch_model.bin')
+ ), 'Please provide a valid directory for manually downloaded model.' # noqa
+ open_clip_model = open_clip.create_model_and_transforms(
+ name,
+ pretrained=None,
+ device='cpu',
+ force_image_size=336,
+ )
+ clip_model, _, clip_preprocess = open_clip_model
+
+ open_clip.load_checkpoint(
+ clip_model,
+ os.path.expanduser(
+ os.path.join(custom_clip_weights,
+ 'open_clip_pytorch_model.bin')))
+ clip_model.to(torch.device(device))
+
+ self.tokenizer = open_clip.get_tokenizer(name)
+ else:
+ # for OpenAI models
+ clip_model, clip_preprocess = clip_wrapper.load(
+ clip_pretrained,
+ device=device,
+ jit=False,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length)
+
+ # pre-encode classes text prompts
+ text_features = self.class_embeddings(self.class_texts,
+ prompt_templates, clip_model,
+ device).permute(1, 0, 2).float()
+ text_features_test = self.class_embeddings(self.test_class_texts,
+ prompt_templates,
+ clip_model,
+ device).permute(1, 0,
+ 2).float()
+ self.register_buffer('text_features', text_features, False)
+ self.register_buffer('text_features_test', text_features_test, False)
+
+ # prepare CLIP model finetune
+ self.clip_finetune = clip_finetune
+ self.clip_model = clip_model.float()
+ self.clip_preprocess = clip_preprocess
+
+ for name, params in self.clip_model.named_parameters():
+ if 'visual' in name:
+ if clip_finetune == 'prompt':
+ params.requires_grad = True if 'prompt' in name else False
+ elif clip_finetune == 'attention':
+ if 'attn' in name or 'position' in name:
+ params.requires_grad = True
+ else:
+ params.requires_grad = False
+ elif clip_finetune == 'full':
+ params.requires_grad = True
+ else:
+ params.requires_grad = False
+ else:
+ params.requires_grad = False
+
+ finetune_backbone = backbone_multiplier > 0.
+ for name, params in self.feature_extractor.named_parameters():
+ if 'norm0' in name:
+ params.requires_grad = False
+ else:
+ params.requires_grad = finetune_backbone
+
+ @torch.no_grad()
+ def class_embeddings(self,
+ classnames,
+ templates,
+ clip_model,
+ device='cpu'):
+ """Convert class names to text embeddings by clip model.
+
+ Args:
+ classnames (list): loaded from json file.
+ templates (dict): text template.
+ clip_model (nn.Module): prepared clip model.
+ device (str | torch.device): loading device of text
+ encoder results.
+ """
+ zeroshot_weights = []
+ for classname in classnames:
+ if ', ' in classname:
+ classname_splits = classname.split(', ')
+ texts = []
+ for template in templates:
+ for cls_split in classname_splits:
+ texts.append(template.format(cls_split))
+ else:
+ texts = [template.format(classname)
+ for template in templates] # format with class
+ if self.tokenizer is not None:
+ texts = self.tokenizer(texts).to(device)
+ else:
+ texts = clip_wrapper.tokenize(texts).to(device)
+ class_embeddings = clip_model.encode_text(texts)
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+ if len(templates) != class_embeddings.shape[0]:
+ class_embeddings = class_embeddings.reshape(
+ len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+ class_embedding = class_embeddings
+ zeroshot_weights.append(class_embedding)
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
+ return zeroshot_weights
+
+ def custom_normalize(self, inputs):
+ """Input normalization for clip model and feature extractor
+ respectively.
+
+ Args:
+ inputs: batched input images.
+ """
+ # clip images
+ batched_clip = (inputs - self.clip_pixel_mean) / self.clip_pixel_std
+ batched_clip = F.interpolate(
+ batched_clip,
+ size=self.clip_resolution,
+ mode='bilinear',
+ align_corners=False)
+ # feature extractor images
+ batched = (inputs - self.pixel_mean) / self.pixel_std
+ return batched, batched_clip
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs: minibatch image. (B, 3, H, W)
+ Returns:
+ outputs (dict):
+ 'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
+ `self.feature_extractor`.
+ 'clip_text_feat': the text feature extracted by clip text encoder.
+ 'clip_text_feat_test': the text feature extracted by clip text
+ encoder for testing.
+ 'clip_img_feat': the image feature extracted clip image encoder.
+ """
+ inputs, clip_inputs = self.custom_normalize(inputs)
+ outputs = dict()
+ # extract appearance guidance feature
+ outputs['appearance_feat'] = self.feature_extractor(inputs)
+
+ # extract clip features
+ outputs['clip_text_feat'] = self.text_features
+ outputs['clip_text_feat_test'] = self.text_features_test
+ clip_features = self.clip_model.encode_image(
+ clip_inputs, dense=True) # B, 577(24x24+1), C
+ B = clip_features.size(0)
+ outputs['clip_img_feat'] = clip_features[:, 1:, :].permute(
+ 0, 2, 1).reshape(B, -1, *self.clip_img_feat_size)
+
+ return outputs
diff --git a/projects/CAT-Seg/cat_seg/utils/__init__.py b/projects/CAT-Seg/cat_seg/utils/__init__.py
new file mode 100644
index 0000000000..88746b2cba
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .clip_templates import (IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT,
+ IMAGENET_TEMPLATES_SELECT_CLIP, ViLD_templates)
+from .self_attention_block import FullAttention, LinearAttention
+
+__all__ = [
+ 'FullAttention', 'LinearAttention', 'IMAGENET_TEMPLATES',
+ 'IMAGENET_TEMPLATES_SELECT', 'IMAGENET_TEMPLATES_SELECT_CLIP',
+ 'ViLD_templates'
+]
diff --git a/projects/CAT-Seg/cat_seg/utils/bpe_vocab/bpe_simple_vocab_16e6.txt.gz b/projects/CAT-Seg/cat_seg/utils/bpe_vocab/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000..7b5088a527
Binary files /dev/null and b/projects/CAT-Seg/cat_seg/utils/bpe_vocab/bpe_simple_vocab_16e6.txt.gz differ
diff --git a/projects/CAT-Seg/cat_seg/utils/clip_model.py b/projects/CAT-Seg/cat_seg/utils/clip_model.py
new file mode 100644
index 0000000000..977444f5b5
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/clip_model.py
@@ -0,0 +1,651 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ """Custom implementation of Bottleneck in ResNet."""
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+ # all conv layers have stride 1.
+ # an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool,
+ # and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(
+ OrderedDict([('-1', nn.AvgPool2d(stride)),
+ ('0',
+ nn.Conv2d(
+ inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias=False)),
+ ('1', nn.BatchNorm2d(planes * self.expansion))]))
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ """Attention Pool2d."""
+
+ def __init__(self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads: int,
+ output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1],
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat(
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False)
+ return x.squeeze(0)
+
+
+class ModifiedResNet(nn.Module):
+ """A ResNet class that is similar to torchvision's but contains the
+ following changes:
+
+ - There are now 3 "stem" convolutions as opposed to 1, with an average
+ pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is
+ prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self,
+ layers,
+ output_dim,
+ heads,
+ input_resolution=224,
+ width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(
+ width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ # this is a *mutable* variable used during construction
+ self._inplanes = width
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
+ heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ """Build resnet layers."""
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): the input mini-batch images.
+ """
+
+ def stem(x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ """Wrapper of GELU activation layer."""
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ """Attention block with residual connection."""
+
+ def __init__(self,
+ d_model: int,
+ n_head: int,
+ attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
+ ('gelu', QuickGELU()),
+ ('c_proj', nn.Linear(d_model * 4, d_model))]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+ self.mask_pre_mlp = True
+
+ def attention(self, x: torch.Tensor):
+ """Calculate mask multi-head-attention."""
+ self.attn_mask = self.attn_mask.to(
+ dtype=x.dtype,
+ device=x.device) if self.attn_mask is not None else None
+ return self.attn(
+ x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+ def forward_dense(self, x: torch.Tensor):
+ """Reinplementation of forward function for dense prediction of image
+ encoder in CLIP model.
+
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ y = self.ln_1(x)
+ y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
+ L, N, D = y.shape # L N 3D
+
+ y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0,
+ 3).reshape(3 * N, L, D // 3)
+ y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
+
+ q, k, v = y.tensor_split(3, dim=0)
+ v = v.transpose(1, 0) + x # L N D
+
+ v = v + self.mlp(self.ln_2(v))
+ return v
+
+
+class Transformer(nn.Module):
+ """General Transformer Architecture for both image and text encoder."""
+
+ def __init__(self,
+ width: int,
+ layers: int,
+ heads: int,
+ attn_mask: torch.Tensor = None,
+ prompt_length=0,
+ prompt_depth=0):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[
+ ResidualAttentionBlock(width, heads, attn_mask)
+ for _ in range(layers)
+ ])
+
+ self.prompt_length = prompt_length
+ self.prompt_depth = prompt_depth
+ self.prompt_tokens = nn.Parameter(
+ torch.zeros(prompt_depth, prompt_length,
+ width)) if prompt_length > 0 else None
+ if self.prompt_tokens is not None:
+ nn.init.xavier_uniform_(self.prompt_tokens)
+
+ def forward(self, x: torch.Tensor, dense=False):
+ """
+ Args:
+ x (torch.Tensor): input features.
+ dense (bool): whether use reimplemented dense forward
+ function in the last layer.
+ """
+ for i, resblock in enumerate(self.resblocks):
+ if self.prompt_length > 0 and i < self.prompt_depth:
+ length = self.prompt_length + 1 if i > 0 else 1
+ x = torch.cat((x[0:1, :, :], self.prompt_tokens[i].repeat(
+ x.shape[1], 1, 1).permute(1, 0, 2), x[length:, :, :]))
+
+ if i == self.layers - 1 and dense:
+ x = resblock.forward_dense(x)
+ x = torch.cat((x[0:1, :, :], x[self.prompt_length + 1::, :]),
+ dim=0)
+ else:
+ x = resblock(x)
+
+ return x
+
+
+class VisualTransformer(nn.Module):
+ """Visual encoder for CLIP model."""
+
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
+ layers: int, heads: int, output_dim: int, prompt_depth: int,
+ prompt_length: int):
+ super().__init__()
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False)
+
+ scale = width**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(
+ (input_resolution // patch_size)**2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(
+ width,
+ layers,
+ heads,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ self.patch_size = patch_size
+ self.input_resolution = input_resolution
+
+ def forward(self, x: torch.Tensor, dense=False):
+ """
+ Args:
+ x (torch.Tensor): input features.
+ dense (bool): whether use reimplemented dense forward
+ function in the last layer.
+ """
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1],
+ -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([
+ self.class_embedding.to(x.dtype) + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
+ ],
+ dim=1) # shape = [*, grid ** 2 + 1, width]
+
+ if dense and (x.shape[1] != self.positional_embedding.shape[0]):
+ x = x + self.resized_pos_embed(self.input_resolution,
+ x.shape[1]).to(x.dtype)
+ else:
+ x = x + self.positional_embedding.to(x.dtype)
+
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, dense)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if dense:
+ x = self.ln_post(x[:, :, :])
+ else:
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+ def resized_pos_embed(self, in_res, tgt_res, mode='bicubic'):
+ """Resize the position embedding."""
+ # assert L == (input_resolution // self.patch_size) ** 2 + 1
+ L, D = self.positional_embedding.shape
+
+ in_side = in_res // self.patch_size
+ # tgt_side = tgt_res // self.patch_size
+ tgt_side = int((tgt_res - 1)**0.5)
+
+ cls_pos = self.positional_embedding[0].unsqueeze(0) # 1 D
+ pos_embed = self.positional_embedding[1:].reshape(
+ 1, in_side, in_side, D).permute(0, 3, 1, 2) # L-1 D -> 1 D S S
+ resized_pos_embed = F.interpolate(
+ pos_embed,
+ size=(tgt_side, tgt_side),
+ mode=mode,
+ align_corners=False,
+ ) # 1 D S S -> 1 D S' S'
+ resized_pos_embed = resized_pos_embed.squeeze(0).reshape(
+ D, -1).T # L'-1 D
+
+ return torch.cat((cls_pos, resized_pos_embed), dim=0)
+
+
+class CLIP(nn.Module):
+ """Custom implementation of CLIP model.
+
+ Refer to: https://github.com/openai/CLIP
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int,
+ # prompt
+ prompt_depth: int = 0,
+ prompt_length: int = 0,
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ self.image_resolution = image_resolution
+
+ if isinstance(vision_layers, (tuple, list)):
+ assert prompt_length == 0 and prompt_depth == 0
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width)
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length,
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask())
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(
+ torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]))
+
+ def build_attention_mask(self):
+ """Create causal attention mask."""
+ # lazily create causal attention mask, with full attention between
+ # the vision tokens pytorch uses additive attention mask; fill with
+ # -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float('-inf'))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ """Return the dtype of the model."""
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image, masks=None, pool_mask=None, dense=False):
+ """Image encoding."""
+ if pool_mask is not None:
+ return self.visual(
+ image.type(self.dtype), mask=pool_mask, dense=dense)
+ if masks is None:
+ return self.visual(image.type(self.dtype), dense=dense)
+ else:
+ return self.visual(image.type(self.dtype), masks.type(self.dtype))
+
+ def encode_text(self, text):
+ """Texts encoding."""
+ x = self.token_embedding(text).type(
+ self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number
+ # in each sequence)
+ x = x[torch.arange(x.shape[0]),
+ text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ """
+ Args:
+ image (torch.Tensor): input images.
+ text (torch.Tensor): input text.
+ """
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+ # import pdb; pdb.set_trace()
+ # normalized features
+ # image_features shape: [1, 1024]
+ image_features = image_features / image_features.norm(
+ dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(
+ dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_iamge = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_iamge, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16."""
+
+ def _convert_weights_to_fp16(layer):
+ if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ layer.weight.data = layer.weight.data.half()
+ if layer.bias is not None:
+ layer.bias.data = layer.bias.data.half()
+
+ if isinstance(layer, nn.MultiheadAttention):
+ for attr in [
+ *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
+ 'in_proj_bias', 'bias_k', 'bias_v'
+ ]:
+ tensor = getattr(layer, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ['text_projection', 'proj']:
+ if hasattr(layer, name):
+ attr = getattr(layer, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict, prompt_depth=0, prompt_length=0):
+ """Build a CLIP model from given pretrained weights."""
+ vit = 'visual.proj' in state_dict
+
+ if vit:
+ vision_width = state_dict['visual.conv1.weight'].shape[0]
+ vision_layers = len([
+ k for k in state_dict.keys()
+ if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
+ ])
+ vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
+ grid_size = round(
+ (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [
+ len({
+ k.split('.')[2]
+ for k in state_dict if k.startswith(f'visual.layer{b}')
+ }) for b in [1, 2, 3, 4]
+ ]
+ vision_layers = tuple(counts)
+ vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
+ output_width = round(
+ (state_dict['visual.attnpool.positional_embedding'].shape[0] -
+ 1)**0.5)
+ vision_patch_size = None
+ assert output_width**2 + 1 == state_dict[
+ 'visual.attnpool.positional_embedding'].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict['text_projection'].shape[1]
+ context_length = state_dict['positional_embedding'].shape[0]
+ vocab_size = state_dict['token_embedding.weight'].shape[0]
+ transformer_width = state_dict['ln_final.weight'].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len({
+ k.split('.')[2]
+ for k in state_dict if k.startswith('transformer.resblocks')
+ })
+
+ model = CLIP(
+ embed_dim,
+ image_resolution,
+ vision_layers,
+ vision_width,
+ vision_patch_size,
+ context_length,
+ vocab_size,
+ transformer_width,
+ transformer_heads,
+ transformer_layers,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length,
+ )
+
+ for key in ['input_resolution', 'context_length', 'vocab_size']:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
diff --git a/projects/CAT-Seg/cat_seg/utils/clip_templates.py b/projects/CAT-Seg/cat_seg/utils/clip_templates.py
new file mode 100644
index 0000000000..bfc32dfc56
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/clip_templates.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+# Source: https://github.com/openai/CLIP.
+
+IMAGENET_TEMPLATES = [
+ 'a bad photo of a {}.',
+ 'a photo of many {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'a photo of a {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+ # 'A photo of a {} in the scene.',
+]
+
+# v1: 59.0875
+IMAGENET_TEMPLATES_SELECT = [
+ 'itap of a {}.',
+ 'a bad photo of the {}.',
+ 'a origami {}.',
+ 'a photo of the large {}.',
+ 'a {} in a video game.',
+ 'art of the {}.',
+ 'a photo of the small {}.',
+ 'A photo of a {} in the scene',
+]
+
+# v9
+IMAGENET_TEMPLATES_SELECT_CLIP = [
+ 'a bad photo of the {}.',
+ 'a photo of the large {}.',
+ 'a photo of the small {}.',
+ 'a cropped photo of a {}.',
+ 'This is a photo of a {}',
+ 'This is a photo of a small {}',
+ 'This is a photo of a medium {}',
+ 'This is a photo of a large {}',
+ 'This is a masked photo of a {}',
+ 'This is a masked photo of a small {}',
+ 'This is a masked photo of a medium {}',
+ 'This is a masked photo of a large {}',
+ 'This is a cropped photo of a {}',
+ 'This is a cropped photo of a small {}',
+ 'This is a cropped photo of a medium {}',
+ 'This is a cropped photo of a large {}',
+ 'A photo of a {} in the scene',
+ 'a bad photo of the {} in the scene',
+ 'a photo of the large {} in the scene',
+ 'a photo of the small {} in the scene',
+ 'a cropped photo of a {} in the scene',
+ 'a photo of a masked {} in the scene',
+ 'There is a {} in the scene',
+ 'There is the {} in the scene',
+ 'This is a {} in the scene',
+ 'This is the {} in the scene',
+ 'This is one {} in the scene',
+ 'There is a masked {} in the scene',
+ 'There is the masked {} in the scene',
+ 'This is a masked {} in the scene',
+ 'This is the masked {} in the scene',
+ 'This is one masked {} in the scene',
+]
+
+# v10, for comparison
+# IMAGENET_TEMPLATES_SELECT_CLIP = [
+# 'a photo of a {}.',
+#
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'a photo of a {} in the scene',
+# 'a photo of a {} in the scene',
+#
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+# ]
+
+ViLD_templates = [
+ 'There is {article} {category} in the scene.',
+ 'There is the {category} in the scene.',
+ 'a photo of {article} {category} in the scene.',
+ 'a photo of the {category} in the scene.',
+ 'a photo of one {category} in the scene.', 'itap of {article} {category}.',
+ 'itap of my {category}.', 'itap of the {category}.',
+ 'a photo of {article} {category}.', 'a photo of my {category}.',
+ 'a photo of the {category}.', 'a photo of one {category}.',
+ 'a photo of many {category}.', 'a good photo of {article} {category}.',
+ 'a good photo of the {category}.', 'a bad photo of {article} {category}.',
+ 'a bad photo of the {category}.', 'a photo of a nice {category}.',
+ 'a photo of the nice {category}.', 'a photo of a cool {category}.',
+ 'a photo of the cool {category}.', 'a photo of a weird {category}.',
+ 'a photo of the weird {category}.', 'a photo of a small {category}.',
+ 'a photo of the small {category}.', 'a photo of a large {category}.',
+ 'a photo of the large {category}.', 'a photo of a clean {category}.',
+ 'a photo of the clean {category}.', 'a photo of a dirty {category}.',
+ 'a photo of the dirty {category}.',
+ 'a bright photo of {article} {category}.',
+ 'a bright photo of the {category}.',
+ 'a dark photo of {article} {category}.', 'a dark photo of the {category}.',
+ 'a photo of a hard to see {category}.',
+ 'a photo of the hard to see {category}.',
+ 'a low resolution photo of {article} {category}.',
+ 'a low resolution photo of the {category}.',
+ 'a cropped photo of {article} {category}.',
+ 'a cropped photo of the {category}.',
+ 'a close-up photo of {article} {category}.',
+ 'a close-up photo of the {category}.',
+ 'a jpeg corrupted photo of {article} {category}.',
+ 'a jpeg corrupted photo of the {category}.',
+ 'a blurry photo of {article} {category}.',
+ 'a blurry photo of the {category}.',
+ 'a pixelated photo of {article} {category}.',
+ 'a pixelated photo of the {category}.',
+ 'a black and white photo of the {category}.',
+ 'a black and white photo of {article} {category}.',
+ 'a plastic {category}.', 'the plastic {category}.', 'a toy {category}.',
+ 'the toy {category}.', 'a plushie {category}.', 'the plushie {category}.',
+ 'a cartoon {category}.', 'the cartoon {category}.',
+ 'an embroidered {category}.', 'the embroidered {category}.',
+ 'a painting of the {category}.', 'a painting of a {category}.'
+]
diff --git a/projects/CAT-Seg/cat_seg/utils/clip_wrapper.py b/projects/CAT-Seg/cat_seg/utils/clip_wrapper.py
new file mode 100644
index 0000000000..f809d2b828
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/clip_wrapper.py
@@ -0,0 +1,275 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Referred to: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/third_party/clip.py # noqa
+import hashlib
+import os
+import urllib
+import warnings
+from typing import List, Union
+
+import torch
+from PIL import Image
+from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
+ ToTensor)
+from tqdm import tqdm
+
+from .clip_model import build_model
+from .tokenizer import SimpleTokenizer as _Tokenizer
+
+__all__ = ['available_models', 'load', 'tokenize']
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ 'RN50':
+ 'https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt', # noqa
+ 'RN101':
+ 'https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt', # noqa
+ 'RN50x4':
+ 'https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt', # noqa
+ 'RN50x16':
+ 'https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt', # noqa
+ 'RN50x64':
+ 'https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt', # noqa
+ 'ViT-B/32':
+ 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt', # noqa
+ 'ViT-B/16':
+ 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt', # noqa
+ 'ViT-L/14':
+ 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt', # noqa
+ 'ViT-L/14@336px':
+ 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt', # noqa
+}
+
+
+def _download(url: str, root: str = os.path.expanduser('~/.cache/clip')):
+ """Download clip pretrained weights."""
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split('/')[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(
+ f'{download_target} exists and is not a regular file')
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target,
+ 'rb').read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(
+ f'{download_target} exists, but the SHA256 checksum does not\
+ match; re-downloading the file')
+
+ with urllib.request.urlopen(url) as source, open(download_target,
+ 'wb') as output:
+ with tqdm(
+ total=int(source.info().get('Content-Length')),
+ ncols=80) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target,
+ 'rb').read()).hexdigest() != expected_sha256:
+ raise RuntimeError(
+ 'Model has been downloaded but the SHA256 checksum does not not\
+ match')
+
+ return download_target
+
+
+def available_models():
+ """Returns a list of available models."""
+ return list(_MODELS.keys())
+
+
+def load(name: str,
+ device: Union[str, torch.device] = 'cuda'
+ if torch.cuda.is_available() else 'cpu',
+ jit=True,
+ prompt_depth=0,
+ prompt_length=0):
+ """Load target clip model."""
+ if name not in _MODELS:
+ raise RuntimeError(
+ f'Model {name} not found; available models = {available_models()}')
+
+ model_path = _download(_MODELS[name])
+ model = torch.jit.load(
+ model_path, map_location=device if jit else 'cpu').eval()
+ n_px = model.input_resolution.item()
+
+ transform = Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert('RGB'),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ if not jit:
+ model = build_model(model.state_dict(), prompt_depth,
+ prompt_length).to(device)
+ return model, transform
+
+ # patch the device names
+ device_holder = torch.jit.trace(
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [
+ n for n in device_holder.graph.findAllNodes('prim::Constant')
+ if 'Device' in repr(n)
+ ][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('prim::Constant'):
+ if 'value' in node.attributeNames() and str(
+ node['value']).startswith('cuda'):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if device == 'cpu':
+ float_holder = torch.jit.trace(
+ lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('aten::to'):
+ inputs = list(node.inputs())
+ for i in [1, 2]:
+ # dtype can be the second or third argument to
+ # aten::to()
+ if inputs[i].node()['value'] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, transform
+
+
+def load_custom(name: str,
+ device: Union[str, torch.device] = 'cuda'
+ if torch.cuda.is_available() else 'cpu',
+ jit=True,
+ n_px=224):
+ """Load a customized clip model."""
+ if name not in _MODELS:
+ raise RuntimeError(
+ f'Model {name} not found; available models = {available_models()}')
+
+ model_path = _download(_MODELS[name])
+ model = torch.jit.load(
+ model_path, map_location=device if jit else 'cpu').eval()
+ # n_px = model.input_resolution.item()
+
+ transform = Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert('RGB'),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ if not jit:
+ model = build_model(model.state_dict()).to(device)
+ return model, transform
+
+ # patch the device names
+ device_holder = torch.jit.trace(
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [
+ n for n in device_holder.graph.findAllNodes('prim::Constant')
+ if 'Device' in repr(n)
+ ][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('prim::Constant'):
+ if 'value' in node.attributeNames() and str(
+ node['value']).startswith('cuda'):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if device == 'cpu':
+ float_holder = torch.jit.trace(
+ lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('aten::to'):
+ inputs = list(node.inputs())
+ for i in [
+ 1, 2
+ ]: # dtype can be the second or third argument to
+ # aten::to()
+ if inputs[i].node()['value'] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, transform
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77):
+ """Convert texts to tokens."""
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder['<|startoftext|>']
+ eot_token = _tokenizer.encoder['<|endoftext|>']
+ # encode each template text phrase
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
+ for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ raise RuntimeError(
+ f'Input {texts[i]} is too long for context length\
+ {context_length}')
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/projects/CAT-Seg/cat_seg/utils/self_attention_block.py b/projects/CAT-Seg/cat_seg/utils/self_attention_block.py
new file mode 100644
index 0000000000..1c06cbd99e
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/self_attention_block.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+class LinearAttention(nn.Module):
+ """Multi-Head linear attention proposed in "Transformers are RNNs".
+
+ Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L247 # noqa
+ """
+
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.eps = eps
+
+ def forward(self, queries, keys, values):
+ """
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = F.elu(queries) + 1
+ K = F.elu(keys) + 1
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum('nshd,nshv->nhdv', K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum('nlhd,nhd->nlh', Q, K.sum(dim=1)) + self.eps)
+ queried_values = torch.einsum('nlhd,nhdv,nlh->nlhv', Q, KV,
+ Z) * v_length
+
+ return queried_values.contiguous()
+
+
+class FullAttention(nn.Module):
+ """Multi-head scaled dot-product attention, a.k.a full attention.
+
+ Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L276 # noqa
+ """
+
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum('nlhd,nshd->nlsh', queries, keys)
+ if kv_mask is not None:
+ QK.masked_fill_(
+ ~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]),
+ float('-inf'))
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+ if self.use_dropout:
+ A = self.dropout(A)
+
+ queried_values = torch.einsum('nlsh,nshd->nlhd', A, values)
+
+ return queried_values.contiguous()
diff --git a/projects/CAT-Seg/cat_seg/utils/tokenizer.py b/projects/CAT-Seg/cat_seg/utils/tokenizer.py
new file mode 100644
index 0000000000..c84711b067
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/tokenizer.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ """Return default BPE vocabulary path."""
+ return os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ 'bpe_vocab/bpe_simple_vocab_16e6.txt.gz')
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """Returns list of utf-8 byte and a corresponding list of unicode strings.
+
+ The reversible bpe codes work on unicode strings. This means you need a
+ large # of unicode characters in your vocab if you want to avoid UNKs. When
+ you're at something like a 10B token dataset you end up needing around 5K
+ for decent coverage. This is a significant percentage of your normal, say,
+ 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
+ unicode strings. And avoids mapping to whitespace/control characters the
+ bpe code barfs on.
+ """
+ bs = list(range(ord('!'),
+ ord('~') + 1)) + list(range(
+ ord('¡'),
+ ord('¬') + 1)) + list(range(ord('®'),
+ ord('ÿ') + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length
+ strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ """Clean string."""
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """Clean whitespace in string."""
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer:
+ """Customized Tokenizer implementation."""
+
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
+ merges = merges[1:49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + '' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {
+ '<|startoftext|>': '<|startoftext|>',
+ '<|endoftext|>': '<|endoftext|>'
+ }
+ self.pat = re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|\
+ 'll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ """Refer to bpe vocabulary dictionary."""
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + '', )
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ''
+
+ while True:
+ bigram = min(
+ pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[
+ i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ """Encode text strings."""
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b]
+ for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token]
+ for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ """Decoder tokens to strings."""
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ 'utf-8', errors='replace').replace('', ' ')
+ return text
diff --git a/projects/CAT-Seg/configs/_base_/datasets/ade20k_384x384.py b/projects/CAT-Seg/configs/_base_/datasets/ade20k_384x384.py
new file mode 100644
index 0000000000..488ba3d7f6
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/datasets/ade20k_384x384.py
@@ -0,0 +1,68 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+crop_size = (384, 384)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(
+ type='RandomResize',
+ scale=(2048, 512),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(2048, 512), keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/training', seg_map_path='annotations/training'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/validation',
+ seg_map_path='annotations/validation'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/CAT-Seg/configs/_base_/datasets/coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/_base_/datasets/coco-stuff164k_384x384.py
new file mode 100644
index 0000000000..dd051761d4
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/datasets/coco-stuff164k_384x384.py
@@ -0,0 +1,62 @@
+# dataset settings
+dataset_type = 'COCOStuffDataset'
+data_root = 'data/coco_stuff164k'
+crop_size = (384, 384)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(2048, 512), keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/train2017', seg_map_path='annotations/train2017'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/val2017', seg_map_path='annotations/val2017'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/CAT-Seg/configs/_base_/datasets/pascal_context_59_384x384.py b/projects/CAT-Seg/configs/_base_/datasets/pascal_context_59_384x384.py
new file mode 100644
index 0000000000..250c5990f6
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/datasets/pascal_context_59_384x384.py
@@ -0,0 +1,72 @@
+# dataset settings
+dataset_type = 'PascalContextDataset59'
+data_root = 'data/VOCdevkit/VOC2010/'
+
+img_scale = (520, 520)
+crop_size = (384, 384)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(
+ type='RandomResize',
+ scale=img_scale,
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
+ ann_file='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
+ ann_file='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/CAT-Seg/configs/_base_/default_runtime.py b/projects/CAT-Seg/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000..272b4d2467
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/default_runtime.py
@@ -0,0 +1,15 @@
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+ type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+log_processor = dict(by_epoch=False)
+log_level = 'INFO'
+load_from = None
+resume = False
+
+tta_model = dict(type='SegTTAModel')
diff --git a/projects/CAT-Seg/configs/_base_/schedules/schedule_80k.py b/projects/CAT-Seg/configs/_base_/schedules/schedule_80k.py
new file mode 100644
index 0000000000..0dcd6c4d1b
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/schedules/schedule_80k.py
@@ -0,0 +1,24 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=1e-4,
+ power=0.9,
+ begin=0,
+ end=80000,
+ by_epoch=False)
+]
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py
new file mode 100644
index 0000000000..bab43a6a39
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py
@@ -0,0 +1,103 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
+ '../_base_/datasets/ade20k_384x384.py'
+]
+
+custom_imports = dict(imports=['cat_seg'])
+
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+crop_size = (384, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ # due to the clip model, we do normalization in backbone forward()
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+# model_cfg
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ type='ResNet',
+ depth=101,
+ # only use the first three layers
+ num_stages=3,
+ out_indices=(0, 1, 2),
+ dilations=(1, 1, 1),
+ strides=(1, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True,
+ init_cfg=dict(
+ type='Pretrained', checkpoint='torchvision://resnet101'),
+ ),
+ train_class_json='data/ade150.json',
+ test_class_json='data/ade150.json',
+ clip_pretrained='ViT-B/16',
+ clip_finetune='attention',
+ ),
+ neck=dict(
+ type='CATSegAggregator',
+ appearance_guidance_dim=1024,
+ num_layers=2,
+ pooling_size=(1, 1),
+ ),
+ decode_head=dict(
+ type='CATSegHead',
+ in_channels=128,
+ channels=128,
+ num_classes=150,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(512, 256),
+ decoder_guidance_proj_dims=(32, 16),
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_pascal-context-59-384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_pascal-context-59-384x384.py
new file mode 100644
index 0000000000..8b412cb86f
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_pascal-context-59-384x384.py
@@ -0,0 +1,103 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
+ '../_base_/datasets/pascal_context_59_384x384.py'
+]
+
+custom_imports = dict(imports=['cat_seg'])
+
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+crop_size = (384, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ # due to the clip model, we do normalization in backbone forward()
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+# model_cfg
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ type='ResNet',
+ depth=101,
+ # only use the first three layers
+ num_stages=3,
+ out_indices=(0, 1, 2),
+ dilations=(1, 1, 1),
+ strides=(1, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True,
+ init_cfg=dict(
+ type='Pretrained', checkpoint='torchvision://resnet101'),
+ ),
+ train_class_json='data/pc59.json',
+ test_class_json='data/pc59.json',
+ clip_pretrained='ViT-B/16',
+ clip_finetune='attention',
+ ),
+ neck=dict(
+ type='CATSegAggregator',
+ appearance_guidance_dim=1024,
+ num_layers=2,
+ pooling_size=(1, 1),
+ ),
+ decode_head=dict(
+ type='CATSegHead',
+ in_channels=128,
+ channels=128,
+ num_classes=59,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(512, 256),
+ decoder_guidance_proj_dims=(32, 16),
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py
new file mode 100644
index 0000000000..52bf712fea
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py
@@ -0,0 +1,102 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
+ '../_base_/datasets/coco-stuff164k_384x384.py'
+]
+
+custom_imports = dict(imports=['cat_seg'])
+
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+crop_size = (384, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ # due to the clip model, we do normalization in backbone forward()
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+# model_cfg
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ type='ResNet',
+ depth=101,
+ # only use the first three layers
+ num_stages=3,
+ out_indices=(0, 1, 2),
+ dilations=(1, 1, 1),
+ strides=(1, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True,
+ init_cfg=dict(
+ type='Pretrained', checkpoint='torchvision://resnet101'),
+ ),
+ train_class_json='data/coco.json',
+ test_class_json='data/coco.json',
+ clip_pretrained='ViT-B/16',
+ clip_finetune='attention',
+ ),
+ neck=dict(
+ type='CATSegAggregator',
+ appearance_guidance_dim=1024,
+ num_layers=2,
+ ),
+ decode_head=dict(
+ type='CATSegHead',
+ in_channels=128,
+ channels=128,
+ num_classes=171,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(512, 256),
+ decoder_guidance_proj_dims=(32, 16),
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitg-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitg-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
new file mode 100644
index 0000000000..345945d028
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitg-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
@@ -0,0 +1,11 @@
+_base_ = './catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py' # noqa
+
+model = dict(
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ clip_pretrained='ViT-G',
+ custom_clip_weights='~/CLIP-ViT-bigG-14-laion2B-39B-b160k'),
+ neck=dict(
+ text_guidance_dim=1280,
+ appearance_guidance_dim=512,
+ ))
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vith-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vith-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
new file mode 100644
index 0000000000..2f09b8c9ca
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vith-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
@@ -0,0 +1,11 @@
+_base_ = './catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py' # noqa
+
+model = dict(
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ clip_pretrained='ViT-H',
+ custom_clip_weights='~/CLIP-ViT-H-14-laion2B-s32B-b79K'),
+ neck=dict(
+ text_guidance_dim=1024,
+ appearance_guidance_dim=512,
+ ))
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
new file mode 100644
index 0000000000..bb4d57ae21
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
@@ -0,0 +1,72 @@
+_base_ = './catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py' # noqa
+
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa
+crop_size = (384, 384)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ _delete_=True,
+ type='SwinTransformer',
+ pretrain_img_size=384,
+ embed_dims=128,
+ depths=[2, 2, 18],
+ num_heads=[4, 8, 16],
+ window_size=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=(0, 1, 2),
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ clip_pretrained='ViT-L/14@336px',
+ ),
+ neck=dict(
+ text_guidance_dim=768,
+ appearance_guidance_dim=512,
+ ),
+ decode_head=dict(
+ embed_dims=128,
+ decoder_guidance_dims=(256, 128),
+ ))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=1,
+ num_workers=2,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/utils/__init__.py b/projects/CAT-Seg/utils/__init__.py
new file mode 100644
index 0000000000..02d85f29cb
--- /dev/null
+++ b/projects/CAT-Seg/utils/__init__.py
@@ -0,0 +1,7 @@
+from .clip_templates import (IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT,
+ IMAGENET_TEMPLATES_SELECT_CLIP, ViLD_templates)
+
+__all__ = [
+ 'IMAGENET_TEMPLATES', 'IMAGENET_TEMPLATES_SELECT',
+ 'IMAGENET_TEMPLATES_SELECT_CLIP', 'ViLD_templates'
+]