forked from xinntao/BasicSR-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_arch.py
52 lines (39 loc) · 1.89 KB
/
example_arch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from torch import nn as nn
from torch.nn import functional as F
from basicsr.archs.arch_util import default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class ExampleArch(nn.Module):
"""Example architecture.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
upscale (int): Upsampling factor. Default: 4.
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, upscale=4):
super(ExampleArch, self).__init__()
self.upscale = upscale
self.conv1 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
default_init_weights(
[self.conv1, self.conv2, self.conv3, self.upconv1, self.upconv2, self.conv_hr, self.conv_last], 0.1)
def forward(self, x):
feat = self.lrelu(self.conv1(x))
feat = self.lrelu(self.conv2(feat))
feat = self.lrelu(self.conv3(feat))
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
out += base
return out