From 5d2680fac4b66836c59aed9482f04c09e1b96f9f Mon Sep 17 00:00:00 2001 From: jts250 Date: Thu, 21 Sep 2023 01:02:48 +0800 Subject: [PATCH 01/17] Support ASFormer --- configs/_base_/models/asformer.py | 12 + .../asformer/asformer_50salads1.py | 120 ++++ .../asformer/asformer_50salads2.py | 120 ++++ .../segmentation/asformer/asformer_gtea1.py | 121 ++++ .../segmentation/asformer/asformer_gtea2.py | 121 ++++ .../segmentation/asformer/asformer_gtea3.py | 121 ++++ .../segmentation/asformer/asformer_gtea4.py | 121 ++++ mmaction/datasets/__init__.py | 3 +- mmaction/datasets/action_segment_dataset.py | 65 ++ mmaction/datasets/transforms/__init__.py | 12 +- mmaction/datasets/transforms/formatting.py | 56 ++ mmaction/datasets/transforms/loading.py | 37 ++ mmaction/evaluation/metrics/__init__.py | 2 +- mmaction/evaluation/metrics/segment_metric.py | 196 ++++++ mmaction/models/__init__.py | 1 + mmaction/models/action_segmentors/__init__.py | 6 + mmaction/models/action_segmentors/asformer.py | 620 ++++++++++++++++++ mmaction/models/losses/__init__.py | 2 +- mmaction/models/losses/mse_loss.py | 36 + 19 files changed, 1764 insertions(+), 8 deletions(-) create mode 100644 configs/_base_/models/asformer.py create mode 100644 configs/segmentation/asformer/asformer_50salads1.py create mode 100644 configs/segmentation/asformer/asformer_50salads2.py create mode 100644 configs/segmentation/asformer/asformer_gtea1.py create mode 100644 configs/segmentation/asformer/asformer_gtea2.py create mode 100644 configs/segmentation/asformer/asformer_gtea3.py create mode 100644 configs/segmentation/asformer/asformer_gtea4.py create mode 100644 mmaction/datasets/action_segment_dataset.py create mode 100644 mmaction/evaluation/metrics/segment_metric.py create mode 100644 mmaction/models/action_segmentors/__init__.py create mode 100644 mmaction/models/action_segmentors/asformer.py create mode 100644 mmaction/models/losses/mse_loss.py diff --git a/configs/_base_/models/asformer.py b/configs/_base_/models/asformer.py new file mode 100644 index 0000000000..6468449ada --- /dev/null +++ b/configs/_base_/models/asformer.py @@ -0,0 +1,12 @@ +# model settings +model = dict( + type='ASFormer', + num_layers=10, + num_f_maps=64, + input_dim=2048, + num_decoders=3, + num_classes=11, + channel_masking_rate=0.5, + sample_rate=1, + r1=2, + r2=2) diff --git a/configs/segmentation/asformer/asformer_50salads1.py b/configs/segmentation/asformer/asformer_50salads1.py new file mode 100644 index 0000000000..77ee1f0f16 --- /dev/null +++ b/configs/segmentation/asformer/asformer_50salads1.py @@ -0,0 +1,120 @@ +_base_ = [ + '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' +] # dataset settings +dataset_type = 'ActionSegmentDataset' +data_root = 'data/action_seg/50salads/' +data_root_val = 'data/action_seg/50salads/' +ann_file_train = 'data/action_seg/50salads/splits/train.split1.bundle' +ann_file_val = 'data/action_seg/50salads/splits/test.split1.bundle' +ann_file_test = 'data/action_seg/50salads/splits/test.split1.bundle' + +model = dict( + type='ASFormer', + num_layers=10, + num_f_maps=64, + input_dim=2048, + num_decoders=3, + num_classes=19, + channel_masking_rate=0.3, + sample_rate=2, + r1=2, + r2=2) + +train_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=( + 'num_classes', + 'actions_dict', + 'index2label', + 'ground_truth', + 'classes', + )) +] + +val_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +test_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 120 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=0, + val_interval=10) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 80, + 100, + ], + gamma=0.5) +] +work_dir = './work_dirs/50salads1/' +test_evaluator = dict( + type='SegmentMetric', + metric_type='ALL', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator +default_hooks = dict(checkpoint=dict(interval=10, max_keep_ckpts=3)) diff --git a/configs/segmentation/asformer/asformer_50salads2.py b/configs/segmentation/asformer/asformer_50salads2.py new file mode 100644 index 0000000000..f83569c14e --- /dev/null +++ b/configs/segmentation/asformer/asformer_50salads2.py @@ -0,0 +1,120 @@ +_base_ = [ + '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' +] # dataset settings +dataset_type = 'ActionSegmentDataset' +data_root = 'data/action_seg/50salads/' +data_root_val = 'data/action_seg/50salads/' +ann_file_train = 'data/action_seg/50salads/splits/train.split2.bundle' +ann_file_val = 'data/action_seg/50salads/splits/test.split2.bundle' +ann_file_test = 'data/action_seg/50salads/splits/test.split2.bundle' + +model = dict( + type='ASFormer', + num_layers=10, + num_f_maps=64, + input_dim=2048, + num_decoders=3, + num_classes=19, + channel_masking_rate=0.3, + sample_rate=2, + r1=2, + r2=2) + +train_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=( + 'num_classes', + 'actions_dict', + 'index2label', + 'ground_truth', + 'classes', + )) +] + +val_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +test_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 120 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=0, + val_interval=10) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 80, + 100, + ], + gamma=0.5) +] +work_dir = './work_dirs/50salads2/' +test_evaluator = dict( + type='SegmentMetric', + metric_type='ALL', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator +default_hooks = dict(checkpoint=dict(interval=10, max_keep_ckpts=3)) diff --git a/configs/segmentation/asformer/asformer_gtea1.py b/configs/segmentation/asformer/asformer_gtea1.py new file mode 100644 index 0000000000..fc0315e1cc --- /dev/null +++ b/configs/segmentation/asformer/asformer_gtea1.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' +] # dataset settings +dataset_type = 'ActionSegmentDataset' +data_root = 'data/action_seg/gtea/' +data_root_val = 'data/action_seg/gtea/' +ann_file_train = 'data/action_seg/gtea/splits/train.split1.bundle' +ann_file_val = 'data/action_seg/gtea/splits/test.split1.bundle' +ann_file_test = 'data/action_seg/gtea/splits/test.split1.bundle' + +train_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=( + 'num_classes', + 'actions_dict', + 'index2label', + 'ground_truth', + 'classes', + )) +] + +val_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +test_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 120 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=0, + val_interval=5) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) +''' +param_scheduler = [ + dict( + monitor= 'F1@50', + param_name='lr', + type='ReduceOnPlateauParamScheduler', + rule='less', + factor=0.5, + patience=3,#33 + verbose=True) +] +''' +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 80, + 100, + ], + gamma=0.5) +] + +work_dir = './work_dirs/gtea1/' +test_evaluator = dict( + type='SegmentMetric', + metric_type='ALL', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator +default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=3)) diff --git a/configs/segmentation/asformer/asformer_gtea2.py b/configs/segmentation/asformer/asformer_gtea2.py new file mode 100644 index 0000000000..2bb9cb7fbf --- /dev/null +++ b/configs/segmentation/asformer/asformer_gtea2.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' +] # dataset settings +dataset_type = 'ActionSegmentDataset' +data_root = 'data/action_seg/gtea/' +data_root_val = 'data/action_seg/gtea/' +ann_file_train = 'data/action_seg/gtea/splits/train.split2.bundle' +ann_file_val = 'data/action_seg/gtea/splits/test.split2.bundle' +ann_file_test = 'data/action_seg/gtea/splits/test.split2.bundle' + +train_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=( + 'num_classes', + 'actions_dict', + 'index2label', + 'ground_truth', + 'classes', + )) +] + +val_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +test_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 120 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=0, + val_interval=5) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) +''' +param_scheduler = [ + dict( + monitor= 'F1@50', + param_name='lr', + type='ReduceOnPlateauParamScheduler', + rule='less', + factor=0.5, + patience=3,#33 + verbose=True) +] +''' +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 80, + 100, + ], + gamma=0.5) +] + +work_dir = './work_dirs/asformer_gtea2/' +test_evaluator = dict( + type='SegmentMetric', + metric_type='ALL', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator +default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=6)) diff --git a/configs/segmentation/asformer/asformer_gtea3.py b/configs/segmentation/asformer/asformer_gtea3.py new file mode 100644 index 0000000000..8bad7fdf93 --- /dev/null +++ b/configs/segmentation/asformer/asformer_gtea3.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' +] # dataset settings +dataset_type = 'ActionSegmentDataset' +data_root = 'data/action_seg/gtea/' +data_root_val = 'data/action_seg/gtea/' +ann_file_train = 'data/action_seg/gtea/splits/train.split3.bundle' +ann_file_val = 'data/action_seg/gtea/splits/test.split3.bundle' +ann_file_test = 'data/action_seg/gtea/splits/test.split3.bundle' + +train_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=( + 'num_classes', + 'actions_dict', + 'index2label', + 'ground_truth', + 'classes', + )) +] + +val_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +test_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 120 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=0, + val_interval=5) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) +''' +param_scheduler = [ + dict( + monitor= 'F1@50', + param_name='lr', + type='ReduceOnPlateauParamScheduler', + rule='less', + factor=0.5, + patience=3,#33 + verbose=True) +] +''' +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 80, + 100, + ], + gamma=0.5) +] + +work_dir = './work_dirs/asformer_gtea3/' +test_evaluator = dict( + type='SegmentMetric', + metric_type='ALL', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator +default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=6)) diff --git a/configs/segmentation/asformer/asformer_gtea4.py b/configs/segmentation/asformer/asformer_gtea4.py new file mode 100644 index 0000000000..e531fc3d26 --- /dev/null +++ b/configs/segmentation/asformer/asformer_gtea4.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' +] # dataset settings +dataset_type = 'ActionSegmentDataset' +data_root = 'data/action_seg/gtea/' +data_root_val = 'data/action_seg/gtea/' +ann_file_train = 'data/action_seg/gtea/splits/train.split4.bundle' +ann_file_val = 'data/action_seg/gtea/splits/test.split4.bundle' +ann_file_test = 'data/action_seg/gtea/splits/test.split4.bundle' + +train_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=( + 'num_classes', + 'actions_dict', + 'index2label', + 'ground_truth', + 'classes', + )) +] + +val_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +test_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 120 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=0, + val_interval=5) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) +''' +param_scheduler = [ + dict( + monitor= 'F1@50', + param_name='lr', + type='ReduceOnPlateauParamScheduler', + rule='less', + factor=0.5, + patience=3,#33 + verbose=True) +] +''' +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 80, + 100, + ], + gamma=0.5) +] + +work_dir = './work_dirs/asformer_gtea4/' +test_evaluator = dict( + type='SegmentMetric', + metric_type='ALL', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator +default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=6)) diff --git a/mmaction/datasets/__init__.py b/mmaction/datasets/__init__.py index cc838f8f31..36d384ac0d 100644 --- a/mmaction/datasets/__init__.py +++ b/mmaction/datasets/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .action_segment_dataset import ActionSegmentDataset from .activitynet_dataset import ActivityNetDataset from .audio_dataset import AudioDataset from .ava_dataset import AVADataset, AVAKineticsDataset @@ -15,5 +16,5 @@ 'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset', 'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset', 'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset', - 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC' + 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC', 'ActionSegmentDataset' ] diff --git a/mmaction/datasets/action_segment_dataset.py b/mmaction/datasets/action_segment_dataset.py new file mode 100644 index 0000000000..120a511dc6 --- /dev/null +++ b/mmaction/datasets/action_segment_dataset.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +from mmengine.fileio import exists + +from mmaction.registry import DATASETS +from mmaction.utils import ConfigType +from .base import BaseActionDataset + + +@DATASETS.register_module() +class ActionSegmentDataset(BaseActionDataset): + + def __init__(self, + ann_file: str, + pipeline: List[Union[dict, Callable]], + data_prefix: Optional[ConfigType] = dict(video=''), + test_mode: bool = False, + **kwargs): + + super().__init__( + ann_file, + pipeline=pipeline, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation file to get video information.""" + exists(self.ann_file) + file_ptr = open(self.ann_file, 'r') # read bundle + list_of_examples = file_ptr.read().split('\n')[:-1] + file_ptr.close() + gts = [ + self.data_prefix['video'] + 'groundTruth/' + vid + for vid in list_of_examples + ] + features_npy = [ + self.data_prefix['video'] + 'features/' + vid.split('.')[0] + + '.npy' for vid in list_of_examples + ] + data_list = [] + + file_ptr_d = open(self.data_prefix['video'] + '/mapping.txt', 'r') + actions = file_ptr_d.read().split('\n')[:-1] + file_ptr.close() + actions_dict = dict() + for a in actions: + actions_dict[a.split()[1]] = int(a.split()[0]) + index2label = dict() + for k, v in actions_dict.items(): + index2label[v] = k + num_classes = len(actions_dict) + + # gts:txt list of examples:txt features_npy:npy + for idx, feature in enumerate(features_npy): + video_info = dict() + feature_path = features_npy[idx] + video_info['feature_path'] = feature_path + video_info['actions_dict'] = actions_dict + video_info['index2label'] = index2label + video_info['ground_truth_path'] = gts[idx] + video_info['num_classes'] = num_classes + data_list.append(video_info) + return data_list diff --git a/mmaction/datasets/transforms/__init__.py b/mmaction/datasets/transforms/__init__.py index 3d1ee91e27..d08316ba9e 100644 --- a/mmaction/datasets/transforms/__init__.py +++ b/mmaction/datasets/transforms/__init__.py @@ -1,13 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. from .formatting import (FormatAudioShape, FormatGCNInput, FormatShape, - PackActionInputs, PackLocalizationInputs, Transpose) + PackActionInputs, PackLocalizationInputs, + PackSegmentationInputs, Transpose) from .loading import (ArrayDecode, AudioFeatureSelector, BuildPseudoClip, DecordDecode, DecordInit, DenseSampleFrames, GenerateLocalizationLabels, ImageDecode, LoadAudioFeature, LoadHVULabel, LoadLocalizationFeature, - LoadProposals, LoadRGBFromFile, OpenCVDecode, OpenCVInit, - PIMSDecode, PIMSInit, PyAVDecode, PyAVDecodeMotionVector, - PyAVInit, RawFrameDecode, SampleAVAFrames, SampleFrames, + LoadProposals, LoadRGBFromFile, LoadSegmentationFeature, + OpenCVDecode, OpenCVInit, PIMSDecode, PIMSInit, + PyAVDecode, PyAVDecodeMotionVector, PyAVInit, + RawFrameDecode, SampleAVAFrames, SampleFrames, UniformSample, UntrimmedSampleFrames) from .pose_transforms import (DecompressPose, GeneratePoseTarget, GenSkeFeat, JointToBone, MergeSkeFeat, MMCompact, MMDecode, @@ -37,5 +39,5 @@ 'SampleAVAFrames', 'SampleFrames', 'TenCrop', 'ThreeCrop', 'ToMotion', 'TorchVisionWrapper', 'Transpose', 'UniformSample', 'UniformSampleFrames', 'UntrimmedSampleFrames', 'MMUniformSampleFrames', 'MMDecode', 'MMCompact', - 'CLIPTokenize' + 'CLIPTokenize', 'LoadSegmentationFeature', 'PackSegmentationInputs' ] diff --git a/mmaction/datasets/transforms/formatting.py b/mmaction/datasets/transforms/formatting.py index a8e9b9ab82..0a76e0d3d3 100644 --- a/mmaction/datasets/transforms/formatting.py +++ b/mmaction/datasets/transforms/formatting.py @@ -169,6 +169,62 @@ def __repr__(self) -> str: return repr_str +@TRANSFORMS.register_module() +class PackSegmentationInputs(BaseTransform): + + def __init__(self, keys=(), meta_keys=('video_name', )): + self.keys = keys + self.meta_keys = meta_keys + + def transform(self, results): + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_samples' (obj:`DetDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'raw_feature' in results: + raw_feature = results['raw_feature'] + packed_results['inputs'] = to_tensor(raw_feature) + else: + raise ValueError('Cannot get "raw_feature" in the input ' + 'dict of `PackSegmentationInputs`.') + + data_sample = ActionDataSample() + for key in self.keys: + if key not in results: + continue + if key == 'classes': + instance_data = InstanceData() + instance_data[key] = to_tensor(results[key]) + data_sample.gt_instances = instance_data + elif key == 'proposals': + instance_data = InstanceData() + instance_data[key] = to_tensor(results[key]) + data_sample.proposals = instance_data + else: + raise NotImplementedError( + f"Key '{key}' is not supported in `PackSegmentationInputs`" + ) + + img_meta = {k: results[k] for k in self.meta_keys if k in results} + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str + + @TRANSFORMS.register_module() class Transpose(BaseTransform): """Transpose image channels to a given order. diff --git a/mmaction/datasets/transforms/loading.py b/mmaction/datasets/transforms/loading.py index 8d789ab4c3..c5101dcdae 100644 --- a/mmaction/datasets/transforms/loading.py +++ b/mmaction/datasets/transforms/loading.py @@ -1784,6 +1784,43 @@ def __repr__(self) -> str: return repr_str +@TRANSFORMS.register_module() +class LoadSegmentationFeature(BaseTransform): + """Load Video features for localizer with given video_name list. + + The required key are "feature_path", "ground_truth_path", + added or modified keys are "actions_dict", "raw_feature", + "ground_truth", "classes". + + Args: + raw_feature_ext (str): Raw feature file extension. Default: '.csv'. + """ + + def transform(self, results): + """Perform the LoadSegmentationFeature loading. + + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + raw_feature = np.load(results['feature_path']) + file_ptr = open(results['ground_truth_path'], 'r') + content = file_ptr.read().split('\n')[:-1] + classes = np.zeros(min(np.shape(raw_feature)[1], len(content))) + for i in range(len(classes)): + classes[i] = results['actions_dict'][content[i]] + + results['raw_feature'] = raw_feature + results['ground_truth'] = content + results['classes'] = classes + + return results + + def __repr__(self): + repr_str = f'{self.__class__.__name__}' + return repr_str + + @TRANSFORMS.register_module() class LoadLocalizationFeature(BaseTransform): """Load Video features for localizer with given video_name list. diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py index 341ec577ce..2655a2f0e9 100644 --- a/mmaction/evaluation/metrics/__init__.py +++ b/mmaction/evaluation/metrics/__init__.py @@ -9,5 +9,5 @@ __all__ = [ 'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix', 'MultiSportsMetric', 'RetrievalMetric', 'VQAAcc', 'ReportVQA', 'VQAMCACC', - 'RetrievalRecall' + 'RetrievalRecall', 'SegmentMetric' ] diff --git a/mmaction/evaluation/metrics/segment_metric.py b/mmaction/evaluation/metrics/segment_metric.py new file mode 100644 index 0000000000..6e707b7373 --- /dev/null +++ b/mmaction/evaluation/metrics/segment_metric.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import Any, Optional, Sequence, Tuple + +import numpy as np +from mmengine.evaluator import BaseMetric + +from mmaction.registry import METRICS +from mmaction.utils import ConfigType + + +@METRICS.register_module() +class SegmentMetric(BaseMetric): + """Action Segmentation dataset evaluation metric.""" + + def __init__(self, + metric_type: str = 'TEM', + collect_device: str = 'cpu', + prefix: Optional[str] = None, + metric_options: dict = {}, + dump_config: ConfigType = dict(out='')): + super().__init__(collect_device=collect_device, prefix=prefix) + self.metric_type = metric_type + + assert 'out' in dump_config + self.output_format = dump_config.pop('output_format', 'csv') + self.out = dump_config['out'] + + self.metric_options = metric_options + if self.metric_type == 'AR@AN': + self.ground_truth = {} + + def process(self, data_batch: Sequence[Tuple[Any, dict]], + predictions: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[Tuple[Any, dict]]): A batch of data + from the dataloader. + predictions (Sequence[dict]): A batch of outputs from + the model. + """ + for pred in predictions: + self.results.append(pred) + + if self.metric_type == 'ALL': + data_batch = data_batch['data_samples'] + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + if self.metric_type == 'ALL': + return self.compute_ALL(results) + return OrderedDict() + + def compute_ALL(self, results: list) -> dict: + """ALL evaluation metric.""" + eval_results = OrderedDict() + overlap = [.1, .25, .5] + tp, fp, fn = np.zeros(3), np.zeros(3), np.zeros(3) + + correct = 0 + total = 0 + edit = 0 + + for vid in self.results: + + gt_content = vid['ground'] + recog_content = vid['recognition'] + + for i in range(len(gt_content)): + total += 1 + if gt_content[i] == recog_content[i]: + correct += 1 + + edit += self.edit_score(recog_content, gt_content) + + for s in range(len(overlap)): + tp1, fp1, fn1 = self.f_score(recog_content, gt_content, + overlap[s]) + tp[s] += tp1 + fp[s] += fp1 + fn[s] += fn1 + eval_results['Acc'] = 100 * float(correct) / total + eval_results['Edit'] = (1.0 * edit) / len(self.results) + f1s = np.array([0, 0, 0], dtype=float) + for s in range(len(overlap)): + precision = tp[s] / float(tp[s] + fp[s]) + recall = tp[s] / float(tp[s] + fn[s]) + + f1 = 2.0 * (precision * recall) / (precision + recall) + + f1 = np.nan_to_num(f1) * 100 + f1s[s] = f1 + + eval_results['F1@10'] = f1s[0] + eval_results['F1@25'] = f1s[1] + eval_results['F1@50'] = f1s[2] + + return eval_results + + def f_score(self, + recognized, + ground_truth, + overlap, + bg_class=['background']): + p_label, p_start, p_end = self.get_labels_start_end_time( + recognized, bg_class) + y_label, y_start, y_end = self.get_labels_start_end_time( + ground_truth, bg_class) + + tp = 0 + fp = 0 + + hits = np.zeros(len(y_label)) + + for j in range(len(p_label)): + intersection = np.minimum(p_end[j], y_end) - np.maximum( + p_start[j], y_start) + union = np.maximum(p_end[j], y_end) - np.minimum( + p_start[j], y_start) + IoU = (1.0 * intersection / union) * ( + [p_label[j] == y_label[x] for x in range(len(y_label))]) + # Get the best scoring segment + idx = np.array(IoU).argmax() + + if IoU[idx] >= overlap and not hits[idx]: + tp += 1 + hits[idx] = 1 + else: + fp += 1 + fn = len(y_label) - sum(hits) + return float(tp), float(fp), float(fn) + + def edit_score(self, + recognized, + ground_truth, + norm=True, + bg_class=['background']): + P, _, _ = self.get_labels_start_end_time(recognized, bg_class) + Y, _, _ = self.get_labels_start_end_time(ground_truth, bg_class) + return self.levenstein(P, Y, norm) + + def get_labels_start_end_time(self, + frame_wise_labels, + bg_class=['background']): + labels = [] + starts = [] + ends = [] + last_label = frame_wise_labels[0] + if frame_wise_labels[0] not in bg_class: + labels.append(frame_wise_labels[0]) + starts.append(0) + for i in range(len(frame_wise_labels)): + if frame_wise_labels[i] != last_label: + if frame_wise_labels[i] not in bg_class: + labels.append(frame_wise_labels[i]) + starts.append(i) + if last_label not in bg_class: + ends.append(i) + last_label = frame_wise_labels[i] + if last_label not in bg_class: + ends.append(i) + return labels, starts, ends + + def levenstein(self, p, y, norm=False): + m_row = len(p) + n_col = len(y) + D = np.zeros([m_row + 1, n_col + 1], np.float64) + for i in range(m_row + 1): + D[i, 0] = i + for i in range(n_col + 1): + D[0, i] = i + + for j in range(1, n_col + 1): + for i in range(1, m_row + 1): + if y[j - 1] == p[i - 1]: + D[i, j] = D[i - 1, j - 1] + else: + D[i, j] = min(D[i - 1, j] + 1, D[i, j - 1] + 1, + D[i - 1, j - 1] + 1) + + if norm: + score = (1 - D[-1, -1] / max(m_row, n_col)) * 100 + else: + score = D[-1, -1] + + return score diff --git a/mmaction/models/__init__.py b/mmaction/models/__init__.py index 08f7d41f52..4b0436e1d9 100644 --- a/mmaction/models/__init__.py +++ b/mmaction/models/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .action_segmentors import * # noqa: F401,F403 from .backbones import * # noqa: F401,F403 from .common import * # noqa: F401,F403 from .data_preprocessors import * # noqa: F401,F403 diff --git a/mmaction/models/action_segmentors/__init__.py b/mmaction/models/action_segmentors/__init__.py new file mode 100644 index 0000000000..eeb81b53e9 --- /dev/null +++ b/mmaction/models/action_segmentors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .asformer import ASFormer + +__all__ = [ + 'ASFormer', +] diff --git a/mmaction/models/action_segmentors/asformer.py b/mmaction/models/action_segmentors/asformer.py new file mode 100644 index 0000000000..b83639cb4d --- /dev/null +++ b/mmaction/models/action_segmentors/asformer.py @@ -0,0 +1,620 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel + +from mmaction.registry import MODELS + + +@MODELS.register_module() +class ASFormer(BaseModel): + """Boundary Matching Network for temporal action proposal generation.""" + + def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, + num_classes, channel_masking_rate, sample_rate): + super().__init__() + self.model = MyTransformer(3, num_layers, r1, r2, num_f_maps, + input_dim, num_classes, + channel_masking_rate) + print('Model Size: ', sum(p.numel() for p in self.model.parameters())) + self.num_classes = num_classes + self.mse = MODELS.build(dict(type='MeanSquareErrorLoss')) + self.ce = MODELS.build(dict(type='CrossEntropyLoss')) + + def init_weights(self) -> None: + """Initiate the parameters from scratch.""" + pass + + def forward(self, inputs, data_samples, mode, **kwargs): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: + + - ``tensor``: Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - ``predict``: Forward and return the predictions, which are fully + processed to a list of :obj:`ActionDataSample`. + - ``loss``: Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[:obj:`ActionDataSample`], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to ``tensor``. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of ``ActionDataSample``. + - If ``mode="loss"``, return a dict of tensor. + """ + input = torch.stack(inputs) + if mode == 'tensor': + return self._forward(inputs, **kwargs) + if mode == 'predict': + return self.predict(input, data_samples, **kwargs) + elif mode == 'loss': + return self.loss(input, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + def loss(self, batch_inputs, batch_data_samples, **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Raw Inputs of the recognizer. + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`ActionDataSample`]): The batch + data samples. It usually includes information such + as ``gt_labels``. + + Returns: + dict: A dictionary of loss components. + """ + device = batch_inputs.device + batch_target_tensor = torch.ones( + len(batch_inputs), + max(tensor.size(1) for tensor in batch_inputs), + dtype=torch.long) * (-100) + mask = torch.zeros( + len(batch_inputs), + self.num_classes, + max(tensor.size(1) for tensor in batch_inputs), + dtype=torch.float) + for i in range(len(batch_inputs)): + batch_target_tensor[i, + :np.shape(batch_data_samples[i].classes)[0]] \ + = torch.from_numpy(batch_data_samples[i].classes) + + mask[i, i, :np.shape(batch_data_samples[i].classes)[0]] = \ + torch.ones(self.num_classes, + np.shape(batch_data_samples[i].classes)[0]) + + batch_target_tensor = batch_target_tensor.to(device) + batch_target_tensor = batch_target_tensor.to(device) + mask = mask.to(device) + batch_inputs = batch_inputs.to(device) + ps = self.model(batch_inputs, mask) + loss = 0 + for p in ps: + loss += self.ce( + p.transpose(2, 1).contiguous().view(-1, self.num_classes), + batch_target_tensor.view(-1), + ignore_index=-100) + loss += 0.15 * torch.mean( + torch.clamp( + self.mse( + F.log_softmax(p[:, :, 1:], dim=1), + F.log_softmax(p.detach()[:, :, :-1], dim=1)), + min=0, + max=16) * mask[:, :, 1:]) + + loss_dict = dict(loss=loss) + return loss_dict + + def predict(self, batch_inputs, batch_data_samples, **kwargs): + """Define the computation performed at every call when testing.""" + device = batch_inputs.device + actions_dict = batch_data_samples[0].actions_dict + batch_target_tensor = torch.ones( + len(batch_inputs), + max(tensor.size(1) for tensor in batch_inputs), + dtype=torch.long) * (-100) + batch_target = [ + data_sample.classes for data_sample in batch_data_samples + ] + mask = torch.zeros( + len(batch_inputs), + self.num_classes, + max(tensor.size(1) for tensor in batch_inputs), + dtype=torch.float) + for i in range(len(batch_inputs)): + batch_target_tensor[i, :np.shape(batch_data_samples[i].classes + )[0]] = torch.from_numpy( + batch_data_samples[i].classes) + mask[i, :, :np. + shape(batch_data_samples[i].classes)[0]] = torch.ones( + self.num_classes, + np.shape(batch_data_samples[i].classes)[0]) + batch_target_tensor = batch_target_tensor.to(device) + mask = mask.to(device) + batch_inputs = batch_inputs.to(device) + predictions = self.model(batch_inputs, mask) + for i in range(len(predictions)): + confidence, predicted = torch.max( + F.softmax(predictions[i], dim=1).data, 1) + confidence, predicted = confidence.squeeze(), predicted.squeeze() + confidence, predicted = confidence.squeeze(), predicted.squeeze() + recognition = [] + ground = [ + batch_data_samples[0].index2label[idx] for idx in batch_target[0] + ] + for i in range(len(predicted)): + recognition = np.concatenate((recognition, [ + list(actions_dict.keys())[list(actions_dict.values()).index( + predicted[i].item())] + ])) + output = [dict(ground=ground, recognition=recognition)] + return output + + def _forward(self, x): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ + print(x.shape) + + return x.shape + + +def exponential_descrease(idx_decoder, p=3): + return math.exp(-p * idx_decoder) + + +class AttentionHelper(nn.Module): + + def __init__(self): + super(AttentionHelper, self).__init__() + self.softmax = nn.Softmax(dim=-1) + + def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask): + """scalar dot attention. + + :param proj_query: shape of (B, C, L) => + (Batch_Size, Feature_Dimension, Length) + :param proj_key: shape of (B, C, L) + :param proj_val: shape of (B, C, L) + :param padding_mask: shape of (B, C, L) + :return: attention value of shape (B, C, L) + """ + m, c1, l1 = proj_query.shape + m, c2, l2 = proj_key.shape + + assert c1 == c2 + + energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key) + attention = energy / np.sqrt(c1) + attention = attention + torch.log(padding_mask + 1e-6) + attention = self.softmax(attention) + attention = attention * padding_mask + attention = attention.permute(0, 2, 1) + out = torch.bmm(proj_val, attention) + return out, attention + + +class AttLayer(nn.Module): + + def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, + att_type): # r1 = r2 (2) + super(AttLayer, self).__init__() + self.query_conv = nn.Conv1d( + in_channels=q_dim, out_channels=q_dim // r1, kernel_size=1) + self.key_conv = nn.Conv1d( + in_channels=k_dim, out_channels=k_dim // r2, kernel_size=1) + self.value_conv = nn.Conv1d( + in_channels=v_dim, out_channels=v_dim // r3, kernel_size=1) + + self.conv_out = nn.Conv1d( + in_channels=v_dim // r3, out_channels=v_dim, kernel_size=1) + + self.bl = bl + self.stage = stage + self.att_type = att_type + assert self.att_type in ['normal_att', 'block_att', 'sliding_att'] + assert self.stage in ['encoder', 'decoder'] + + self.att_helper = AttentionHelper() + self.window_mask = self.construct_window_mask() + + def construct_window_mask(self): + """construct window mask of shape (1, l, l + l//2 + l//2), used for + sliding window self attention.""" + window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2))) + for i in range(self.bl): + window_mask[:, i, i:i + self.bl] = 1 + return window_mask + + def forward(self, x1, x2, mask): + query = self.query_conv(x1) + key = self.key_conv(x1) + + if self.stage == 'decoder': + assert x2 is not None + value = self.value_conv(x2) + else: + value = self.value_conv(x1) + + if self.att_type == 'normal_att': + return self._normal_self_att(query, key, value, mask) + elif self.att_type == 'block_att': + return self._block_wise_self_att(query, key, value, mask) + elif self.att_type == 'sliding_att': + return self._sliding_window_self_att(query, key, value, mask) + + def _normal_self_att(self, q, k, v, mask): + device = q.device + m_batchsize, c1, L = q.size() + _, c2, L = k.size() + _, c3, L = v.size() + padding_mask = torch.ones( + (m_batchsize, 1, L)).to(device) * mask[:, 0:1, :] + output, attentions = self.att_helper.scalar_dot_att( + q, k, v, padding_mask) + output = self.conv_out(F.relu(output)) + output = output[:, :, 0:L] + return output * mask[:, 0:1, :] + + def _block_wise_self_att(self, q, k, v, mask): + device = q.device + m_batchsize, c1, L = q.size() + _, c2, L = k.size() + _, c3, L = v.size() + + nb = L // self.bl + if L % self.bl != 0: + q = torch.cat([ + q, + torch.zeros( + (m_batchsize, c1, self.bl - L % self.bl)).to(device) + ], + dim=-1) + k = torch.cat([ + k, + torch.zeros( + (m_batchsize, c2, self.bl - L % self.bl)).to(device) + ], + dim=-1) + v = torch.cat([ + v, + torch.zeros( + (m_batchsize, c3, self.bl - L % self.bl)).to(device) + ], + dim=-1) + nb += 1 + + padding_mask = torch.cat([ + torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], + torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device) + ], + dim=-1) + + q = q.reshape(m_batchsize, c1, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c1, + self.bl) + padding_mask = padding_mask.reshape( + m_batchsize, 1, nb, + self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb, 1, self.bl) + k = k.reshape(m_batchsize, c2, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c2, + self.bl) + v = v.reshape(m_batchsize, c3, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c3, + self.bl) + + output, attentions = self.att_helper.scalar_dot_att( + q, k, v, padding_mask) + output = self.conv_out(F.relu(output)) + + output = output.reshape(m_batchsize, nb, c3, self.bl).permute( + 0, 2, 1, 3).reshape(m_batchsize, c3, nb * self.bl) + output = output[:, :, 0:L] + return output * mask[:, 0:1, :] + + def _sliding_window_self_att(self, q, k, v, mask): + device = q.device + m_batchsize, c1, L = q.size() + _, c2, _ = k.size() + _, c3, _ = v.size() + nb = L // self.bl + if L % self.bl != 0: + q = torch.cat([ + q, + torch.zeros( + (m_batchsize, c1, self.bl - L % self.bl)).to(device) + ], + dim=-1) + k = torch.cat([ + k, + torch.zeros( + (m_batchsize, c2, self.bl - L % self.bl)).to(device) + ], + dim=-1) + v = torch.cat([ + v, + torch.zeros( + (m_batchsize, c3, self.bl - L % self.bl)).to(device) + ], + dim=-1) + nb += 1 + padding_mask = torch.cat([ + torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], + torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device) + ], + dim=-1) + q = q.reshape(m_batchsize, c1, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c1, + self.bl) + k = torch.cat([ + torch.zeros(m_batchsize, c2, self.bl // 2).to(device), k, + torch.zeros(m_batchsize, c2, self.bl // 2).to(device) + ], + dim=-1) + v = torch.cat([ + torch.zeros(m_batchsize, c3, self.bl // 2).to(device), v, + torch.zeros(m_batchsize, c3, self.bl // 2).to(device) + ], + dim=-1) + padding_mask = torch.cat([ + torch.zeros(m_batchsize, 1, self.bl // 2).to(device), padding_mask, + torch.zeros(m_batchsize, 1, self.bl // 2).to(device) + ], + dim=-1) + k = torch.cat([ + k[:, :, i * self.bl:(i + 1) * self.bl + (self.bl // 2) * 2] + for i in range(nb) + ], + dim=0) # special case when self.bl = 1 + v = torch.cat([ + v[:, :, i * self.bl:(i + 1) * self.bl + (self.bl // 2) * 2] + for i in range(nb) + ], + dim=0) + padding_mask = torch.cat([ + padding_mask[:, :, i * self.bl:(i + 1) * self.bl + + (self.bl // 2) * 2] for i in range(nb) + ], + dim=0) # of shape (m*nb, 1, 2l) + final_mask = self.window_mask.to(device).repeat( + m_batchsize * nb, 1, 1) * padding_mask + + output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask) + output = self.conv_out(F.relu(output)) + + output = output.reshape(m_batchsize, nb, -1, self.bl).permute( + 0, 2, 1, 3).reshape(m_batchsize, -1, nb * self.bl) + output = output[:, :, 0:L] + return output * mask[:, 0:1, :] + + +class MultiHeadAttLayer(nn.Module): + + def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, + num_head): + super(MultiHeadAttLayer, self).__init__() + # assert v_dim % num_head == 0 + self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1) + self.layers = nn.ModuleList([ + copy.deepcopy( + AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)) + for i in range(num_head) + ]) + self.dropout = nn.Dropout(p=0.5) + + def forward(self, x1, x2, mask): + out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1) + out = self.conv_out(self.dropout(out)) + return out + + +class ConvFeedForward(nn.Module): + + def __init__(self, dilation, in_channels, out_channels): + super(ConvFeedForward, self).__init__() + self.layer = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + 3, + padding=dilation, + dilation=dilation), nn.ReLU()) + + def forward(self, x): + return self.layer(x) + + +class FCFeedForward(nn.Module): + + def __init__(self, in_channels, out_channels): + super(FCFeedForward, self).__init__() + self.layer = nn.Sequential( + nn.Conv1d(in_channels, out_channels, 1), nn.ReLU(), nn.Dropout(), + nn.Conv1d(out_channels, out_channels, 1)) + + def forward(self, x): + return self.layer(x) + + +class AttModule(nn.Module): + + def __init__(self, dilation, in_channels, out_channels, r1, r2, att_type, + stage, alpha): + super(AttModule, self).__init__() + self.feed_forward = ConvFeedForward(dilation, in_channels, + out_channels) + self.instance_norm = nn.InstanceNorm1d( + in_channels, track_running_stats=False) + self.att_layer = AttLayer( + in_channels, + in_channels, + out_channels, + r1, + r1, + r2, + dilation, + att_type=att_type, + stage=stage) # dilation + self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) + self.dropout = nn.Dropout() + self.alpha = alpha + + def forward(self, x, f, mask): + out = self.feed_forward(x) + out = self.alpha * self.att_layer(self.instance_norm(out), f, + mask) + out + out = self.conv_1x1(out) + out = self.dropout(out) + return (x + out) * mask[:, 0:1, :] + + +class PositionalEncoding(nn.Module): + """Implement the PE function.""" + + def __init__(self, d_model, max_len=10000): + super(PositionalEncoding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).permute(0, 2, 1) # of shape (1, d_model, l) + self.pe = nn.Parameter(pe, requires_grad=True) + + # self.register_buffer('pe', pe) + + def forward(self, x): + return x + self.pe[:, :, 0:x.shape[2]] + + +class Encoder(nn.Module): + + def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, + channel_masking_rate, att_type, alpha): + super(Encoder, self).__init__() + self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) + self.layers = nn.ModuleList([ + AttModule(2**i, num_f_maps, num_f_maps, r1, r2, att_type, + 'encoder', alpha) for i in # 2**i + range(num_layers) + ]) + + self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) + self.dropout = nn.Dropout2d(p=channel_masking_rate) + self.channel_masking_rate = channel_masking_rate + + def forward(self, x, mask): + ''' + :param x: (N, C, L) + :param mask: + :return: + ''' + + if self.channel_masking_rate > 0: + x = x.unsqueeze(2) + x = self.dropout(x) + x = x.squeeze(2) + + feature = self.conv_1x1(x) + for layer in self.layers: + feature = layer(feature, None, mask) + + out = self.conv_out(feature) * mask[:, 0:1, :] + + return out, feature + + +class Decoder(nn.Module): + + def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, + att_type, alpha): + super(Decoder, self).__init__( + ) # self.position_en = PositionalEncoding(d_model=num_f_maps) + self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) + self.layers = nn.ModuleList([ + AttModule(2**i, num_f_maps, num_f_maps, r1, r2, att_type, + 'decoder', alpha) for i in # 2 ** i + range(num_layers) + ]) + self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) + + def forward(self, x, fencoder, mask): + + feature = self.conv_1x1(x) + for layer in self.layers: + feature = layer(feature, fencoder, mask) + + out = self.conv_out(feature) * mask[:, 0:1, :] + + return out, feature + + +class MyTransformer(nn.Module): + + def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, + num_classes, channel_masking_rate): + super(MyTransformer, self).__init__() + self.encoder = Encoder( + num_layers, + r1, + r2, + num_f_maps, + input_dim, + num_classes, + channel_masking_rate, + att_type='sliding_att', + alpha=1) + self.decoders = nn.ModuleList([ + copy.deepcopy( + Decoder( + num_layers, + r1, + r2, + num_f_maps, + num_classes, + num_classes, + att_type='sliding_att', + alpha=exponential_descrease(s))) + for s in range(num_decoders) + ]) # num_decoders + + def forward(self, x, mask): + out, feature = self.encoder(x, mask) + outputs = out.unsqueeze(0) + + for decoder in self.decoders: + out, feature = decoder( + F.softmax(out, dim=1) * mask[:, 0:1, :], + feature * mask[:, 0:1, :], mask) + outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) + + return outputs diff --git a/mmaction/models/losses/__init__.py b/mmaction/models/losses/__init__.py index 41afcb7ace..18accb15f9 100644 --- a/mmaction/models/losses/__init__.py +++ b/mmaction/models/losses/__init__.py @@ -12,5 +12,5 @@ __all__ = [ 'BaseWeightedLoss', 'CrossEntropyLoss', 'NLLLoss', 'BCELossWithLogits', 'BinaryLogisticRegressionLoss', 'BMNLoss', 'OHEMHingeLoss', 'SSNLoss', - 'HVULoss', 'CBFocalLoss' + 'HVULoss', 'CBFocalLoss', 'MeanSquareErrorLoss' ] diff --git a/mmaction/models/losses/mse_loss.py b/mmaction/models/losses/mse_loss.py new file mode 100644 index 0000000000..7f97bed123 --- /dev/null +++ b/mmaction/models/losses/mse_loss.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F + +from mmaction.registry import MODELS +from .base import BaseWeightedLoss + + +@MODELS.register_module() +class MeanSquareErrorLoss(BaseWeightedLoss): + """Mean Square Error Loss.""" + + def __init__(self, loss_weight: float = 1., reduction: str = 'none'): + super().__init__(loss_weight=loss_weight) + self.reduction = reduction + + def _forward(self, cls_score: torch.Tensor, label: torch.Tensor, + **kwargs) -> torch.Tensor: + """Forward function. + + Args: + cls_score (torch.Tensor): The class score. + label (torch.Tensor): The ground truth label. + kwargs: Any keyword argument to be used to calculate + MeanSquareError loss. + + Returns: + torch.Tensor: The returned MeanSquareError loss. + """ + if cls_score.size() == label.size(): + assert len(kwargs) == 0, \ + ('For now, no extra args are supported for soft label, ' + f'but get {kwargs}') + + loss_cls = F.mse_loss(cls_score, label, reduction=self.reduction) + return loss_cls From 233d79ef643a582e40ed4e196eab7a1c21e43801 Mon Sep 17 00:00:00 2001 From: jts250 Date: Thu, 21 Sep 2023 09:46:26 +0800 Subject: [PATCH 02/17] update metrics/__init__.py --- mmaction/evaluation/metrics/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py index 2655a2f0e9..75bfca5a2d 100644 --- a/mmaction/evaluation/metrics/__init__.py +++ b/mmaction/evaluation/metrics/__init__.py @@ -5,6 +5,7 @@ from .multimodal_metric import VQAMCACC, ReportVQA, RetrievalRecall, VQAAcc from .multisports_metric import MultiSportsMetric from .retrieval_metric import RetrievalMetric +from .segment_metric import SegmentMetric __all__ = [ 'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix', From 9e8eefaeb828b91e7ad881bbab41462e82fa0cc1 Mon Sep 17 00:00:00 2001 From: jts250 Date: Thu, 21 Sep 2023 09:57:42 +0800 Subject: [PATCH 03/17] update losses/__init__.py --- mmaction/models/losses/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmaction/models/losses/__init__.py b/mmaction/models/losses/__init__.py index 18accb15f9..b2ccee1c76 100644 --- a/mmaction/models/losses/__init__.py +++ b/mmaction/models/losses/__init__.py @@ -5,6 +5,7 @@ from .cross_entropy_loss import (BCELossWithLogits, CBFocalLoss, CrossEntropyLoss) from .hvu_loss import HVULoss +from .mse_loss import MeanSquareErrorLoss from .nll_loss import NLLLoss from .ohem_hinge_loss import OHEMHingeLoss from .ssn_loss import SSNLoss From 0d5af7a0bea72348f5bc39d4cf122c3465621da4 Mon Sep 17 00:00:00 2001 From: jts250 Date: Mon, 9 Oct 2023 21:43:44 +0800 Subject: [PATCH 04/17] update evaluation --- mmaction/evaluation/metrics/segment_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmaction/evaluation/metrics/segment_metric.py b/mmaction/evaluation/metrics/segment_metric.py index 6e707b7373..e812b2be35 100644 --- a/mmaction/evaluation/metrics/segment_metric.py +++ b/mmaction/evaluation/metrics/segment_metric.py @@ -14,7 +14,7 @@ class SegmentMetric(BaseMetric): """Action Segmentation dataset evaluation metric.""" def __init__(self, - metric_type: str = 'TEM', + metric_type: str = 'ALL', collect_device: str = 'cpu', prefix: Optional[str] = None, metric_options: dict = {}, From 9a86243e34604356e3631c9588b051f3918d63e6 Mon Sep 17 00:00:00 2001 From: jts250 Date: Mon, 9 Oct 2023 22:54:35 +0800 Subject: [PATCH 05/17] update readme --- configs/segmentation/asformer/README.md | 90 ++++++++++++++++++++++ configs/segmentation/asformer/metafile.yml | 26 +++++++ 2 files changed, 116 insertions(+) create mode 100644 configs/segmentation/asformer/README.md create mode 100644 configs/segmentation/asformer/metafile.yml diff --git a/configs/segmentation/asformer/README.md b/configs/segmentation/asformer/README.md new file mode 100644 index 0000000000..699b6a866a --- /dev/null +++ b/configs/segmentation/asformer/README.md @@ -0,0 +1,90 @@ +# ASFormer + +[ASFormer: Transformer for Action Segmentation](https://arxiv.org/pdf/2110.08568.pdf) + + + +## Abstract + + + +Algorithms for the action segmentation task typically use temporal models to predict +what action is occurring at each frame for a minute-long daily activity. Recent studies have shown the potential of Transformer in modeling the relations among elements +in sequential data. However, there are several major concerns when directly applying +the Transformer to the action segmentation task, such as the lack of inductive biases +with small training sets, the deficit in processing long input sequence, and the limitation of the decoder architecture to utilize temporal relations among multiple action segments to refine the initial predictions. To address these concerns, we design an efficient +Transformer-based model for the action segmentation task, named ASFormer, with three +distinctive characteristics: (i) We explicitly bring in the local connectivity inductive priors because of the high locality of features. It constrains the hypothesis space within a +reliable scope, and is beneficial for the action segmentation task to learn a proper target +function with small training sets. (ii) We apply a pre-defined hierarchical representation pattern that efficiently handles long input sequences. (iii) We carefully design the +decoder to refine the initial predictions from the encoder. Extensive experiments on +three public datasets demonstrate the effectiveness of our methods. The original code is available at +https://github.com/ChinaYi/ASFormer. + + + +
+ +
+ +## Results and Models + +### ActivityNet feature + +| feature | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | iter time(s) | config | ckpt | log | +| :-----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :----------: | :--------------------------------------------: | :------------------------------------------: | :------------------------------------------: | +| gtea | 1 | None | 67.25 | 32.89 | 49.43 | 56.64 | 75.29 | 8693 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | + +1. The **gpus** indicates the number of gpu we used to get the checkpoint. + According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU, + e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu. +2. For feature column, cuhk_mean_100 denotes the widely used cuhk activitynet feature extracted by [anet2016-cuhk](https://github.com/yjxiong/anet2016-cuhk). +3. We evaluate the action detection performance of BMN, using [anet_cuhk_2017](https://download.openmmlab.com/mmaction/localization/cuhk_anet17_pred.json) submission for ActivityNet2017 Untrimmed Video Classification Track to assign label for each action proposal. + +\*We train BMN with the [official repo](https://github.com/JJBOY/BMN-Boundary-Matching-Network), evaluate its proposal generation and action detection performance with [anet_cuhk_2017](https://download.openmmlab.com/mmaction/localization/cuhk_anet17_pred.json) for label assigning. + +For more details on data preparation, you can refer to [ActivityNet Data Preparation](/tools/data/activitynet/README.md). + +## Train + +Train ASFormer model on features dataset for action segmentation. + +```shell +bash tools/dist_train.sh configs/segmentation/asformer/asformer_gtea.py 1 +``` + +For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). + +## Test + +Test ASFormer on features dataset for action segmentation. + +```shell +python3 tools/test.py configs/segmentation/asformer/asformer_gtea.py CHECKPOINT.PTH +``` + +For more details, you can refer to the **Testing** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). + +## Citation + +```BibTeX +@inproceedings{chinayi_ASformer, + author={Fangqiu Yi and Hongyu Wen and Tingting Jiang}, + booktitle={The British Machine Vision Conference (BMVC)}, + title={ASFormer: Transformer for Action Segmentation}, + year={2021}, +} +``` + + + +```BibTeX +@inproceedings{fathi2011learning, + title={Learning to recognize objects in egocentric activities}, + author={Fathi, Alireza and Ren, Xiaofeng and Rehg, James M}, + booktitle={CVPR 2011}, + pages={3281--3288}, + year={2011}, + organization={IEEE} +} +``` diff --git a/configs/segmentation/asformer/metafile.yml b/configs/segmentation/asformer/metafile.yml new file mode 100644 index 0000000000..21a0e103ca --- /dev/null +++ b/configs/segmentation/asformer/metafile.yml @@ -0,0 +1,26 @@ +Collections: +- Name: ASFormer + README: configs/segmentation/asformer/README.md + Paper: + URL: https://arxiv.org/pdf/2110.08568.pdf + Title: "ASFormer: Transformer for Action Segmentation" + +Models: + - Name: bmn_2xb8-400x100-9e_activitynet-feature + Config: configs/segmentation/asformer/asformer_gtea.py + In Collection: ASFormer + Metadata: + Batch Size: 1 + Epochs: 120 + Training Data: GTEA + Training Resources: 1 GPU + Modality: RGB + Results: + - Dataset: GTEA + Task: Action Segmentation + Metrics: + Acc: 79.76 + Edit: 85.92 + F1@10: 90.02 + F1@25: 88.75 + F1@50: 80.23 From 9c9914e5a9cb6984886d6616cd11ff7877d1a324 Mon Sep 17 00:00:00 2001 From: jts250 Date: Tue, 10 Oct 2023 23:10:56 +0800 Subject: [PATCH 06/17] update tools/data --- tools/data/action_seg/README.md | 134 ++++++++++++++++++ tools/data/action_seg/download_datasets.sh | 21 +++ .../action_seg/generate_boundary_array.py | 61 ++++++++ tools/data/action_seg/generate_gt_array.py | 100 +++++++++++++ 4 files changed, 316 insertions(+) create mode 100644 tools/data/action_seg/README.md create mode 100644 tools/data/action_seg/download_datasets.sh create mode 100644 tools/data/action_seg/generate_boundary_array.py create mode 100644 tools/data/action_seg/generate_gt_array.py diff --git a/tools/data/action_seg/README.md b/tools/data/action_seg/README.md new file mode 100644 index 0000000000..b24a9a3157 --- /dev/null +++ b/tools/data/action_seg/README.md @@ -0,0 +1,134 @@ +# Preparing Datasets for Action Segmentation + +## Introduction + + + +```BibTeX +@inproceedings{fathi2011learning, + title={Learning to recognize objects in egocentric activities}, + author={Fathi, Alireza and Ren, Xiaofeng and Rehg, James M}, + booktitle={CVPR 2011}, + pages={3281--3288}, + year={2011}, + organization={IEEE} +} +``` + +```BibTeX +@inproceedings{stein2013combining, + title={Combining embedded accelerometers with computer vision for recognizing food preparation activities}, + author={Stein, Sebastian and McKenna, Stephen J}, + booktitle={Proceedings of the 2013 ACM international joint conference on Pervasive and ubiquitous computing}, + pages={729--738}, + year={2013} +} +``` + +```BibTeX +@inproceedings{kuehne2014language, + title={The language of actions: Recovering the syntax and semantics of goal-directed human activities}, + author={Kuehne, Hilde and Arslan, Ali and Serre, Thomas}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={780--787}, + year={2014} +} +``` + +For basic dataset information, you can refer to the articles. +Before we start, please make sure that the directory is located at `$MMACTION2/tools/data/`. +To run the bash scripts below, you need to install `unzip`. you can install it by `sudo apt-get install unzip`. + +## Step 1. Prepare Annotations and Features + +First of all, you can run the following script to prepare annotations and features. + +```shell +bash download_datasets.sh +``` + +## Step 2. Preprocess the Data + +you can execute the following script to preprocess the downloaded data and generate two folders for each dataset, 'gt_arr' and 'gt_boundary_arr'. + +```shell +python tools/data/action_seg/generate_boundary_array.py --dataset-dir action_seg +python tools/data/action_seg/generate_gt_array.py --dataset_dir data/action_seg +``` + +## Step 3. Check Directory Structure + +After the whole data process for GTEA, 50Salads and Breakfast preparation, +you will get the features, splits ,annotation files and groundtruth boundaries for the datasets. + +For extracting features from your own videos, please refer to [activitynet](/tools/data/activitynet/README.md). + +In the context of the whole project (for GTEA, 50Salads and Breakfast), the folder structure will look like: + +``` +mmaction2 +├── mmaction +├── tools +├── configs +├── data +│ ├── action_seg +│ │ ├── gtea +│ │ │ ├── features +│ │ │ │ ├── S1_Cheese_C1.npy +│ │ │ │ ├── S1_Coffee_C1.npy +│ │ │ │ ├── ... +│ │ │ ├── groundTruth +│ │ │ │ ├── S1_Cheese_C1.txt +│ │ │ │ ├── S1_Coffee_C1.txt +│ │ │ │ ├── ... +│ │ │ ├── gt_arr +│ │ │ │ ├── S1_Cheese_C1.npy +│ │ │ │ ├── S1_Coffee_C1.npy +│ │ │ │ ├── ... +│ │ │ ├── gt_boundary_arr +│ │ │ │ ├── S1_Cheese_C1.npy +│ │ │ │ ├── S1_Coffee_C1.npy +│ │ │ │ ├── ... +│ │ │ ├── splits +│ │ │ │ ├── fifa_mean_dur_split1.pt +│ │ │ │ ├── fifa_mean_dur_split2.pt +│ │ │ │ ├── ... +│ │ │ │ ├── test.split0.bundle +│ │ │ │ ├── test.split1.bundle +│ │ │ │ ├── ... +│ │ │ │ ├── train.split0.bundle +│ │ │ │ ├── train.split1.bundle +│ │ │ │ ├── ... +│ │ │ │ ├── train_split1_mean_duration.txt +│ │ │ │ ├── train_split2_mean_duration.txt +│ │ │ │ ├── ... +│ │ │ │ ├── ... +│ │ │ ├── mapping.txt +│ │ ├── 50salads +│ │ │ ├── features +│ │ │ │ ├── ... +│ │ │ ├── groundTruth +│ │ │ │ ├── ... +│ │ │ ├── gt_arr +│ │ │ │ ├── ... +│ │ │ ├── gt_boundary_arr +│ │ │ │ ├── ... +│ │ │ ├── splits +│ │ │ │ ├── ... +│ │ │ ├── mapping.txt +│ │ ├── breakfast +│ │ │ ├── features +│ │ │ │ ├── ... +│ │ │ ├── groundTruth +│ │ │ │ ├── ... +│ │ │ ├── gt_arr +│ │ │ │ ├── ... +│ │ │ ├── gt_boundary_arr +│ │ │ │ ├── ... +│ │ │ ├── splits +│ │ │ │ ├── ... +│ │ │ ├── mapping.txt + +``` + +For training and evaluating on GTEA, 50Salads and Breakfast, please refer to [Training and Test Tutorial](/docs/en/user_guides/train_test.md). diff --git a/tools/data/action_seg/download_datasets.sh b/tools/data/action_seg/download_datasets.sh new file mode 100644 index 0000000000..65dda9b53c --- /dev/null +++ b/tools/data/action_seg/download_datasets.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +set -e + +DATA_DIR="../../../data" + +if [[ ! -d "${DATA_DIR}" ]]; then + echo "${DATA_DIR} does not exist. Creating"; + mkdir -p ${DATA_DIR} +fi + +cd ${DATA_DIR} +wget https://zenodo.org/record/3625992/files/data.zip --no-check-certificate + +# sudo apt-get install unzip +unzip data.zip +rm data.zip + +mv data action_seg + +cd - diff --git a/tools/data/action_seg/generate_boundary_array.py b/tools/data/action_seg/generate_boundary_array.py new file mode 100644 index 0000000000..262a0394b9 --- /dev/null +++ b/tools/data/action_seg/generate_boundary_array.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import os + +import numpy as np + + +def get_arguments() -> argparse.Namespace: + """parse all the arguments from command line interface return a list of + parsed arguments.""" + + parser = argparse.ArgumentParser( + description='generate ground truth arrays for boundary regression.') + parser.add_argument( + '--dataset_dir', + type=str, + help='path to a dataset directory', + ) + + return parser.parse_args() + + +def main() -> None: + args = get_arguments() + + datasets = ['50salads', 'gtea', 'breakfast'] + + for dataset in datasets: + # make directory for saving ground truth numpy arrays + save_dir = os.path.join(args.dataset_dir, dataset, 'gt_boundary_arr') + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + gt_dir = os.path.join(args.dataset_dir, dataset, 'groundTruth') + gt_paths = glob.glob(os.path.join(gt_dir, '*.txt')) + + for gt_path in gt_paths: + # the name of ground truth text file + gt_name = os.path.relpath(gt_path, gt_dir) + + with open(gt_path, 'r') as f: + gt = f.read().split('\n')[:-1] + + # define the frame where new action starts as boundary frame + boundary = np.zeros(len(gt)) + last = gt[0] + boundary[0] = 1 + for i in range(1, len(gt)): + if last != gt[i]: + boundary[i] = 1 + last = gt[i] + + # save array + np.save(os.path.join(save_dir, gt_name[:-4] + '.npy'), boundary) + + print('Done') + + +if __name__ == '__main__': + main() diff --git a/tools/data/action_seg/generate_gt_array.py b/tools/data/action_seg/generate_gt_array.py new file mode 100644 index 0000000000..e4f7ffe7d0 --- /dev/null +++ b/tools/data/action_seg/generate_gt_array.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import os +import sys +from typing import Dict + +import numpy as np + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +dataset_names = ['50salads', 'breakfast', 'gtea'] + + +def get_class2id_map(dataset: str, + dataset_dir: str = './dataset') -> Dict[str, int]: + """ + Args: + dataset: 50salads, gtea, breakfast + dataset_dir: the path to the dataset directory + """ + + assert (dataset in dataset_names + ), 'You have to choose 50salads, gtea or breakfast as dataset.' + + with open( + os.path.join(dataset_dir, '{}/mapping.txt'.format(dataset)), + 'r') as f: + actions = f.read().split('\n')[:-1] + + class2id_map = dict() + for a in actions: + class2id_map[a.split()[1]] = int(a.split()[0]) + + return class2id_map + + +def get_id2class_map(dataset: str, + dataset_dir: str = './dataset') -> Dict[int, str]: + class2id_map = get_class2id_map(dataset, dataset_dir) + + return {val: key for key, val in class2id_map.items()} + + +def get_n_classes(dataset: str, dataset_dir: str = './dataset') -> int: + return len(get_class2id_map(dataset, dataset_dir)) + + +def get_arguments() -> argparse.Namespace: + """parse all the arguments from command line interface return a list of + parsed arguments.""" + + parser = argparse.ArgumentParser( + description='convert ground truth txt files to numpy array') + parser.add_argument( + '--dataset_dir', + type=str, + default='./dataset', + help='path to a dataset directory (default: ./dataset)', + ) + + return parser.parse_args() + + +def main() -> None: + args = get_arguments() + + datasets = ['50salads', 'gtea', 'breakfast'] + + for dataset in datasets: + # make directory for saving ground truth numpy arrays + save_dir = os.path.join(args.dataset_dir, dataset, 'gt_arr') + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + # class to index mapping + class2id_map = get_class2id_map(dataset, dataset_dir=args.dataset_dir) + + gt_dir = os.path.join(args.dataset_dir, dataset, 'groundTruth') + gt_paths = glob.glob(os.path.join(gt_dir, '*.txt')) + + for gt_path in gt_paths: + # the name of ground truth text file + gt_name = os.path.relpath(gt_path, gt_dir) + + with open(gt_path, 'r') as f: + gt = f.read().split('\n')[:-1] + + gt_array = np.zeros(len(gt)) + for i in range(len(gt)): + gt_array[i] = class2id_map[gt[i]] + + # save array + np.save(os.path.join(save_dir, gt_name[:-4] + '.npy'), gt_array) + + print('Done') + + +if __name__ == '__main__': + main() From ffa99708501344958fe7581947ff8a1d615e9c84 Mon Sep 17 00:00:00 2001 From: jts250 Date: Tue, 10 Oct 2023 23:41:39 +0800 Subject: [PATCH 07/17] modify configs --- ...er_50salads2.py => asformer_breakfast1.py} | 29 +++-- .../segmentation/asformer/asformer_gtea1.py | 12 -- .../segmentation/asformer/asformer_gtea2.py | 121 ------------------ .../segmentation/asformer/asformer_gtea3.py | 121 ------------------ .../segmentation/asformer/asformer_gtea4.py | 121 ------------------ mmaction/evaluation/metrics/segment_metric.py | 3 +- 6 files changed, 16 insertions(+), 391 deletions(-) rename configs/segmentation/asformer/{asformer_50salads2.py => asformer_breakfast1.py} (85%) delete mode 100644 configs/segmentation/asformer/asformer_gtea2.py delete mode 100644 configs/segmentation/asformer/asformer_gtea3.py delete mode 100644 configs/segmentation/asformer/asformer_gtea4.py diff --git a/configs/segmentation/asformer/asformer_50salads2.py b/configs/segmentation/asformer/asformer_breakfast1.py similarity index 85% rename from configs/segmentation/asformer/asformer_50salads2.py rename to configs/segmentation/asformer/asformer_breakfast1.py index f83569c14e..eb80e98383 100644 --- a/configs/segmentation/asformer/asformer_50salads2.py +++ b/configs/segmentation/asformer/asformer_breakfast1.py @@ -2,23 +2,23 @@ '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' ] # dataset settings dataset_type = 'ActionSegmentDataset' -data_root = 'data/action_seg/50salads/' -data_root_val = 'data/action_seg/50salads/' -ann_file_train = 'data/action_seg/50salads/splits/train.split2.bundle' -ann_file_val = 'data/action_seg/50salads/splits/test.split2.bundle' -ann_file_test = 'data/action_seg/50salads/splits/test.split2.bundle' +data_root = 'data/action_seg/breakfast/' +data_root_val = 'data/action_seg/breakfast/' +ann_file_train = 'data/action_seg/breakfast/splits/train.split1.bundle' +ann_file_val = 'data/action_seg/breakfast/splits/test.split1.bundle' +ann_file_test = 'data/action_seg/breakfast/splits/test.split1.bundle' model = dict( type='ASFormer', - num_layers=10, - num_f_maps=64, + channel_masking_rate=0.3, input_dim=2048, + num_classes=48, num_decoders=3, - num_classes=19, - channel_masking_rate=0.3, - sample_rate=2, + num_f_maps=64, + num_layers=10, r1=2, - r2=2) + r2=2, + sample_rate=1) train_pipeline = [ dict(type='LoadSegmentationFeature'), @@ -93,7 +93,7 @@ type='EpochBasedTrainLoop', max_epochs=max_epochs, val_begin=0, - val_interval=10) + val_interval=5) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') @@ -111,10 +111,11 @@ ], gamma=0.5) ] -work_dir = './work_dirs/50salads2/' + +work_dir = './work_dirs/breakfast1/' test_evaluator = dict( type='SegmentMetric', metric_type='ALL', dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) val_evaluator = test_evaluator -default_hooks = dict(checkpoint=dict(interval=10, max_keep_ckpts=3)) +default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=3)) diff --git a/configs/segmentation/asformer/asformer_gtea1.py b/configs/segmentation/asformer/asformer_gtea1.py index fc0315e1cc..5fe16b059d 100644 --- a/configs/segmentation/asformer/asformer_gtea1.py +++ b/configs/segmentation/asformer/asformer_gtea1.py @@ -87,18 +87,6 @@ test_cfg = dict(type='TestLoop') optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) -''' -param_scheduler = [ - dict( - monitor= 'F1@50', - param_name='lr', - type='ReduceOnPlateauParamScheduler', - rule='less', - factor=0.5, - patience=3,#33 - verbose=True) -] -''' param_scheduler = [ dict( type='MultiStepLR', diff --git a/configs/segmentation/asformer/asformer_gtea2.py b/configs/segmentation/asformer/asformer_gtea2.py deleted file mode 100644 index 2bb9cb7fbf..0000000000 --- a/configs/segmentation/asformer/asformer_gtea2.py +++ /dev/null @@ -1,121 +0,0 @@ -_base_ = [ - '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' -] # dataset settings -dataset_type = 'ActionSegmentDataset' -data_root = 'data/action_seg/gtea/' -data_root_val = 'data/action_seg/gtea/' -ann_file_train = 'data/action_seg/gtea/splits/train.split2.bundle' -ann_file_val = 'data/action_seg/gtea/splits/test.split2.bundle' -ann_file_test = 'data/action_seg/gtea/splits/test.split2.bundle' - -train_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=( - 'num_classes', - 'actions_dict', - 'index2label', - 'ground_truth', - 'classes', - )) -] - -val_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=('num_classes', 'actions_dict', 'index2label', - 'ground_truth', 'classes')) -] - -test_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=('num_classes', 'actions_dict', 'index2label', - 'ground_truth', 'classes')) -] - -train_dataloader = dict( - batch_size=1, - num_workers=1, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - drop_last=True, - dataset=dict( - type=dataset_type, - ann_file=ann_file_train, - data_prefix=dict(video=data_root), - pipeline=train_pipeline)) - -val_dataloader = dict( - batch_size=1, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_val, - data_prefix=dict(video=data_root_val), - pipeline=val_pipeline, - test_mode=True)) - -test_dataloader = dict( - batch_size=1, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_test, - data_prefix=dict(video=data_root_val), - pipeline=test_pipeline, - test_mode=True)) - -max_epochs = 120 -train_cfg = dict( - type='EpochBasedTrainLoop', - max_epochs=max_epochs, - val_begin=0, - val_interval=5) - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) -''' -param_scheduler = [ - dict( - monitor= 'F1@50', - param_name='lr', - type='ReduceOnPlateauParamScheduler', - rule='less', - factor=0.5, - patience=3,#33 - verbose=True) -] -''' -param_scheduler = [ - dict( - type='MultiStepLR', - begin=0, - end=max_epochs, - by_epoch=True, - milestones=[ - 80, - 100, - ], - gamma=0.5) -] - -work_dir = './work_dirs/asformer_gtea2/' -test_evaluator = dict( - type='SegmentMetric', - metric_type='ALL', - dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) -val_evaluator = test_evaluator -default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=6)) diff --git a/configs/segmentation/asformer/asformer_gtea3.py b/configs/segmentation/asformer/asformer_gtea3.py deleted file mode 100644 index 8bad7fdf93..0000000000 --- a/configs/segmentation/asformer/asformer_gtea3.py +++ /dev/null @@ -1,121 +0,0 @@ -_base_ = [ - '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' -] # dataset settings -dataset_type = 'ActionSegmentDataset' -data_root = 'data/action_seg/gtea/' -data_root_val = 'data/action_seg/gtea/' -ann_file_train = 'data/action_seg/gtea/splits/train.split3.bundle' -ann_file_val = 'data/action_seg/gtea/splits/test.split3.bundle' -ann_file_test = 'data/action_seg/gtea/splits/test.split3.bundle' - -train_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=( - 'num_classes', - 'actions_dict', - 'index2label', - 'ground_truth', - 'classes', - )) -] - -val_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=('num_classes', 'actions_dict', 'index2label', - 'ground_truth', 'classes')) -] - -test_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=('num_classes', 'actions_dict', 'index2label', - 'ground_truth', 'classes')) -] - -train_dataloader = dict( - batch_size=1, - num_workers=1, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - drop_last=True, - dataset=dict( - type=dataset_type, - ann_file=ann_file_train, - data_prefix=dict(video=data_root), - pipeline=train_pipeline)) - -val_dataloader = dict( - batch_size=1, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_val, - data_prefix=dict(video=data_root_val), - pipeline=val_pipeline, - test_mode=True)) - -test_dataloader = dict( - batch_size=1, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_test, - data_prefix=dict(video=data_root_val), - pipeline=test_pipeline, - test_mode=True)) - -max_epochs = 120 -train_cfg = dict( - type='EpochBasedTrainLoop', - max_epochs=max_epochs, - val_begin=0, - val_interval=5) - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) -''' -param_scheduler = [ - dict( - monitor= 'F1@50', - param_name='lr', - type='ReduceOnPlateauParamScheduler', - rule='less', - factor=0.5, - patience=3,#33 - verbose=True) -] -''' -param_scheduler = [ - dict( - type='MultiStepLR', - begin=0, - end=max_epochs, - by_epoch=True, - milestones=[ - 80, - 100, - ], - gamma=0.5) -] - -work_dir = './work_dirs/asformer_gtea3/' -test_evaluator = dict( - type='SegmentMetric', - metric_type='ALL', - dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) -val_evaluator = test_evaluator -default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=6)) diff --git a/configs/segmentation/asformer/asformer_gtea4.py b/configs/segmentation/asformer/asformer_gtea4.py deleted file mode 100644 index e531fc3d26..0000000000 --- a/configs/segmentation/asformer/asformer_gtea4.py +++ /dev/null @@ -1,121 +0,0 @@ -_base_ = [ - '../../_base_/models/asformer.py', '../../_base_/default_runtime.py' -] # dataset settings -dataset_type = 'ActionSegmentDataset' -data_root = 'data/action_seg/gtea/' -data_root_val = 'data/action_seg/gtea/' -ann_file_train = 'data/action_seg/gtea/splits/train.split4.bundle' -ann_file_val = 'data/action_seg/gtea/splits/test.split4.bundle' -ann_file_test = 'data/action_seg/gtea/splits/test.split4.bundle' - -train_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=( - 'num_classes', - 'actions_dict', - 'index2label', - 'ground_truth', - 'classes', - )) -] - -val_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=('num_classes', 'actions_dict', 'index2label', - 'ground_truth', 'classes')) -] - -test_pipeline = [ - dict(type='LoadSegmentationFeature'), - dict( - type='PackSegmentationInputs', - keys=('classes', ), - meta_keys=('num_classes', 'actions_dict', 'index2label', - 'ground_truth', 'classes')) -] - -train_dataloader = dict( - batch_size=1, - num_workers=1, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - drop_last=True, - dataset=dict( - type=dataset_type, - ann_file=ann_file_train, - data_prefix=dict(video=data_root), - pipeline=train_pipeline)) - -val_dataloader = dict( - batch_size=1, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_val, - data_prefix=dict(video=data_root_val), - pipeline=val_pipeline, - test_mode=True)) - -test_dataloader = dict( - batch_size=1, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_test, - data_prefix=dict(video=data_root_val), - pipeline=test_pipeline, - test_mode=True)) - -max_epochs = 120 -train_cfg = dict( - type='EpochBasedTrainLoop', - max_epochs=max_epochs, - val_begin=0, - val_interval=5) - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5)) -''' -param_scheduler = [ - dict( - monitor= 'F1@50', - param_name='lr', - type='ReduceOnPlateauParamScheduler', - rule='less', - factor=0.5, - patience=3,#33 - verbose=True) -] -''' -param_scheduler = [ - dict( - type='MultiStepLR', - begin=0, - end=max_epochs, - by_epoch=True, - milestones=[ - 80, - 100, - ], - gamma=0.5) -] - -work_dir = './work_dirs/asformer_gtea4/' -test_evaluator = dict( - type='SegmentMetric', - metric_type='ALL', - dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) -val_evaluator = test_evaluator -default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=6)) diff --git a/mmaction/evaluation/metrics/segment_metric.py b/mmaction/evaluation/metrics/segment_metric.py index e812b2be35..b9e2c45974 100644 --- a/mmaction/evaluation/metrics/segment_metric.py +++ b/mmaction/evaluation/metrics/segment_metric.py @@ -22,13 +22,12 @@ def __init__(self, super().__init__(collect_device=collect_device, prefix=prefix) self.metric_type = metric_type + assert metric_type == 'ALL' assert 'out' in dump_config self.output_format = dump_config.pop('output_format', 'csv') self.out = dump_config['out'] self.metric_options = metric_options - if self.metric_type == 'AR@AN': - self.ground_truth = {} def process(self, data_batch: Sequence[Tuple[Any, dict]], predictions: Sequence[dict]) -> None: From 6a493a8b51da4599c30c8670941c4803cd9a7fe9 Mon Sep 17 00:00:00 2001 From: jts250 Date: Wed, 11 Oct 2023 00:38:35 +0800 Subject: [PATCH 08/17] modify metafile&readme --- configs/segmentation/asformer/README.md | 57 ++++++++++++++----- ...mer_50salads1.py => asformer_50salads2.py} | 8 +-- ...r_breakfast1.py => asformer_breakfast2.py} | 8 +-- .../{asformer_gtea1.py => asformer_gtea2.py} | 8 +-- configs/segmentation/asformer/metafile.yml | 52 ++++++++++++++--- 5 files changed, 101 insertions(+), 32 deletions(-) rename configs/segmentation/asformer/{asformer_50salads1.py => asformer_50salads2.py} (92%) rename configs/segmentation/asformer/{asformer_breakfast1.py => asformer_breakfast2.py} (92%) rename configs/segmentation/asformer/{asformer_gtea1.py => asformer_gtea2.py} (92%) diff --git a/configs/segmentation/asformer/README.md b/configs/segmentation/asformer/README.md index 699b6a866a..99d757ac11 100644 --- a/configs/segmentation/asformer/README.md +++ b/configs/segmentation/asformer/README.md @@ -18,8 +18,7 @@ distinctive characteristics: (i) We explicitly bring in the local connectivity i reliable scope, and is beneficial for the action segmentation task to learn a proper target function with small training sets. (ii) We apply a pre-defined hierarchical representation pattern that efficiently handles long input sequences. (iii) We carefully design the decoder to refine the initial predictions from the encoder. Extensive experiments on -three public datasets demonstrate the effectiveness of our methods. The original code is available at -https://github.com/ChinaYi/ASFormer. +three public datasets demonstrate the effectiveness of our methods. @@ -27,30 +26,42 @@ https://github.com/ChinaYi/ASFormer. -## Results and Models +## Results ### ActivityNet feature -| feature | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | iter time(s) | config | ckpt | log | -| :-----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :----------: | :--------------------------------------------: | :------------------------------------------: | :------------------------------------------: | -| gtea | 1 | None | 67.25 | 32.89 | 49.43 | 56.64 | 75.29 | 8693 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| dataset | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | iter time(s) | config | ckpt | log | +| :--------------: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :----------: | :-----------------------------------------: | :---------------------------------------: | :---------------------------------------: | +| gtea_split2 | 1 | None | 80.34 | 81.58 | 89.30 | 87.83 | 75.28 | 1500 | - | - | - | - | +| gtea_split1 | 1 | None | 76.54 | 80.36 | 84.80 | 83.39 | 77.74 | 1500 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| gtea_split3 | 1 | None | 82.41 | 90.03 | 92.13 | 92.37 | 86.26 | 1500 | - | - | - | - | +| gtea_split4 | 1 | None | 79.77 | 91.70 | 92.88 | 92.39 | 81.65 | 1500 | - | - | - | - | +| 50salads_split2 | 1 | None | 87.55 | 79.10 | 85.17 | 83.73 | 77.99 | 7200 | - | - | - | - | +| 50salads_split1 | 1 | None | 81.44 | 73.25 | 82.04 | 80.27 | 71.84 | 7200 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| 50salads_split3 | 1 | None | 85.51 | 82.23 | 85.71 | 84.29 | 78.57 | 7200 | - | - | - | - | +| 50salads_split4 | 1 | None | 87.27 | 80.46 | 85.99 | 83.14 | 78.86 | 7200 | - | - | - | - | +| 50salads_split5 | 1 | None | 87.96 | 75.29 | 84.60 | 83.13 | 76.28 | 7200 | - | - | - | - | +| breakfast_split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | - | - | - | - | +| breakfast_split1 | 1 | None | 75.52 | 76.87 | 77.06 | 73.05 | 61.77 | 8800 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| breakfast_split3 | 1 | None | 74.86 | 74.33 | 76.17 | 70.85 | 58.07 | 8800 | - | - | - | - | +| breakfast_split4 | 1 | None | 70.39 | 71.54 | 73.42 | 66.61 | 52.76 | 8800 | - | - | - | - | 1. The **gpus** indicates the number of gpu we used to get the checkpoint. According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU, - e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu. -2. For feature column, cuhk_mean_100 denotes the widely used cuhk activitynet feature extracted by [anet2016-cuhk](https://github.com/yjxiong/anet2016-cuhk). -3. We evaluate the action detection performance of BMN, using [anet_cuhk_2017](https://download.openmmlab.com/mmaction/localization/cuhk_anet17_pred.json) submission for ActivityNet2017 Untrimmed Video Classification Track to assign label for each action proposal. + e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu. . -\*We train BMN with the [official repo](https://github.com/JJBOY/BMN-Boundary-Matching-Network), evaluate its proposal generation and action detection performance with [anet_cuhk_2017](https://download.openmmlab.com/mmaction/localization/cuhk_anet17_pred.json) for label assigning. +2. We train ASFormer with the [official repo](https://github.com/ChinaYi/ASFormer), evaluate its proposal segmentation performance with GTEA, Breakfast and 50Salads. -For more details on data preparation, you can refer to [ActivityNet Data Preparation](/tools/data/activitynet/README.md). +3. For experiments with other splits, we simply change the names of the training and testing datasets in the configs file. + +For more details on data preparation, you can refer to [Preparing Datasets for Action Segmentation](/tools/data/action_seg/README.md). ## Train Train ASFormer model on features dataset for action segmentation. ```shell -bash tools/dist_train.sh configs/segmentation/asformer/asformer_gtea.py 1 +bash tools/dist_train.sh configs/segmentation/asformer/asformer_gtea2.py 1 ``` For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). @@ -60,7 +71,7 @@ For more details, you can refer to the **Training** part in the [Training and Te Test ASFormer on features dataset for action segmentation. ```shell -python3 tools/test.py configs/segmentation/asformer/asformer_gtea.py CHECKPOINT.PTH +python3 tools/test.py configs/segmentation/asformer/asformer_gtea2.py CHECKPOINT.PTH ``` For more details, you can refer to the **Testing** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). @@ -88,3 +99,23 @@ For more details, you can refer to the **Testing** part in the [Training and Tes organization={IEEE} } ``` + +```BibTeX +@inproceedings{stein2013combining, + title={Combining embedded accelerometers with computer vision for recognizing food preparation activities}, + author={Stein, Sebastian and McKenna, Stephen J}, + booktitle={Proceedings of the 2013 ACM international joint conference on Pervasive and ubiquitous computing}, + pages={729--738}, + year={2013} +} +``` + +```BibTeX +@inproceedings{kuehne2014language, + title={The language of actions: Recovering the syntax and semantics of goal-directed human activities}, + author={Kuehne, Hilde and Arslan, Ali and Serre, Thomas}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={780--787}, + year={2014} +} +``` diff --git a/configs/segmentation/asformer/asformer_50salads1.py b/configs/segmentation/asformer/asformer_50salads2.py similarity index 92% rename from configs/segmentation/asformer/asformer_50salads1.py rename to configs/segmentation/asformer/asformer_50salads2.py index 77ee1f0f16..f83569c14e 100644 --- a/configs/segmentation/asformer/asformer_50salads1.py +++ b/configs/segmentation/asformer/asformer_50salads2.py @@ -4,9 +4,9 @@ dataset_type = 'ActionSegmentDataset' data_root = 'data/action_seg/50salads/' data_root_val = 'data/action_seg/50salads/' -ann_file_train = 'data/action_seg/50salads/splits/train.split1.bundle' -ann_file_val = 'data/action_seg/50salads/splits/test.split1.bundle' -ann_file_test = 'data/action_seg/50salads/splits/test.split1.bundle' +ann_file_train = 'data/action_seg/50salads/splits/train.split2.bundle' +ann_file_val = 'data/action_seg/50salads/splits/test.split2.bundle' +ann_file_test = 'data/action_seg/50salads/splits/test.split2.bundle' model = dict( type='ASFormer', @@ -111,7 +111,7 @@ ], gamma=0.5) ] -work_dir = './work_dirs/50salads1/' +work_dir = './work_dirs/50salads2/' test_evaluator = dict( type='SegmentMetric', metric_type='ALL', diff --git a/configs/segmentation/asformer/asformer_breakfast1.py b/configs/segmentation/asformer/asformer_breakfast2.py similarity index 92% rename from configs/segmentation/asformer/asformer_breakfast1.py rename to configs/segmentation/asformer/asformer_breakfast2.py index eb80e98383..2e3154a562 100644 --- a/configs/segmentation/asformer/asformer_breakfast1.py +++ b/configs/segmentation/asformer/asformer_breakfast2.py @@ -4,9 +4,9 @@ dataset_type = 'ActionSegmentDataset' data_root = 'data/action_seg/breakfast/' data_root_val = 'data/action_seg/breakfast/' -ann_file_train = 'data/action_seg/breakfast/splits/train.split1.bundle' -ann_file_val = 'data/action_seg/breakfast/splits/test.split1.bundle' -ann_file_test = 'data/action_seg/breakfast/splits/test.split1.bundle' +ann_file_train = 'data/action_seg/breakfast/splits/train.split2.bundle' +ann_file_val = 'data/action_seg/breakfast/splits/test.split2.bundle' +ann_file_test = 'data/action_seg/breakfast/splits/test.split2.bundle' model = dict( type='ASFormer', @@ -112,7 +112,7 @@ gamma=0.5) ] -work_dir = './work_dirs/breakfast1/' +work_dir = './work_dirs/breakfast2/' test_evaluator = dict( type='SegmentMetric', metric_type='ALL', diff --git a/configs/segmentation/asformer/asformer_gtea1.py b/configs/segmentation/asformer/asformer_gtea2.py similarity index 92% rename from configs/segmentation/asformer/asformer_gtea1.py rename to configs/segmentation/asformer/asformer_gtea2.py index 5fe16b059d..cc0f2e80ee 100644 --- a/configs/segmentation/asformer/asformer_gtea1.py +++ b/configs/segmentation/asformer/asformer_gtea2.py @@ -4,9 +4,9 @@ dataset_type = 'ActionSegmentDataset' data_root = 'data/action_seg/gtea/' data_root_val = 'data/action_seg/gtea/' -ann_file_train = 'data/action_seg/gtea/splits/train.split1.bundle' -ann_file_val = 'data/action_seg/gtea/splits/test.split1.bundle' -ann_file_test = 'data/action_seg/gtea/splits/test.split1.bundle' +ann_file_train = 'data/action_seg/gtea/splits/train.split2.bundle' +ann_file_val = 'data/action_seg/gtea/splits/test.split2.bundle' +ann_file_test = 'data/action_seg/gtea/splits/test.split2.bundle' train_pipeline = [ dict(type='LoadSegmentationFeature'), @@ -100,7 +100,7 @@ gamma=0.5) ] -work_dir = './work_dirs/gtea1/' +work_dir = './work_dirs/gtea2/' test_evaluator = dict( type='SegmentMetric', metric_type='ALL', diff --git a/configs/segmentation/asformer/metafile.yml b/configs/segmentation/asformer/metafile.yml index 21a0e103ca..fda26604ab 100644 --- a/configs/segmentation/asformer/metafile.yml +++ b/configs/segmentation/asformer/metafile.yml @@ -6,8 +6,8 @@ Collections: Title: "ASFormer: Transformer for Action Segmentation" Models: - - Name: bmn_2xb8-400x100-9e_activitynet-feature - Config: configs/segmentation/asformer/asformer_gtea.py + - Name: asformer_gtea2 + Config: configs/segmentation/asformer/asformer_gtea2.py In Collection: ASFormer Metadata: Batch Size: 1 @@ -19,8 +19,46 @@ Models: - Dataset: GTEA Task: Action Segmentation Metrics: - Acc: 79.76 - Edit: 85.92 - F1@10: 90.02 - F1@25: 88.75 - F1@50: 80.23 + Acc: 80.34 + Edit: 81.58 + F1@10: 89.30 + F1@25: 87.83 + F1@50: 75.28 + + - Name: asformer_50salads2 + Config: configs/segmentation/asformer/asformer_50salads2.py + In Collection: ASFormer + Metadata: + Batch Size: 1 + Epochs: 120 + Training Data: GTEA + Training Resources: 1 GPU + Modality: RGB + Results: + - Dataset: GTEA + Task: Action Segmentation + Metrics: + Acc: 87.55 + Edit: 79.10 + F1@10: 85.17 + F1@25: 83.73 + F1@50: 77.99 + + - Name: asformer_breakfast2 + Config: configs/segmentation/asformer/asformer_breakfast2.py + In Collection: ASFormer + Metadata: + Batch Size: 1 + Epochs: 120 + Training Data: GTEA + Training Resources: 1 GPU + Modality: RGB + Results: + - Dataset: GTEA + Task: Action Segmentation + Metrics: + Acc: 74.12 + Edit: 76.53 + F1@10: 77.74 + F1@25: 72.62 + F1@50: 60.43 From 82cfcae37b623b3d2ce6f57d8401ceaa30c848b1 Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Tue, 10 Oct 2023 22:48:45 -0400 Subject: [PATCH 09/17] update readme --- configs/segmentation/asformer/README.md | 44 ++++++++++++------- ... => asformer_1xb1-120e_50salads-split2.py} | 0 ...=> asformer_1xb1-120e_breakfast-split2.py} | 0 ...2.py => asformer_1xb1-120e_gtea-split2.py} | 0 4 files changed, 27 insertions(+), 17 deletions(-) rename configs/segmentation/asformer/{asformer_50salads2.py => asformer_1xb1-120e_50salads-split2.py} (100%) rename configs/segmentation/asformer/{asformer_breakfast2.py => asformer_1xb1-120e_breakfast-split2.py} (100%) rename configs/segmentation/asformer/{asformer_gtea2.py => asformer_1xb1-120e_gtea-split2.py} (100%) diff --git a/configs/segmentation/asformer/README.md b/configs/segmentation/asformer/README.md index 99d757ac11..d0f507658d 100644 --- a/configs/segmentation/asformer/README.md +++ b/configs/segmentation/asformer/README.md @@ -28,23 +28,33 @@ three public datasets demonstrate the effectiveness of our methods. ## Results -### ActivityNet feature - -| dataset | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | iter time(s) | config | ckpt | log | -| :--------------: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :----------: | :-----------------------------------------: | :---------------------------------------: | :---------------------------------------: | -| gtea_split2 | 1 | None | 80.34 | 81.58 | 89.30 | 87.83 | 75.28 | 1500 | - | - | - | - | -| gtea_split1 | 1 | None | 76.54 | 80.36 | 84.80 | 83.39 | 77.74 | 1500 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | -| gtea_split3 | 1 | None | 82.41 | 90.03 | 92.13 | 92.37 | 86.26 | 1500 | - | - | - | - | -| gtea_split4 | 1 | None | 79.77 | 91.70 | 92.88 | 92.39 | 81.65 | 1500 | - | - | - | - | -| 50salads_split2 | 1 | None | 87.55 | 79.10 | 85.17 | 83.73 | 77.99 | 7200 | - | - | - | - | -| 50salads_split1 | 1 | None | 81.44 | 73.25 | 82.04 | 80.27 | 71.84 | 7200 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | -| 50salads_split3 | 1 | None | 85.51 | 82.23 | 85.71 | 84.29 | 78.57 | 7200 | - | - | - | - | -| 50salads_split4 | 1 | None | 87.27 | 80.46 | 85.99 | 83.14 | 78.86 | 7200 | - | - | - | - | -| 50salads_split5 | 1 | None | 87.96 | 75.29 | 84.60 | 83.13 | 76.28 | 7200 | - | - | - | - | -| breakfast_split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | - | - | - | - | -| breakfast_split1 | 1 | None | 75.52 | 76.87 | 77.06 | 73.05 | 61.77 | 8800 | - | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | -| breakfast_split3 | 1 | None | 74.86 | 74.33 | 76.17 | 70.85 | 58.07 | 8800 | - | - | - | - | -| breakfast_split4 | 1 | None | 70.39 | 71.54 | 73.42 | 66.61 | 52.76 | 8800 | - | - | - | - | +### GTEA + +| split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | +| :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | +| split2 | 1 | None | 80.34 | 81.58 | 89.30 | 87.83 | 75.28 | 1500 | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | | +| split1 | 1 | None | 76.54 | 80.36 | 84.80 | 83.39 | 77.74 | 1500 | - | - | - | +| split3 | 1 | None | 82.41 | 90.03 | 92.13 | 92.37 | 86.26 | 1500 | - | - | - | +| split4 | 1 | None | 79.77 | 91.70 | 92.88 | 92.39 | 81.65 | 1500 | - | - | - | + +### 50Salads + +| split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | +| :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | +| split2 | 1 | None | 87.55 | 79.10 | 85.17 | 83.73 | 77.99 | 7200 | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| split1 | 1 | None | 81.44 | 73.25 | 82.04 | 80.27 | 71.84 | 7200 | - | | | +| split3 | 1 | None | 85.51 | 82.23 | 85.71 | 84.29 | 78.57 | 7200 | - | - | - | +| split4 | 1 | None | 87.27 | 80.46 | 85.99 | 83.14 | 78.86 | 7200 | - | - | - | +| split5 | 1 | None | 87.96 | 75.29 | 84.60 | 83.13 | 76.28 | 7200 | - | - | - | + +### Breakfast + +| split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | +| :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | +| split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| split1 | 1 | None | 75.52 | 76.87 | 77.06 | 73.05 | 61.77 | 8800 | - | - | - | +| split3 | 1 | None | 74.86 | 74.33 | 76.17 | 70.85 | 58.07 | 8800 | - | - | - | +| split4 | 1 | None | 70.39 | 71.54 | 73.42 | 66.61 | 52.76 | 8800 | - | - | - | 1. The **gpus** indicates the number of gpu we used to get the checkpoint. According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU, diff --git a/configs/segmentation/asformer/asformer_50salads2.py b/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2.py similarity index 100% rename from configs/segmentation/asformer/asformer_50salads2.py rename to configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2.py diff --git a/configs/segmentation/asformer/asformer_breakfast2.py b/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2.py similarity index 100% rename from configs/segmentation/asformer/asformer_breakfast2.py rename to configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2.py diff --git a/configs/segmentation/asformer/asformer_gtea2.py b/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2.py similarity index 100% rename from configs/segmentation/asformer/asformer_gtea2.py rename to configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2.py From 0c96fd8c5176a6f708b8cdeaea9b658a2972c8ae Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Tue, 10 Oct 2023 23:13:04 -0400 Subject: [PATCH 10/17] update readme --- configs/segmentation/asformer/README.md | 18 +++++++----------- ...r_1xb1-120e_50salads-split2-i3d-feature.py} | 0 ..._1xb1-120e_breakfast-split2-i3d-feature.py} | 0 ...ormer_1xb1-120e_gtea-split2-i3d-feature.py} | 0 4 files changed, 7 insertions(+), 11 deletions(-) rename configs/segmentation/asformer/{asformer_1xb1-120e_50salads-split2.py => asformer_1xb1-120e_50salads-split2-i3d-feature.py} (100%) rename configs/segmentation/asformer/{asformer_1xb1-120e_breakfast-split2.py => asformer_1xb1-120e_breakfast-split2-i3d-feature.py} (100%) rename configs/segmentation/asformer/{asformer_1xb1-120e_gtea-split2.py => asformer_1xb1-120e_gtea-split2-i3d-feature.py} (100%) diff --git a/configs/segmentation/asformer/README.md b/configs/segmentation/asformer/README.md index d0f507658d..512edccc93 100644 --- a/configs/segmentation/asformer/README.md +++ b/configs/segmentation/asformer/README.md @@ -23,7 +23,7 @@ three public datasets demonstrate the effectiveness of our methods.
- +
## Results @@ -32,7 +32,7 @@ three public datasets demonstrate the effectiveness of our methods. | split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | | :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | -| split2 | 1 | None | 80.34 | 81.58 | 89.30 | 87.83 | 75.28 | 1500 | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | | +| split2 | 1 | None | 80.34 | 81.58 | 89.30 | 87.83 | 75.28 | 1500 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature_20231011-b5aaf789.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.log) | | split1 | 1 | None | 76.54 | 80.36 | 84.80 | 83.39 | 77.74 | 1500 | - | - | - | | split3 | 1 | None | 82.41 | 90.03 | 92.13 | 92.37 | 86.26 | 1500 | - | - | - | | split4 | 1 | None | 79.77 | 91.70 | 92.88 | 92.39 | 81.65 | 1500 | - | - | - | @@ -41,7 +41,7 @@ three public datasets demonstrate the effectiveness of our methods. | split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | | :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | -| split2 | 1 | None | 87.55 | 79.10 | 85.17 | 83.73 | 77.99 | 7200 | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| split2 | 1 | None | 87.55 | 79.10 | 85.17 | 83.73 | 77.99 | 7200 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature_20231011-25dc57d5.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.log) | | split1 | 1 | None | 81.44 | 73.25 | 82.04 | 80.27 | 71.84 | 7200 | - | | | | split3 | 1 | None | 85.51 | 82.23 | 85.71 | 84.29 | 78.57 | 7200 | - | - | - | | split4 | 1 | None | 87.27 | 80.46 | 85.99 | 83.14 | 78.86 | 7200 | - | - | - | @@ -51,18 +51,14 @@ three public datasets demonstrate the effectiveness of our methods. | split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | | :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | -| split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | [config](/configs/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature_20220908-79f92857.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/bmn/bmn_2xb8-400x100-9e_activitynet-feature.log) | +| split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth) | | split1 | 1 | None | 75.52 | 76.87 | 77.06 | 73.05 | 61.77 | 8800 | - | - | - | | split3 | 1 | None | 74.86 | 74.33 | 76.17 | 70.85 | 58.07 | 8800 | - | - | - | | split4 | 1 | None | 70.39 | 71.54 | 73.42 | 66.61 | 52.76 | 8800 | - | - | - | 1. The **gpus** indicates the number of gpu we used to get the checkpoint. - According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU, - e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu. . -2. We train ASFormer with the [official repo](https://github.com/ChinaYi/ASFormer), evaluate its proposal segmentation performance with GTEA, Breakfast and 50Salads. - -3. For experiments with other splits, we simply change the names of the training and testing datasets in the configs file. +2. We only provide checkpoints of one split. For experiments with other splits, we simply change the names of the training and testing datasets in the configs file, i.e., modifying `ann_file_train`, `ann_file_val` and `ann_file_test`. For more details on data preparation, you can refer to [Preparing Datasets for Action Segmentation](/tools/data/action_seg/README.md). @@ -71,7 +67,7 @@ For more details on data preparation, you can refer to [Preparing Datasets for A Train ASFormer model on features dataset for action segmentation. ```shell -bash tools/dist_train.sh configs/segmentation/asformer/asformer_gtea2.py 1 +bash tools/dist_train.sh configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py 1 ``` For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). @@ -81,7 +77,7 @@ For more details, you can refer to the **Training** part in the [Training and Te Test ASFormer on features dataset for action segmentation. ```shell -python3 tools/test.py configs/segmentation/asformer/asformer_gtea2.py CHECKPOINT.PTH +python3 tools/test.py configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py CHECKPOINT.PTH ``` For more details, you can refer to the **Testing** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). diff --git a/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2.py b/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py similarity index 100% rename from configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2.py rename to configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py diff --git a/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2.py b/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py similarity index 100% rename from configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2.py rename to configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py diff --git a/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2.py b/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py similarity index 100% rename from configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2.py rename to configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py From d5db9f74d6992f3aec0dfd4e300a1397cf927d05 Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Tue, 10 Oct 2023 23:19:52 -0400 Subject: [PATCH 11/17] update readme --- configs/segmentation/asformer/README.md | 6 +++--- configs/segmentation/asformer/metafile.yml | 12 +++++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/configs/segmentation/asformer/README.md b/configs/segmentation/asformer/README.md index 512edccc93..f8c47432d3 100644 --- a/configs/segmentation/asformer/README.md +++ b/configs/segmentation/asformer/README.md @@ -23,7 +23,7 @@ three public datasets demonstrate the effectiveness of our methods.
- +
## Results @@ -42,7 +42,7 @@ three public datasets demonstrate the effectiveness of our methods. | split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | | :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | | split2 | 1 | None | 87.55 | 79.10 | 85.17 | 83.73 | 77.99 | 7200 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature_20231011-25dc57d5.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.log) | -| split1 | 1 | None | 81.44 | 73.25 | 82.04 | 80.27 | 71.84 | 7200 | - | | | +| split1 | 1 | None | 81.44 | 73.25 | 82.04 | 80.27 | 71.84 | 7200 | - | - | - | | split3 | 1 | None | 85.51 | 82.23 | 85.71 | 84.29 | 78.57 | 7200 | - | - | - | | split4 | 1 | None | 87.27 | 80.46 | 85.99 | 83.14 | 78.86 | 7200 | - | - | - | | split5 | 1 | None | 87.96 | 75.29 | 84.60 | 83.13 | 76.28 | 7200 | - | - | - | @@ -51,7 +51,7 @@ three public datasets demonstrate the effectiveness of our methods. | split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log | | :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: | -| split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth) | +| split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.log) | | split1 | 1 | None | 75.52 | 76.87 | 77.06 | 73.05 | 61.77 | 8800 | - | - | - | | split3 | 1 | None | 74.86 | 74.33 | 76.17 | 70.85 | 58.07 | 8800 | - | - | - | | split4 | 1 | None | 70.39 | 71.54 | 73.42 | 66.61 | 52.76 | 8800 | - | - | - | diff --git a/configs/segmentation/asformer/metafile.yml b/configs/segmentation/asformer/metafile.yml index fda26604ab..714c8f5ebc 100644 --- a/configs/segmentation/asformer/metafile.yml +++ b/configs/segmentation/asformer/metafile.yml @@ -7,7 +7,7 @@ Collections: Models: - Name: asformer_gtea2 - Config: configs/segmentation/asformer/asformer_gtea2.py + Config: configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py In Collection: ASFormer Metadata: Batch Size: 1 @@ -24,9 +24,11 @@ Models: F1@10: 89.30 F1@25: 87.83 F1@50: 75.28 + Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.log + Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature_20231011-25dc57d5.pth - Name: asformer_50salads2 - Config: configs/segmentation/asformer/asformer_50salads2.py + Config: configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py In Collection: ASFormer Metadata: Batch Size: 1 @@ -43,9 +45,11 @@ Models: F1@10: 85.17 F1@25: 83.73 F1@50: 77.99 + Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.log + Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature_20231011-b5aaf789.pth - Name: asformer_breakfast2 - Config: configs/segmentation/asformer/asformer_breakfast2.py + Config: configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py In Collection: ASFormer Metadata: Batch Size: 1 @@ -62,3 +66,5 @@ Models: F1@10: 77.74 F1@25: 72.62 F1@50: 60.43 + Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.log + Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth From 3fa86bca6da086afd974056b726816fa2fe22670 Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Tue, 10 Oct 2023 23:27:37 -0400 Subject: [PATCH 12/17] update readme --- configs/segmentation/asformer/README.md | 2 +- configs/segmentation/asformer/metafile.yml | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/segmentation/asformer/README.md b/configs/segmentation/asformer/README.md index f8c47432d3..6cfe4e1f42 100644 --- a/configs/segmentation/asformer/README.md +++ b/configs/segmentation/asformer/README.md @@ -58,7 +58,7 @@ three public datasets demonstrate the effectiveness of our methods. 1. The **gpus** indicates the number of gpu we used to get the checkpoint. -2. We only provide checkpoints of one split. For experiments with other splits, we simply change the names of the training and testing datasets in the configs file, i.e., modifying `ann_file_train`, `ann_file_val` and `ann_file_test`. +2. We report results trained on every split, but only provide checkpoints of one split. For experiments with other splits, simply change the paths to the training and testing datasets in the configs file, i.e., modifying `ann_file_train`, `ann_file_val` and `ann_file_test`. For more details on data preparation, you can refer to [Preparing Datasets for Action Segmentation](/tools/data/action_seg/README.md). diff --git a/configs/segmentation/asformer/metafile.yml b/configs/segmentation/asformer/metafile.yml index 714c8f5ebc..5915efb531 100644 --- a/configs/segmentation/asformer/metafile.yml +++ b/configs/segmentation/asformer/metafile.yml @@ -6,7 +6,7 @@ Collections: Title: "ASFormer: Transformer for Action Segmentation" Models: - - Name: asformer_gtea2 + - Name: asformer_1xb1-120e_gtea-split2-i3d-feature Config: configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py In Collection: ASFormer Metadata: @@ -16,7 +16,7 @@ Models: Training Resources: 1 GPU Modality: RGB Results: - - Dataset: GTEA + - Dataset: GTEA split2 Task: Action Segmentation Metrics: Acc: 80.34 @@ -27,17 +27,17 @@ Models: Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.log Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature_20231011-25dc57d5.pth - - Name: asformer_50salads2 + - Name: asformer_1xb1-120e_50salads-split2-i3d-feature Config: configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py In Collection: ASFormer Metadata: Batch Size: 1 Epochs: 120 - Training Data: GTEA + Training Data: 50salads Training Resources: 1 GPU Modality: RGB Results: - - Dataset: GTEA + - Dataset: 50salads split2 Task: Action Segmentation Metrics: Acc: 87.55 @@ -48,17 +48,17 @@ Models: Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.log Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature_20231011-b5aaf789.pth - - Name: asformer_breakfast2 + - Name: asformer_1xb1-120e_breakfast-split2-i3d-feature Config: configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py In Collection: ASFormer Metadata: Batch Size: 1 Epochs: 120 - Training Data: GTEA + Training Data: breakfast Training Resources: 1 GPU Modality: RGB Results: - - Dataset: GTEA + - Dataset: breakfast split2 Task: Action Segmentation Metrics: Acc: 74.12 From 4ad54159c7e6452548ec1e0ab2b7a2bdcc15164d Mon Sep 17 00:00:00 2001 From: jts250 Date: Thu, 12 Oct 2023 09:25:18 +0800 Subject: [PATCH 13/17] modify tools/data/action_seg/readme --- tools/data/action_seg/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/data/action_seg/README.md b/tools/data/action_seg/README.md index b24a9a3157..3ada793e2d 100644 --- a/tools/data/action_seg/README.md +++ b/tools/data/action_seg/README.md @@ -61,8 +61,6 @@ python tools/data/action_seg/generate_gt_array.py --dataset_dir data/action_seg After the whole data process for GTEA, 50Salads and Breakfast preparation, you will get the features, splits ,annotation files and groundtruth boundaries for the datasets. -For extracting features from your own videos, please refer to [activitynet](/tools/data/activitynet/README.md). - In the context of the whole project (for GTEA, 50Salads and Breakfast), the folder structure will look like: ``` From 788964facb8f18bebc870038685593d804b08429 Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Thu, 12 Oct 2023 00:11:41 -0400 Subject: [PATCH 14/17] fix lint --- mmaction/datasets/__init__.py | 2 +- mmaction/evaluation/metrics/__init__.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mmaction/datasets/__init__.py b/mmaction/datasets/__init__.py index aaff6bb3da..be417d5a81 100644 --- a/mmaction/datasets/__init__.py +++ b/mmaction/datasets/__init__.py @@ -17,6 +17,6 @@ 'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset', 'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset', 'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset', - 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC', 'CharadesSTADataset', + 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC', 'CharadesSTADataset', 'ActionSegmentDataset' ] diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py index 43d252e4ca..96ad99c96a 100644 --- a/mmaction/evaluation/metrics/__init__.py +++ b/mmaction/evaluation/metrics/__init__.py @@ -8,7 +8,6 @@ from .segment_metric import SegmentMetric from .video_grounding_metric import RecallatTopK - __all__ = [ 'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix', 'MultiSportsMetric', 'RetrievalMetric', 'VQAAcc', 'ReportVQA', 'VQAMCACC', From 5d935d4900b1423d7da0947f72ef8c468f34ea96 Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Thu, 12 Oct 2023 00:55:29 -0400 Subject: [PATCH 15/17] add docstring --- mmaction/models/action_segmentors/asformer.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/mmaction/models/action_segmentors/asformer.py b/mmaction/models/action_segmentors/asformer.py index b83639cb4d..534595a19c 100644 --- a/mmaction/models/action_segmentors/asformer.py +++ b/mmaction/models/action_segmentors/asformer.py @@ -60,8 +60,6 @@ def forward(self, inputs, data_samples, mode, **kwargs): - If ``mode="loss"``, return a dict of tensor. """ input = torch.stack(inputs) - if mode == 'tensor': - return self._forward(inputs, **kwargs) if mode == 'predict': return self.predict(input, data_samples, **kwargs) elif mode == 'loss': @@ -169,19 +167,6 @@ def predict(self, batch_inputs, batch_data_samples, **kwargs): output = [dict(ground=ground, recognition=recognition)] return output - def _forward(self, x): - """Define the computation performed at every call. - - Args: - x (torch.Tensor): The input data. - Returns: - torch.Tensor: The output of the module. - """ - print(x.shape) - - return x.shape - - def exponential_descrease(idx_decoder, p=3): return math.exp(-p * idx_decoder) @@ -448,6 +433,13 @@ def __init__(self, dilation, in_channels, out_channels): dilation=dilation), nn.ReLU()) def forward(self, x): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ return self.layer(x) @@ -579,7 +571,7 @@ def forward(self, x, fencoder, mask): class MyTransformer(nn.Module): - + """An encoder-decoder transformer""" def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, num_classes, channel_masking_rate): super(MyTransformer, self).__init__() @@ -608,6 +600,13 @@ def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, ]) # num_decoders def forward(self, x, mask): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ out, feature = self.encoder(x, mask) outputs = out.unsqueeze(0) @@ -617,4 +616,4 @@ def forward(self, x, mask): feature * mask[:, 0:1, :], mask) outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) - return outputs + return outputs \ No newline at end of file From 2f87a573e3d51707cb8ddd3514d3e0d41d6a0199 Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Thu, 12 Oct 2023 01:06:18 -0400 Subject: [PATCH 16/17] add docstring --- mmaction/models/action_segmentors/asformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mmaction/models/action_segmentors/asformer.py b/mmaction/models/action_segmentors/asformer.py index 534595a19c..5de3f4f0ba 100644 --- a/mmaction/models/action_segmentors/asformer.py +++ b/mmaction/models/action_segmentors/asformer.py @@ -167,6 +167,7 @@ def predict(self, batch_inputs, batch_data_samples, **kwargs): output = [dict(ground=ground, recognition=recognition)] return output + def exponential_descrease(idx_decoder, p=3): return math.exp(-p * idx_decoder) @@ -571,7 +572,8 @@ def forward(self, x, fencoder, mask): class MyTransformer(nn.Module): - """An encoder-decoder transformer""" + """An encoder-decoder transformer.""" + def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, num_classes, channel_masking_rate): super(MyTransformer, self).__init__() @@ -601,7 +603,7 @@ def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, def forward(self, x, mask): """Define the computation performed at every call. - + Args: x (torch.Tensor): The input data. Returns: @@ -616,4 +618,4 @@ def forward(self, x, mask): feature * mask[:, 0:1, :], mask) outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) - return outputs \ No newline at end of file + return outputs From df478d03ccd9da255442562f7becbbab65310c53 Mon Sep 17 00:00:00 2001 From: KaiHoo Date: Thu, 12 Oct 2023 01:32:45 -0400 Subject: [PATCH 17/17] add docstring --- mmaction/models/action_segmentors/asformer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mmaction/models/action_segmentors/asformer.py b/mmaction/models/action_segmentors/asformer.py index 5de3f4f0ba..25a072016a 100644 --- a/mmaction/models/action_segmentors/asformer.py +++ b/mmaction/models/action_segmentors/asformer.py @@ -453,6 +453,13 @@ def __init__(self, in_channels, out_channels): nn.Conv1d(out_channels, out_channels, 1)) def forward(self, x): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ return self.layer(x) @@ -506,6 +513,13 @@ def __init__(self, d_model, max_len=10000): # self.register_buffer('pe', pe) def forward(self, x): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ return x + self.pe[:, :, 0:x.shape[2]]