Skip to content

Commit

Permalink
Add fix for other V8 heads
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Jul 29, 2024
1 parent 9e47d56 commit bd8010b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 157 deletions.
183 changes: 27 additions & 156 deletions tools/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,12 @@ def __init__(self, old_detect, use_rvc2: bool):
super().__init__()
self.nc = old_detect.nc # number of classes
self.nl = old_detect.nl # number of detection layers
self.reg_max = (
old_detect.reg_max
) # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.reg_max = old_detect.reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = old_detect.no # number of outputs per anchor
self.stride = old_detect.stride # strides computed during build

self.cv2 = old_detect.cv2
self.cv3 = old_detect.cv3
# self.dfl = CustomDFL()
self.f = old_detect.f
self.i = old_detect.i

Expand All @@ -371,7 +368,7 @@ def __init__(self, old_detect, use_rvc2: bool):
self.proj_conv.weight.data[:] = nn.Parameter(x.view(1, old_detect.dfl.c1, 1, 1))

def forward(self, x):
shape = x[0].shape # BCHW
bs = x[0].shape[0] # batch size

outputs = []
for i in range(self.nl):
Expand All @@ -380,9 +377,9 @@ def forward(self, x):

# ------------------------------
# DFL PART
box = box.view(shape[0], 4, self.reg_max, h*w).permute(0, 2, 1, 3)
box = box.view(bs, 4, self.reg_max, h*w).permute(0, 2, 1, 3)
box = self.proj_conv(F.softmax(box, dim=1))[:, 0]
box = box.reshape([shape[0], 4, h, w])
box = box.reshape([bs, 4, h, w])
# ------------------------------

cls = self.cv3[i](x[i])
Expand All @@ -400,131 +397,49 @@ def forward(self, x):

return outputs

class OBBV8(nn.Module):
"""YOLOv8 OBB detection head for detection with rotation models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init

class OBBV8(DetectV8):
"""YOLOv8 OBB detection head for detection with rotation models."""
def __init__(self, old_obb, use_rvc2):
super().__init__()
self.nc = old_obb.nc # number of classes
self.nl = old_obb.nl # number of detection layers
self.reg_max = old_obb.reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = old_obb.no # number of outputs per anchor
self.stride = old_obb.stride # strides computed during build

self.cv2 = old_obb.cv2
self.cv3 = old_obb.cv3
self.dfl = old_obb.dfl
self.f = old_obb.f
self.i = old_obb.i

self.use_rvc2 = use_rvc2

super().__init__(old_obb, use_rvc2)
self.ne = old_obb.ne # number of extra parameters
self.cv4 = old_obb.cv4

def forward(self, x):
# Detection part
outputs = super().forward(x)

# OBB part
bs = x[0].shape[0] # batch size
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
# ---------------------
# Detection part
# ---------------------
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)

box, cls = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
box = self.dfl(box)
cls_output = cls.sigmoid()
# Get the max
if self.use_rvc2:
conf, _ = cls_output.max(1, keepdim=True)
else:
conf = torch.ones((cls_output.shape[0], 1, cls_output.shape[2]), device=cls_output.device)
# Concatenate
y = torch.cat([box, conf, cls_output], axis=1)
# Split to 3 channels
outputs = []
start, end = 0, 0
for i, xi in enumerate(x):
end += xi.shape[-2]*xi.shape[-1]
outputs.append(y[:, :, start:end].view(xi.shape[0], -1, xi.shape[-2], xi.shape[-1]))
start += xi.shape[-2]*xi.shape[-1]

# Append the angle
outputs.append(angle)

return outputs


class PoseV8(nn.Module):
class PoseV8(DetectV8):
"""YOLOv8 Pose head for keypoints models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init

def __init__(self, old_kpts, use_rvc2):
super().__init__()
self.nc = old_kpts.nc # number of classes
self.nl = old_kpts.nl # number of detection layers
self.reg_max = old_kpts.reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = old_kpts.no # number of outputs per anchor
self.stride = old_kpts.stride # strides computed during build

self.cv2 = old_kpts.cv2
self.cv3 = old_kpts.cv3
self.dfl = old_kpts.dfl
self.f = old_kpts.f
self.i = old_kpts.i

super().__init__(old_kpts, use_rvc2)
self.kpt_shape = old_kpts.kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
self.nk = old_kpts.nk # number of keypoints total

