Skip to content

Commit

Permalink
Updated DetectV8 head
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Jul 23, 2024
1 parent d3f5c27 commit 9e47d56
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions tools/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,45 +360,46 @@ def __init__(self, old_detect, use_rvc2: bool):

self.cv2 = old_detect.cv2
self.cv3 = old_detect.cv3
self.dfl = old_detect.dfl
# self.dfl = CustomDFL()
self.f = old_detect.f
self.i = old_detect.i

self.use_rvc2 = use_rvc2

self.proj_conv = nn.Conv2d(old_detect.dfl.c1, 1, 1, bias=False).requires_grad_(False)
x = torch.arange(old_detect.dfl.c1, dtype=torch.float)
self.proj_conv.weight.data[:] = nn.Parameter(x.view(1, old_detect.dfl.c1, 1, 1))

def forward(self, x):
shape = x[0].shape # BCHW

outputs = []
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
box = self.cv2[i](x[i])
h, w = box.shape[2:]

# ------------------------------
# DFL PART
box = box.view(shape[0], 4, self.reg_max, h*w).permute(0, 2, 1, 3)
box = self.proj_conv(F.softmax(box, dim=1))[:, 0]
box = box.reshape([shape[0], 4, h, w])
# ------------------------------

cls = self.cv3[i](x[i])
cls_output = cls.sigmoid()
if self.use_rvc2:
conf, _ = cls_output.max(1, keepdim=True)
else:
conf = torch.ones(
(cls_output.shape[0], 1, cls_output.shape[2], cls_output.shape[3]),
device=cls_output.device,
)

box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split(
(self.reg_max * 4, self.nc), 1
)
box = self.dfl(box)
cls_output = cls.sigmoid()
# Get the max
if self.use_rvc2:
conf, _ = cls_output.max(1, keepdim=True)
else:
conf = torch.ones(
(cls_output.shape[0], 1, cls_output.shape[2]), device=cls_output.device
)
# Concatenate
y = torch.cat([box, conf, cls_output], axis=1)
# Split to 3 channels
outputs = []
start, end = 0, 0
for i, xi in enumerate(x):
end += xi.shape[-2] * xi.shape[-1]
outputs.append(
y[:, :, start:end].view(xi.shape[0], -1, xi.shape[-2], xi.shape[-1])
)
start += xi.shape[-2] * xi.shape[-1]
output = torch.cat([box, conf, cls_output], axis=1)
outputs.append(output)

return outputs


class OBBV8(nn.Module):
"""YOLOv8 OBB detection head for detection with rotation models."""
dynamic = False # force grid reconstruction
Expand Down

0 comments on commit 9e47d56

Please sign in to comment.