-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
103 lines (76 loc) · 3 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch.nn as nn
import torch
from torchvision.models.resnet import resnet50
class Residual(nn.Module):
def __init__(self, dim, kernel_size):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same")
self.gelu = nn.GELU()
self.batchnorm = nn.BatchNorm2d(dim)
def forward(self, y):
x = self.conv1(y)
x = self.gelu(x)
x = self.batchnorm(x)
return x + y
class ConvMixerBlock(nn.Module):
def __init__(self, dim, kernel_size):
super().__init__()
self.resblock = Residual(dim, kernel_size)
self.conv1 = nn.Conv2d(dim, dim, kernel_size=1)
self.gelu1 = nn.GELU()
self.batchnorm1 = nn.BatchNorm2d(dim)
def forward(self, x):
x = self.resblock(x)
x = self.conv1(x)
x = self.gelu1(x)
x = self.batchnorm1(x)
return x
class ConvMixer(nn.Module):
def __init__(self, dim, depth, kernel_size, patch_size, n_classes=2):
super().__init__()
self.conv1 = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
self.gelu1 = nn.GELU()
self.batchnorm1 = nn.BatchNorm2d(dim)
self.convmixblock = nn.Sequential(*[ConvMixerBlock(dim, kernel_size) for i in range(depth)])
self.gvp = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
# self.linear = nn.Sequential(nn.Linear(dim, n_classes))
self.linear = nn.Linear(dim, n_classes)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.gelu1(x1)
x3 = self.batchnorm1(x2)
for block in self.convmixblock:
x3 = block(x3)
x4 = self.gvp(x3)
x5 = self.flatten(x4)
x6 = self.linear(x5)
return x6
class ResNetFeatureModel(nn.Module):
def __init__(self, output_layer):
super().__init__()
self.output_layer = output_layer
# Let's use pretrained resnet18 for image classification
pretrained_resnet = resnet50(pretrained=True)
# Extract the model layers up-to output_layer. For our case we've set output_layer = 'avg_pooling'
self.children_list = []
for n,c in pretrained_resnet.named_children():
self.children_list.append(c)
if n == self.output_layer:
break
self.net = nn.Sequential(*self.children_list)
self.flatten = nn.Flatten()
self.linear = nn.Sequential(
nn.Linear(2048, 2)
) # 2048 cause output of resnet is this
def forward(self,x):
x = self.net(x)
x = self.flatten(x)
x = self.linear(x)
return x
def get_model(MODEL_NAME='conv-mix'):
if MODEL_NAME == 'conv-mix':
model = ConvMixer(dim=768, depth=32, kernel_size=7, patch_size=4)
elif MODEL_NAME == 'res-net':
model = ResNetFeatureModel(output_layer='avgpool')
return model