-
Notifications
You must be signed in to change notification settings - Fork 2
/
LNL.py
137 lines (112 loc) · 5.44 KB
/
LNL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Author: Omid Nejati
Email: [email protected]
LNL : Introducing locality mechanism into Transformer in Transformer (TNT)
"""
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
from timm.models.registry import register_model
from models.localvit import LocalityFeedForward
from models.tnt import Attention, TNT
import math
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'tnt_t_conv_patch16_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'tnt_s_conv_patch16_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'tnt_b_conv_patch16_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
}
class Block(nn.Module):
""" TNT Block
"""
def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
# Inner transformer
self.norm_in = norm_layer(in_dim)
self.attn_in = Attention(
in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.norm_mlp_in = norm_layer(in_dim)
self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4),
out_features=in_dim, act_layer=act_layer, drop=drop)
self.norm1_proj = norm_layer(in_dim)
self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True)
# Outer transformer
self.norm_out = norm_layer(dim)
self.attn_out = Attention(
dim, dim, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, reduction=dim)
def forward(self, pixel_embed, patch_embed):
# inner
x, _ = self.attn_in(self.norm_in(pixel_embed))
pixel_embed = pixel_embed + self.drop_path(x)
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
# outer
B, N, C = patch_embed.size()
Nsqrt = int(math.sqrt(N))
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))
x, weights = self.attn_out(self.norm_out(patch_embed))
patch_embed = patch_embed + self.drop_path(x)
cls_token, patch_embed = torch.split(patch_embed, [1, N - 1], dim=1) # (B, 1, dim), (B, 196, dim)
patch_embed = patch_embed.transpose(1, 2).view(B, C, Nsqrt, Nsqrt) # (B, dim, 14, 14)
patch_embed = self.conv(patch_embed).flatten(2).transpose(1, 2) # (B, 196, dim)
patch_embed = torch.cat([cls_token, patch_embed], dim=1)
return pixel_embed, patch_embed, weights
class LocalViT_TNT(TNT):
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12,
num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4):
super().__init__(img_size, patch_size, in_chans, num_classes, embed_dim, in_dim, depth,
num_heads, in_num_head, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path_rate, norm_layer, first_stride)
new_patch_size = self.pixel_embed.new_patch_size
num_pixel = new_patch_size ** 2
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
blocks = []
for i in range(depth):
blocks.append(Block(
dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], norm_layer=norm_layer))
self.blocks = nn.ModuleList(blocks)
self.apply(self._init_weights)
@register_model
def LNL_Ti(pretrained=False, **kwargs):
model = LocalViT_TNT(patch_size=16, embed_dim=192, in_dim=12, depth=12, num_heads=3, in_num_head=3,
qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_t_conv_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def LNL_S(pretrained=False, **kwargs):
model = LocalViT_TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_s_conv_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return