diff --git a/mmaction/visualization/action_visualizer.py b/mmaction/visualization/action_visualizer.py index 7a3bfab85e..37bb10c17f 100644 --- a/mmaction/visualization/action_visualizer.py +++ b/mmaction/visualization/action_visualizer.py @@ -214,7 +214,7 @@ def add_datasample(self, texts = ['Frame %d of total %d frames' % (frame_idx, tol_video)] self.set_image(frame) - if draw_gt and 'gt_labels' in data_sample: + if draw_gt and 'gt_label' in data_sample: gt_labels = data_sample.gt_label idx = gt_labels.tolist() class_labels = [''] * len(idx) @@ -226,14 +226,15 @@ def add_datasample(self, prefix = 'Ground truth: ' texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) - if draw_pred and 'pred_labels' in data_sample: - pred_labels = data_sample.pred_labels - idx = pred_labels.item.tolist() + if draw_pred and 'pred_label' in data_sample: + pred_labels = data_sample.pred_label + idx = pred_labels.tolist() score_labels = [''] * len(idx) class_labels = [''] * len(idx) - if draw_score and 'score' in pred_labels: + if draw_score and 'pred_score' in data_sample: score_labels = [ - f', {pred_labels.score[i].item():.2f}' for i in idx + f', {data_sample.pred_score[i].item():.2f}' + for i in idx ] if classes is not None: