-
Notifications
You must be signed in to change notification settings - Fork 3
/
pa_pool.py
69 lines (62 loc) · 2.44 KB
/
pa_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch.nn.functional as F
def pa_avg_pool(in_dict):
"""Mask weighted avg pooling.
Args:
feat: pytorch tensor, with shape [N, C, H, W]
mask: pytorch tensor, with shape [N, pC, pH, pW]
Returns:
feat_list: a list (length = pC) of pytorch tensors with shape [N, C, 1, 1]
visible: pytorch tensor with shape [N, pC]
"""
feat = in_dict['feat']
mask = in_dict['pap_mask']
N, C, H, W = feat.size()
N, pC, pH, pW = mask.size()
# 1 * [N, C, pH, pW] -> [N, 1, C, pH, pW] -> [N, pC, C, pH, pW]
feat = feat.unsqueeze(1).expand((N, pC, C, pH, pW))
# [N, pC]
visible = (mask.sum(-1).sum(-1) != 0).float()
# [N, pC, 1, pH, pW] -> [N, pC, C, pH, pW]
mask = mask.unsqueeze(2).expand((N, pC, C, pH, pW))
# [N, pC, C]
feat = (feat * mask).sum(-1).sum(-1) / (mask.sum(-1).sum(-1) + 1e-12)
# pC * [N, C]
feat_list = list(feat.transpose(0, 1))
# pC * [N, C, 1, 1]
feat_list = [f.unsqueeze(-1).unsqueeze(-1) for f in feat_list]
out_dict = {'feat_list': feat_list, 'visible': visible}
return out_dict
def pa_max_pool(in_dict):
"""Implement `local max pooling` as `masking + global max pooling`.
Args:
feat: pytorch tensor, with shape [N, C, H, W]
mask: pytorch tensor, with shape [N, pC, pH, pW]
Returns:
feat_list: a list (length = pC) of pytorch tensors with shape [N, C, 1, 1]
visible: pytorch tensor with shape [N, pC]
NOTE:
The implementation of `masking + global max pooling` is only equivalent
to `local max pooling` when feature values are non-negative, which holds
for ResNet that has ReLU as final operation of all blocks.
"""
feat = in_dict['feat']
mask = in_dict['pap_mask']
N, C, H, W = feat.size()
N, pC, pH, pW = mask.size()
feat_list = []
for i in range(pC):
# [N, C, pH, pW]
m = mask[:, i, :, :].unsqueeze(1).expand_as(feat)
# [N, C]
local_feat = F.adaptive_max_pool2d(feat * m, 1)
# local_feat = F.adaptive_max_pool2d(feat * m, 1).view(N, -1)
feat_list.append(local_feat)
# [N, pC]
visible = (mask.sum(-1).sum(-1) != 0).float()
out_dict = {'feat_list': feat_list, 'visible': visible}
return out_dict
class PAPool(object):
def __init__(self, cfg):
self.pool = pa_avg_pool if cfg.max_or_avg == 'avg' else pa_max_pool
def __call__(self, *args, **kwargs):
return self.pool(*args, **kwargs)