self.cv4 = old_kpts.cv4
self.use_rvc2 = use_rvc2

def forward(self, x):
"""Perform forward pass through YOLO model and return predictions."""
bs = x[0].shape[0] # batch size

kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
# ---------------------
# Detection part
# ---------------------
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)

# box, cls = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
x_cat = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2)
if self.shape != bs:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = bs
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

box = self.dfl(box)
cls_output = cls.sigmoid()
# Get the max
if self.use_rvc2:
conf, _ = cls_output.max(1, keepdim=True)
else:
conf = torch.ones((cls_output.shape[0], 1, cls_output.shape[2]), device=cls_output.device)
# Concatenate
y = torch.cat([box, conf, cls_output], axis=1)
# Split to 3 channels
outputs = []
start, end = 0, 0
for i, xi in enumerate(x):
end += xi.shape[-2]*xi.shape[-1]
outputs.append(y[:, :, start:end].view(xi.shape[0], -1, xi.shape[-2], xi.shape[-1]))
start += xi.shape[-2]*xi.shape[-1]

# Detection part
outputs = super().forward(x)

# Pose part
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
pred_kpt = self.kpts_decode(bs, kpt)
outputs.append(pred_kpt)

Expand All @@ -540,68 +455,23 @@ def kpts_decode(self, bs, kpts):
return a.view(bs, self.nk, -1)


class SegmentV8(nn.Module):
class SegmentV8(DetectV8):
"""YOLOv8 Segment head for segmentation models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init

def __init__(self, old_segment, use_rvc2):
super().__init__()
self.nc = old_segment.nc # number of classes
self.nl = old_segment.nl # number of detection layers
self.reg_max = old_segment.reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = old_segment.no # number of outputs per anchor
self.stride = old_segment.stride # strides computed during build

self.cv2 = old_segment.cv2
self.cv3 = old_segment.cv3
self.dfl = old_segment.dfl
self.f = old_segment.f
self.i = old_segment.i

super().__init__(old_segment, use_rvc2)
self.nm = old_segment.nm # number of masks
self.npr = old_segment.npr # number of protos
self.proto = old_segment.proto # protos
self.detect = old_segment.detect

self.cv4 = old_segment.cv4

self.use_rvc2 = use_rvc2

def forward(self, x):
p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size

mc = [self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)] # mask coefficients
# ---------------------
# Detection part
# ---------------------
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)

box, cls = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
box = self.dfl(box)
cls_output = cls.sigmoid()
# Get the max
if self.use_rvc2:
conf, _ = cls_output.max(1, keepdim=True)
else:
conf = torch.ones((cls_output.shape[0], 1, cls_output.shape[2]), device=cls_output.device)
# Concatenate
y = torch.cat([box, conf, cls_output], axis=1)
# Split to 3 channels
outputs = []
start, end = 0, 0
for i, xi in enumerate(x):
end += xi.shape[-2]*xi.shape[-1]
outputs.append(mc[i].view(xi.shape[0], -1, xi.shape[-2], xi.shape[-1]))
outputs.append(y[:, :, start:end].view(xi.shape[0], -1, xi.shape[-2], xi.shape[-1]))
start += xi.shape[-2]*xi.shape[-1]

outputs.append(p)
outputs = super().forward(x)
# Masks
outputs.extend([self.cv4[i](x[i]) for i in range(self.nl)])
# Mask protos
outputs.append(self.proto(x[0]))

return outputs

Expand All @@ -626,6 +496,7 @@ def forward(self, x):
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
return x


class DetectV10(DetectV8):
"""YOLOv10 Detect head for detection models."""
def __init__(self, old_detect, use_rvc2):
Expand Down
2 changes: 1 addition & 1 deletion tools/yolo/yolov8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_output_names(mode: int) -> List[str]:
if mode == DETECT_MODE:
return ["output1_yolov8", "output2_yolov8", "output3_yolov8"]
elif mode == SEGMENT_MODE:
return ["output1_masks", "output1_yolov8", "output2_masks", "output2_yolov8", "output3_masks", "output3_yolov8", "protos_output"]
return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "output1_masks", "output2_masks", "output3_masks", "protos_output"]
elif mode == OBB_MODE:
return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "angle_output"]
elif mode == POSE_MODE:
Expand Down

0 comments on commit bd8010b

Please sign in to comment.