diff --git a/configs/_base_/models/asformer.py b/configs/_base_/models/asformer.py
new file mode 100644
index 0000000000..6468449ada
--- /dev/null
+++ b/configs/_base_/models/asformer.py
@@ -0,0 +1,12 @@
+# model settings
+model = dict(
+ type='ASFormer',
+ num_layers=10,
+ num_f_maps=64,
+ input_dim=2048,
+ num_decoders=3,
+ num_classes=11,
+ channel_masking_rate=0.5,
+ sample_rate=1,
+ r1=2,
+ r2=2)
diff --git a/configs/segmentation/asformer/README.md b/configs/segmentation/asformer/README.md
new file mode 100644
index 0000000000..6cfe4e1f42
--- /dev/null
+++ b/configs/segmentation/asformer/README.md
@@ -0,0 +1,127 @@
+# ASFormer
+
+[ASFormer: Transformer for Action Segmentation](https://arxiv.org/pdf/2110.08568.pdf)
+
+
+
+## Abstract
+
+
+
+Algorithms for the action segmentation task typically use temporal models to predict
+what action is occurring at each frame for a minute-long daily activity. Recent studies have shown the potential of Transformer in modeling the relations among elements
+in sequential data. However, there are several major concerns when directly applying
+the Transformer to the action segmentation task, such as the lack of inductive biases
+with small training sets, the deficit in processing long input sequence, and the limitation of the decoder architecture to utilize temporal relations among multiple action segments to refine the initial predictions. To address these concerns, we design an efficient
+Transformer-based model for the action segmentation task, named ASFormer, with three
+distinctive characteristics: (i) We explicitly bring in the local connectivity inductive priors because of the high locality of features. It constrains the hypothesis space within a
+reliable scope, and is beneficial for the action segmentation task to learn a proper target
+function with small training sets. (ii) We apply a pre-defined hierarchical representation pattern that efficiently handles long input sequences. (iii) We carefully design the
+decoder to refine the initial predictions from the encoder. Extensive experiments on
+three public datasets demonstrate the effectiveness of our methods.
+
+
+
+
+
+
+
+## Results
+
+### GTEA
+
+| split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log |
+| :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: |
+| split2 | 1 | None | 80.34 | 81.58 | 89.30 | 87.83 | 75.28 | 1500 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature_20231011-b5aaf789.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.log) |
+| split1 | 1 | None | 76.54 | 80.36 | 84.80 | 83.39 | 77.74 | 1500 | - | - | - |
+| split3 | 1 | None | 82.41 | 90.03 | 92.13 | 92.37 | 86.26 | 1500 | - | - | - |
+| split4 | 1 | None | 79.77 | 91.70 | 92.88 | 92.39 | 81.65 | 1500 | - | - | - |
+
+### 50Salads
+
+| split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log |
+| :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: |
+| split2 | 1 | None | 87.55 | 79.10 | 85.17 | 83.73 | 77.99 | 7200 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature_20231011-25dc57d5.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.log) |
+| split1 | 1 | None | 81.44 | 73.25 | 82.04 | 80.27 | 71.84 | 7200 | - | - | - |
+| split3 | 1 | None | 85.51 | 82.23 | 85.71 | 84.29 | 78.57 | 7200 | - | - | - |
+| split4 | 1 | None | 87.27 | 80.46 | 85.99 | 83.14 | 78.86 | 7200 | - | - | - |
+| split5 | 1 | None | 87.96 | 75.29 | 84.60 | 83.13 | 76.28 | 7200 | - | - | - |
+
+### Breakfast
+
+| split | gpus | pretrain | ACC | EDIT | F1@10 | F1@25 | F1@50 | gpu_mem(M) | config | ckpt | log |
+| :----: | :--: | :------: | :---: | :---: | :---: | :---: | :---: | :--------: | :------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------: |
+| split2 | 1 | None | 74.12 | 76.53 | 77.74 | 72.62 | 60.43 | 8800 | [config](/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.log) |
+| split1 | 1 | None | 75.52 | 76.87 | 77.06 | 73.05 | 61.77 | 8800 | - | - | - |
+| split3 | 1 | None | 74.86 | 74.33 | 76.17 | 70.85 | 58.07 | 8800 | - | - | - |
+| split4 | 1 | None | 70.39 | 71.54 | 73.42 | 66.61 | 52.76 | 8800 | - | - | - |
+
+1. The **gpus** indicates the number of gpu we used to get the checkpoint.
+
+2. We report results trained on every split, but only provide checkpoints of one split. For experiments with other splits, simply change the paths to the training and testing datasets in the configs file, i.e., modifying `ann_file_train`, `ann_file_val` and `ann_file_test`.
+
+For more details on data preparation, you can refer to [Preparing Datasets for Action Segmentation](/tools/data/action_seg/README.md).
+
+## Train
+
+Train ASFormer model on features dataset for action segmentation.
+
+```shell
+bash tools/dist_train.sh configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py 1
+```
+
+For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md).
+
+## Test
+
+Test ASFormer on features dataset for action segmentation.
+
+```shell
+python3 tools/test.py configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py CHECKPOINT.PTH
+```
+
+For more details, you can refer to the **Testing** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md).
+
+## Citation
+
+```BibTeX
+@inproceedings{chinayi_ASformer,
+ author={Fangqiu Yi and Hongyu Wen and Tingting Jiang},
+ booktitle={The British Machine Vision Conference (BMVC)},
+ title={ASFormer: Transformer for Action Segmentation},
+ year={2021},
+}
+```
+
+
+
+```BibTeX
+@inproceedings{fathi2011learning,
+ title={Learning to recognize objects in egocentric activities},
+ author={Fathi, Alireza and Ren, Xiaofeng and Rehg, James M},
+ booktitle={CVPR 2011},
+ pages={3281--3288},
+ year={2011},
+ organization={IEEE}
+}
+```
+
+```BibTeX
+@inproceedings{stein2013combining,
+ title={Combining embedded accelerometers with computer vision for recognizing food preparation activities},
+ author={Stein, Sebastian and McKenna, Stephen J},
+ booktitle={Proceedings of the 2013 ACM international joint conference on Pervasive and ubiquitous computing},
+ pages={729--738},
+ year={2013}
+}
+```
+
+```BibTeX
+@inproceedings{kuehne2014language,
+ title={The language of actions: Recovering the syntax and semantics of goal-directed human activities},
+ author={Kuehne, Hilde and Arslan, Ali and Serre, Thomas},
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
+ pages={780--787},
+ year={2014}
+}
+```
diff --git a/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py b/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py
new file mode 100644
index 0000000000..f83569c14e
--- /dev/null
+++ b/configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py
@@ -0,0 +1,120 @@
+_base_ = [
+ '../../_base_/models/asformer.py', '../../_base_/default_runtime.py'
+] # dataset settings
+dataset_type = 'ActionSegmentDataset'
+data_root = 'data/action_seg/50salads/'
+data_root_val = 'data/action_seg/50salads/'
+ann_file_train = 'data/action_seg/50salads/splits/train.split2.bundle'
+ann_file_val = 'data/action_seg/50salads/splits/test.split2.bundle'
+ann_file_test = 'data/action_seg/50salads/splits/test.split2.bundle'
+
+model = dict(
+ type='ASFormer',
+ num_layers=10,
+ num_f_maps=64,
+ input_dim=2048,
+ num_decoders=3,
+ num_classes=19,
+ channel_masking_rate=0.3,
+ sample_rate=2,
+ r1=2,
+ r2=2)
+
+train_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=(
+ 'num_classes',
+ 'actions_dict',
+ 'index2label',
+ 'ground_truth',
+ 'classes',
+ ))
+]
+
+val_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=('num_classes', 'actions_dict', 'index2label',
+ 'ground_truth', 'classes'))
+]
+
+test_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=('num_classes', 'actions_dict', 'index2label',
+ 'ground_truth', 'classes'))
+]
+
+train_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ drop_last=True,
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_train,
+ data_prefix=dict(video=data_root),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_val,
+ data_prefix=dict(video=data_root_val),
+ pipeline=val_pipeline,
+ test_mode=True))
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_test,
+ data_prefix=dict(video=data_root_val),
+ pipeline=test_pipeline,
+ test_mode=True))
+
+max_epochs = 120
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=max_epochs,
+ val_begin=0,
+ val_interval=10)
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5))
+param_scheduler = [
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epochs,
+ by_epoch=True,
+ milestones=[
+ 80,
+ 100,
+ ],
+ gamma=0.5)
+]
+work_dir = './work_dirs/50salads2/'
+test_evaluator = dict(
+ type='SegmentMetric',
+ metric_type='ALL',
+ dump_config=dict(out=f'{work_dir}/results.json', output_format='json'))
+val_evaluator = test_evaluator
+default_hooks = dict(checkpoint=dict(interval=10, max_keep_ckpts=3))
diff --git a/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py b/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py
new file mode 100644
index 0000000000..2e3154a562
--- /dev/null
+++ b/configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py
@@ -0,0 +1,121 @@
+_base_ = [
+ '../../_base_/models/asformer.py', '../../_base_/default_runtime.py'
+] # dataset settings
+dataset_type = 'ActionSegmentDataset'
+data_root = 'data/action_seg/breakfast/'
+data_root_val = 'data/action_seg/breakfast/'
+ann_file_train = 'data/action_seg/breakfast/splits/train.split2.bundle'
+ann_file_val = 'data/action_seg/breakfast/splits/test.split2.bundle'
+ann_file_test = 'data/action_seg/breakfast/splits/test.split2.bundle'
+
+model = dict(
+ type='ASFormer',
+ channel_masking_rate=0.3,
+ input_dim=2048,
+ num_classes=48,
+ num_decoders=3,
+ num_f_maps=64,
+ num_layers=10,
+ r1=2,
+ r2=2,
+ sample_rate=1)
+
+train_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=(
+ 'num_classes',
+ 'actions_dict',
+ 'index2label',
+ 'ground_truth',
+ 'classes',
+ ))
+]
+
+val_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=('num_classes', 'actions_dict', 'index2label',
+ 'ground_truth', 'classes'))
+]
+
+test_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=('num_classes', 'actions_dict', 'index2label',
+ 'ground_truth', 'classes'))
+]
+
+train_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ drop_last=True,
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_train,
+ data_prefix=dict(video=data_root),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_val,
+ data_prefix=dict(video=data_root_val),
+ pipeline=val_pipeline,
+ test_mode=True))
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_test,
+ data_prefix=dict(video=data_root_val),
+ pipeline=test_pipeline,
+ test_mode=True))
+
+max_epochs = 120
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=max_epochs,
+ val_begin=0,
+ val_interval=5)
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5))
+param_scheduler = [
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epochs,
+ by_epoch=True,
+ milestones=[
+ 80,
+ 100,
+ ],
+ gamma=0.5)
+]
+
+work_dir = './work_dirs/breakfast2/'
+test_evaluator = dict(
+ type='SegmentMetric',
+ metric_type='ALL',
+ dump_config=dict(out=f'{work_dir}/results.json', output_format='json'))
+val_evaluator = test_evaluator
+default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=3))
diff --git a/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py b/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py
new file mode 100644
index 0000000000..cc0f2e80ee
--- /dev/null
+++ b/configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py
@@ -0,0 +1,109 @@
+_base_ = [
+ '../../_base_/models/asformer.py', '../../_base_/default_runtime.py'
+] # dataset settings
+dataset_type = 'ActionSegmentDataset'
+data_root = 'data/action_seg/gtea/'
+data_root_val = 'data/action_seg/gtea/'
+ann_file_train = 'data/action_seg/gtea/splits/train.split2.bundle'
+ann_file_val = 'data/action_seg/gtea/splits/test.split2.bundle'
+ann_file_test = 'data/action_seg/gtea/splits/test.split2.bundle'
+
+train_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=(
+ 'num_classes',
+ 'actions_dict',
+ 'index2label',
+ 'ground_truth',
+ 'classes',
+ ))
+]
+
+val_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=('num_classes', 'actions_dict', 'index2label',
+ 'ground_truth', 'classes'))
+]
+
+test_pipeline = [
+ dict(type='LoadSegmentationFeature'),
+ dict(
+ type='PackSegmentationInputs',
+ keys=('classes', ),
+ meta_keys=('num_classes', 'actions_dict', 'index2label',
+ 'ground_truth', 'classes'))
+]
+
+train_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ drop_last=True,
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_train,
+ data_prefix=dict(video=data_root),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_val,
+ data_prefix=dict(video=data_root_val),
+ pipeline=val_pipeline,
+ test_mode=True))
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_test,
+ data_prefix=dict(video=data_root_val),
+ pipeline=test_pipeline,
+ test_mode=True))
+
+max_epochs = 120
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=max_epochs,
+ val_begin=0,
+ val_interval=5)
+
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+optim_wrapper = dict(optimizer=dict(type='Adam', lr=0.0005, weight_decay=1e-5))
+param_scheduler = [
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epochs,
+ by_epoch=True,
+ milestones=[
+ 80,
+ 100,
+ ],
+ gamma=0.5)
+]
+
+work_dir = './work_dirs/gtea2/'
+test_evaluator = dict(
+ type='SegmentMetric',
+ metric_type='ALL',
+ dump_config=dict(out=f'{work_dir}/results.json', output_format='json'))
+val_evaluator = test_evaluator
+default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=3))
diff --git a/configs/segmentation/asformer/metafile.yml b/configs/segmentation/asformer/metafile.yml
new file mode 100644
index 0000000000..5915efb531
--- /dev/null
+++ b/configs/segmentation/asformer/metafile.yml
@@ -0,0 +1,70 @@
+Collections:
+- Name: ASFormer
+ README: configs/segmentation/asformer/README.md
+ Paper:
+ URL: https://arxiv.org/pdf/2110.08568.pdf
+ Title: "ASFormer: Transformer for Action Segmentation"
+
+Models:
+ - Name: asformer_1xb1-120e_gtea-split2-i3d-feature
+ Config: configs/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.py
+ In Collection: ASFormer
+ Metadata:
+ Batch Size: 1
+ Epochs: 120
+ Training Data: GTEA
+ Training Resources: 1 GPU
+ Modality: RGB
+ Results:
+ - Dataset: GTEA split2
+ Task: Action Segmentation
+ Metrics:
+ Acc: 80.34
+ Edit: 81.58
+ F1@10: 89.30
+ F1@25: 87.83
+ F1@50: 75.28
+ Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.log
+ Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature_20231011-25dc57d5.pth
+
+ - Name: asformer_1xb1-120e_50salads-split2-i3d-feature
+ Config: configs/segmentation/asformer/asformer_1xb1-120e_50salads-split2-i3d-feature.py
+ In Collection: ASFormer
+ Metadata:
+ Batch Size: 1
+ Epochs: 120
+ Training Data: 50salads
+ Training Resources: 1 GPU
+ Modality: RGB
+ Results:
+ - Dataset: 50salads split2
+ Task: Action Segmentation
+ Metrics:
+ Acc: 87.55
+ Edit: 79.10
+ F1@10: 85.17
+ F1@25: 83.73
+ F1@50: 77.99
+ Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature.log
+ Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_gtea-split2-i3d-feature_20231011-b5aaf789.pth
+
+ - Name: asformer_1xb1-120e_breakfast-split2-i3d-feature
+ Config: configs/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.py
+ In Collection: ASFormer
+ Metadata:
+ Batch Size: 1
+ Epochs: 120
+ Training Data: breakfast
+ Training Resources: 1 GPU
+ Modality: RGB
+ Results:
+ - Dataset: breakfast split2
+ Task: Action Segmentation
+ Metrics:
+ Acc: 74.12
+ Edit: 76.53
+ F1@10: 77.74
+ F1@25: 72.62
+ F1@50: 60.43
+ Training Log: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature.log
+ Weights: https://download.openmmlab.com/mmaction/v1.0/segmentation/asformer/asformer_1xb1-120e_breakfast-split2-i3d-feature_20231011-10e557f3.pth
diff --git a/mmaction/datasets/__init__.py b/mmaction/datasets/__init__.py
index eef565309d..be417d5a81 100644
--- a/mmaction/datasets/__init__.py
+++ b/mmaction/datasets/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .action_segment_dataset import ActionSegmentDataset
from .activitynet_dataset import ActivityNetDataset
from .audio_dataset import AudioDataset
from .ava_dataset import AVADataset, AVAKineticsDataset
@@ -16,5 +17,6 @@
'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset',
'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset',
'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset',
- 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC', 'CharadesSTADataset'
+ 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC', 'CharadesSTADataset',
+ 'ActionSegmentDataset'
]
diff --git a/mmaction/datasets/action_segment_dataset.py b/mmaction/datasets/action_segment_dataset.py
new file mode 100644
index 0000000000..120a511dc6
--- /dev/null
+++ b/mmaction/datasets/action_segment_dataset.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Callable, List, Optional, Union
+
+from mmengine.fileio import exists
+
+from mmaction.registry import DATASETS
+from mmaction.utils import ConfigType
+from .base import BaseActionDataset
+
+
+@DATASETS.register_module()
+class ActionSegmentDataset(BaseActionDataset):
+
+ def __init__(self,
+ ann_file: str,
+ pipeline: List[Union[dict, Callable]],
+ data_prefix: Optional[ConfigType] = dict(video=''),
+ test_mode: bool = False,
+ **kwargs):
+
+ super().__init__(
+ ann_file,
+ pipeline=pipeline,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self) -> List[dict]:
+ """Load annotation file to get video information."""
+ exists(self.ann_file)
+ file_ptr = open(self.ann_file, 'r') # read bundle
+ list_of_examples = file_ptr.read().split('\n')[:-1]
+ file_ptr.close()
+ gts = [
+ self.data_prefix['video'] + 'groundTruth/' + vid
+ for vid in list_of_examples
+ ]
+ features_npy = [
+ self.data_prefix['video'] + 'features/' + vid.split('.')[0] +
+ '.npy' for vid in list_of_examples
+ ]
+ data_list = []
+
+ file_ptr_d = open(self.data_prefix['video'] + '/mapping.txt', 'r')
+ actions = file_ptr_d.read().split('\n')[:-1]
+ file_ptr.close()
+ actions_dict = dict()
+ for a in actions:
+ actions_dict[a.split()[1]] = int(a.split()[0])
+ index2label = dict()
+ for k, v in actions_dict.items():
+ index2label[v] = k
+ num_classes = len(actions_dict)
+
+ # gts:txt list of examples:txt features_npy:npy
+ for idx, feature in enumerate(features_npy):
+ video_info = dict()
+ feature_path = features_npy[idx]
+ video_info['feature_path'] = feature_path
+ video_info['actions_dict'] = actions_dict
+ video_info['index2label'] = index2label
+ video_info['ground_truth_path'] = gts[idx]
+ video_info['num_classes'] = num_classes
+ data_list.append(video_info)
+ return data_list
diff --git a/mmaction/datasets/transforms/__init__.py b/mmaction/datasets/transforms/__init__.py
index 3d1ee91e27..d08316ba9e 100644
--- a/mmaction/datasets/transforms/__init__.py
+++ b/mmaction/datasets/transforms/__init__.py
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import (FormatAudioShape, FormatGCNInput, FormatShape,
- PackActionInputs, PackLocalizationInputs, Transpose)
+ PackActionInputs, PackLocalizationInputs,
+ PackSegmentationInputs, Transpose)
from .loading import (ArrayDecode, AudioFeatureSelector, BuildPseudoClip,
DecordDecode, DecordInit, DenseSampleFrames,
GenerateLocalizationLabels, ImageDecode,
LoadAudioFeature, LoadHVULabel, LoadLocalizationFeature,
- LoadProposals, LoadRGBFromFile, OpenCVDecode, OpenCVInit,
- PIMSDecode, PIMSInit, PyAVDecode, PyAVDecodeMotionVector,
- PyAVInit, RawFrameDecode, SampleAVAFrames, SampleFrames,
+ LoadProposals, LoadRGBFromFile, LoadSegmentationFeature,
+ OpenCVDecode, OpenCVInit, PIMSDecode, PIMSInit,
+ PyAVDecode, PyAVDecodeMotionVector, PyAVInit,
+ RawFrameDecode, SampleAVAFrames, SampleFrames,
UniformSample, UntrimmedSampleFrames)
from .pose_transforms import (DecompressPose, GeneratePoseTarget, GenSkeFeat,
JointToBone, MergeSkeFeat, MMCompact, MMDecode,
@@ -37,5 +39,5 @@
'SampleAVAFrames', 'SampleFrames', 'TenCrop', 'ThreeCrop', 'ToMotion',
'TorchVisionWrapper', 'Transpose', 'UniformSample', 'UniformSampleFrames',
'UntrimmedSampleFrames', 'MMUniformSampleFrames', 'MMDecode', 'MMCompact',
- 'CLIPTokenize'
+ 'CLIPTokenize', 'LoadSegmentationFeature', 'PackSegmentationInputs'
]
diff --git a/mmaction/datasets/transforms/formatting.py b/mmaction/datasets/transforms/formatting.py
index 0ae1475c8b..cec44090e4 100644
--- a/mmaction/datasets/transforms/formatting.py
+++ b/mmaction/datasets/transforms/formatting.py
@@ -168,6 +168,62 @@ def __repr__(self) -> str:
return repr_str
+@TRANSFORMS.register_module()
+class PackSegmentationInputs(BaseTransform):
+
+ def __init__(self, keys=(), meta_keys=('video_name', )):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def transform(self, results):
+ """Method to pack the input data.
+
+ Args:
+ results (dict): Result dict from the data pipeline.
+
+ Returns:
+ dict:
+
+ - 'inputs' (obj:`torch.Tensor`): The forward data of models.
+ - 'data_samples' (obj:`DetDataSample`): The annotation info of the
+ sample.
+ """
+ packed_results = dict()
+ if 'raw_feature' in results:
+ raw_feature = results['raw_feature']
+ packed_results['inputs'] = to_tensor(raw_feature)
+ else:
+ raise ValueError('Cannot get "raw_feature" in the input '
+ 'dict of `PackSegmentationInputs`.')
+
+ data_sample = ActionDataSample()
+ for key in self.keys:
+ if key not in results:
+ continue
+ if key == 'classes':
+ instance_data = InstanceData()
+ instance_data[key] = to_tensor(results[key])
+ data_sample.gt_instances = instance_data
+ elif key == 'proposals':
+ instance_data = InstanceData()
+ instance_data[key] = to_tensor(results[key])
+ data_sample.proposals = instance_data
+ else:
+ raise NotImplementedError(
+ f"Key '{key}' is not supported in `PackSegmentationInputs`"
+ )
+
+ img_meta = {k: results[k] for k in self.meta_keys if k in results}
+ data_sample.set_metainfo(img_meta)
+ packed_results['data_samples'] = data_sample
+ return packed_results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ repr_str += f'(meta_keys={self.meta_keys})'
+ return repr_str
+
+
@TRANSFORMS.register_module()
class Transpose(BaseTransform):
"""Transpose image channels to a given order.
diff --git a/mmaction/datasets/transforms/loading.py b/mmaction/datasets/transforms/loading.py
index 8d789ab4c3..c5101dcdae 100644
--- a/mmaction/datasets/transforms/loading.py
+++ b/mmaction/datasets/transforms/loading.py
@@ -1784,6 +1784,43 @@ def __repr__(self) -> str:
return repr_str
+@TRANSFORMS.register_module()
+class LoadSegmentationFeature(BaseTransform):
+ """Load Video features for localizer with given video_name list.
+
+ The required key are "feature_path", "ground_truth_path",
+ added or modified keys are "actions_dict", "raw_feature",
+ "ground_truth", "classes".
+
+ Args:
+ raw_feature_ext (str): Raw feature file extension. Default: '.csv'.
+ """
+
+ def transform(self, results):
+ """Perform the LoadSegmentationFeature loading.
+
+ Args:
+ results (dict): The resulting dict to be modified and passed
+ to the next transform in pipeline.
+ """
+ raw_feature = np.load(results['feature_path'])
+ file_ptr = open(results['ground_truth_path'], 'r')
+ content = file_ptr.read().split('\n')[:-1]
+ classes = np.zeros(min(np.shape(raw_feature)[1], len(content)))
+ for i in range(len(classes)):
+ classes[i] = results['actions_dict'][content[i]]
+
+ results['raw_feature'] = raw_feature
+ results['ground_truth'] = content
+ results['classes'] = classes
+
+ return results
+
+ def __repr__(self):
+ repr_str = f'{self.__class__.__name__}'
+ return repr_str
+
+
@TRANSFORMS.register_module()
class LoadLocalizationFeature(BaseTransform):
"""Load Video features for localizer with given video_name list.
diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py
index fd50aded2e..96ad99c96a 100644
--- a/mmaction/evaluation/metrics/__init__.py
+++ b/mmaction/evaluation/metrics/__init__.py
@@ -5,10 +5,11 @@
from .multimodal_metric import VQAMCACC, ReportVQA, RetrievalRecall, VQAAcc
from .multisports_metric import MultiSportsMetric
from .retrieval_metric import RetrievalMetric
+from .segment_metric import SegmentMetric
from .video_grounding_metric import RecallatTopK
__all__ = [
'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix',
'MultiSportsMetric', 'RetrievalMetric', 'VQAAcc', 'ReportVQA', 'VQAMCACC',
- 'RetrievalRecall', 'RecallatTopK'
+ 'RetrievalRecall', 'RecallatTopK', 'SegmentMetric'
]
diff --git a/mmaction/evaluation/metrics/segment_metric.py b/mmaction/evaluation/metrics/segment_metric.py
new file mode 100644
index 0000000000..b9e2c45974
--- /dev/null
+++ b/mmaction/evaluation/metrics/segment_metric.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+from typing import Any, Optional, Sequence, Tuple
+
+import numpy as np
+from mmengine.evaluator import BaseMetric
+
+from mmaction.registry import METRICS
+from mmaction.utils import ConfigType
+
+
+@METRICS.register_module()
+class SegmentMetric(BaseMetric):
+ """Action Segmentation dataset evaluation metric."""
+
+ def __init__(self,
+ metric_type: str = 'ALL',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None,
+ metric_options: dict = {},
+ dump_config: ConfigType = dict(out='')):
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.metric_type = metric_type
+
+ assert metric_type == 'ALL'
+ assert 'out' in dump_config
+ self.output_format = dump_config.pop('output_format', 'csv')
+ self.out = dump_config['out']
+
+ self.metric_options = metric_options
+
+ def process(self, data_batch: Sequence[Tuple[Any, dict]],
+ predictions: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[Tuple[Any, dict]]): A batch of data
+ from the dataloader.
+ predictions (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for pred in predictions:
+ self.results.append(pred)
+
+ if self.metric_type == 'ALL':
+ data_batch = data_batch['data_samples']
+
+ def compute_metrics(self, results: list) -> dict:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+ Returns:
+ dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ if self.metric_type == 'ALL':
+ return self.compute_ALL(results)
+ return OrderedDict()
+
+ def compute_ALL(self, results: list) -> dict:
+ """ALL evaluation metric."""
+ eval_results = OrderedDict()
+ overlap = [.1, .25, .5]
+ tp, fp, fn = np.zeros(3), np.zeros(3), np.zeros(3)
+
+ correct = 0
+ total = 0
+ edit = 0
+
+ for vid in self.results:
+
+ gt_content = vid['ground']
+ recog_content = vid['recognition']
+
+ for i in range(len(gt_content)):
+ total += 1
+ if gt_content[i] == recog_content[i]:
+ correct += 1
+
+ edit += self.edit_score(recog_content, gt_content)
+
+ for s in range(len(overlap)):
+ tp1, fp1, fn1 = self.f_score(recog_content, gt_content,
+ overlap[s])
+ tp[s] += tp1
+ fp[s] += fp1
+ fn[s] += fn1
+ eval_results['Acc'] = 100 * float(correct) / total
+ eval_results['Edit'] = (1.0 * edit) / len(self.results)
+ f1s = np.array([0, 0, 0], dtype=float)
+ for s in range(len(overlap)):
+ precision = tp[s] / float(tp[s] + fp[s])
+ recall = tp[s] / float(tp[s] + fn[s])
+
+ f1 = 2.0 * (precision * recall) / (precision + recall)
+
+ f1 = np.nan_to_num(f1) * 100
+ f1s[s] = f1
+
+ eval_results['F1@10'] = f1s[0]
+ eval_results['F1@25'] = f1s[1]
+ eval_results['F1@50'] = f1s[2]
+
+ return eval_results
+
+ def f_score(self,
+ recognized,
+ ground_truth,
+ overlap,
+ bg_class=['background']):
+ p_label, p_start, p_end = self.get_labels_start_end_time(
+ recognized, bg_class)
+ y_label, y_start, y_end = self.get_labels_start_end_time(
+ ground_truth, bg_class)
+
+ tp = 0
+ fp = 0
+
+ hits = np.zeros(len(y_label))
+
+ for j in range(len(p_label)):
+ intersection = np.minimum(p_end[j], y_end) - np.maximum(
+ p_start[j], y_start)
+ union = np.maximum(p_end[j], y_end) - np.minimum(
+ p_start[j], y_start)
+ IoU = (1.0 * intersection / union) * (
+ [p_label[j] == y_label[x] for x in range(len(y_label))])
+ # Get the best scoring segment
+ idx = np.array(IoU).argmax()
+
+ if IoU[idx] >= overlap and not hits[idx]:
+ tp += 1
+ hits[idx] = 1
+ else:
+ fp += 1
+ fn = len(y_label) - sum(hits)
+ return float(tp), float(fp), float(fn)
+
+ def edit_score(self,
+ recognized,
+ ground_truth,
+ norm=True,
+ bg_class=['background']):
+ P, _, _ = self.get_labels_start_end_time(recognized, bg_class)
+ Y, _, _ = self.get_labels_start_end_time(ground_truth, bg_class)
+ return self.levenstein(P, Y, norm)
+
+ def get_labels_start_end_time(self,
+ frame_wise_labels,
+ bg_class=['background']):
+ labels = []
+ starts = []
+ ends = []
+ last_label = frame_wise_labels[0]
+ if frame_wise_labels[0] not in bg_class:
+ labels.append(frame_wise_labels[0])
+ starts.append(0)
+ for i in range(len(frame_wise_labels)):
+ if frame_wise_labels[i] != last_label:
+ if frame_wise_labels[i] not in bg_class:
+ labels.append(frame_wise_labels[i])
+ starts.append(i)
+ if last_label not in bg_class:
+ ends.append(i)
+ last_label = frame_wise_labels[i]
+ if last_label not in bg_class:
+ ends.append(i)
+ return labels, starts, ends
+
+ def levenstein(self, p, y, norm=False):
+ m_row = len(p)
+ n_col = len(y)
+ D = np.zeros([m_row + 1, n_col + 1], np.float64)
+ for i in range(m_row + 1):
+ D[i, 0] = i
+ for i in range(n_col + 1):
+ D[0, i] = i
+
+ for j in range(1, n_col + 1):
+ for i in range(1, m_row + 1):
+ if y[j - 1] == p[i - 1]:
+ D[i, j] = D[i - 1, j - 1]
+ else:
+ D[i, j] = min(D[i - 1, j] + 1, D[i, j - 1] + 1,
+ D[i - 1, j - 1] + 1)
+
+ if norm:
+ score = (1 - D[-1, -1] / max(m_row, n_col)) * 100
+ else:
+ score = D[-1, -1]
+
+ return score
diff --git a/mmaction/models/__init__.py b/mmaction/models/__init__.py
index 08f7d41f52..4b0436e1d9 100644
--- a/mmaction/models/__init__.py
+++ b/mmaction/models/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .action_segmentors import * # noqa: F401,F403
from .backbones import * # noqa: F401,F403
from .common import * # noqa: F401,F403
from .data_preprocessors import * # noqa: F401,F403
diff --git a/mmaction/models/action_segmentors/__init__.py b/mmaction/models/action_segmentors/__init__.py
new file mode 100644
index 0000000000..eeb81b53e9
--- /dev/null
+++ b/mmaction/models/action_segmentors/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .asformer import ASFormer
+
+__all__ = [
+ 'ASFormer',
+]
diff --git a/mmaction/models/action_segmentors/asformer.py b/mmaction/models/action_segmentors/asformer.py
new file mode 100644
index 0000000000..25a072016a
--- /dev/null
+++ b/mmaction/models/action_segmentors/asformer.py
@@ -0,0 +1,635 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmengine.model import BaseModel
+
+from mmaction.registry import MODELS
+
+
+@MODELS.register_module()
+class ASFormer(BaseModel):
+ """Boundary Matching Network for temporal action proposal generation."""
+
+ def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim,
+ num_classes, channel_masking_rate, sample_rate):
+ super().__init__()
+ self.model = MyTransformer(3, num_layers, r1, r2, num_f_maps,
+ input_dim, num_classes,
+ channel_masking_rate)
+ print('Model Size: ', sum(p.numel() for p in self.model.parameters()))
+ self.num_classes = num_classes
+ self.mse = MODELS.build(dict(type='MeanSquareErrorLoss'))
+ self.ce = MODELS.build(dict(type='CrossEntropyLoss'))
+
+ def init_weights(self) -> None:
+ """Initiate the parameters from scratch."""
+ pass
+
+ def forward(self, inputs, data_samples, mode, **kwargs):
+ """The unified entry for a forward process in both training and test.
+
+ The method should accept three modes:
+
+ - ``tensor``: Forward the whole network and return tensor or tuple of
+ tensor without any post-processing, same as a common nn.Module.
+ - ``predict``: Forward and return the predictions, which are fully
+ processed to a list of :obj:`ActionDataSample`.
+ - ``loss``: Forward and return a dict of losses according to the given
+ inputs and data samples.
+
+ Note that this method doesn't handle neither back propagation nor
+ optimizer updating, which are done in the :meth:`train_step`.
+
+ Args:
+ inputs (Tensor): The input tensor with shape
+ (N, C, ...) in general.
+ data_samples (List[:obj:`ActionDataSample`], optional): The
+ annotation data of every samples. Defaults to None.
+ mode (str): Return what kind of value. Defaults to ``tensor``.
+
+ Returns:
+ The return type depends on ``mode``.
+
+ - If ``mode="tensor"``, return a tensor or a tuple of tensor.
+ - If ``mode="predict"``, return a list of ``ActionDataSample``.
+ - If ``mode="loss"``, return a dict of tensor.
+ """
+ input = torch.stack(inputs)
+ if mode == 'predict':
+ return self.predict(input, data_samples, **kwargs)
+ elif mode == 'loss':
+ return self.loss(input, data_samples, **kwargs)
+ else:
+ raise RuntimeError(f'Invalid mode "{mode}". '
+ 'Only supports loss, predict and tensor mode')
+
+ def loss(self, batch_inputs, batch_data_samples, **kwargs):
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ batch_inputs (Tensor): Raw Inputs of the recognizer.
+ These should usually be mean centered and std scaled.
+ batch_data_samples (List[:obj:`ActionDataSample`]): The batch
+ data samples. It usually includes information such
+ as ``gt_labels``.
+
+ Returns:
+ dict: A dictionary of loss components.
+ """
+ device = batch_inputs.device
+ batch_target_tensor = torch.ones(
+ len(batch_inputs),
+ max(tensor.size(1) for tensor in batch_inputs),
+ dtype=torch.long) * (-100)
+ mask = torch.zeros(
+ len(batch_inputs),
+ self.num_classes,
+ max(tensor.size(1) for tensor in batch_inputs),
+ dtype=torch.float)
+ for i in range(len(batch_inputs)):
+ batch_target_tensor[i,
+ :np.shape(batch_data_samples[i].classes)[0]] \
+ = torch.from_numpy(batch_data_samples[i].classes)
+
+ mask[i, i, :np.shape(batch_data_samples[i].classes)[0]] = \
+ torch.ones(self.num_classes,
+ np.shape(batch_data_samples[i].classes)[0])
+
+ batch_target_tensor = batch_target_tensor.to(device)
+ batch_target_tensor = batch_target_tensor.to(device)
+ mask = mask.to(device)
+ batch_inputs = batch_inputs.to(device)
+ ps = self.model(batch_inputs, mask)
+ loss = 0
+ for p in ps:
+ loss += self.ce(
+ p.transpose(2, 1).contiguous().view(-1, self.num_classes),
+ batch_target_tensor.view(-1),
+ ignore_index=-100)
+ loss += 0.15 * torch.mean(
+ torch.clamp(
+ self.mse(
+ F.log_softmax(p[:, :, 1:], dim=1),
+ F.log_softmax(p.detach()[:, :, :-1], dim=1)),
+ min=0,
+ max=16) * mask[:, :, 1:])
+
+ loss_dict = dict(loss=loss)
+ return loss_dict
+
+ def predict(self, batch_inputs, batch_data_samples, **kwargs):
+ """Define the computation performed at every call when testing."""
+ device = batch_inputs.device
+ actions_dict = batch_data_samples[0].actions_dict
+ batch_target_tensor = torch.ones(
+ len(batch_inputs),
+ max(tensor.size(1) for tensor in batch_inputs),
+ dtype=torch.long) * (-100)
+ batch_target = [
+ data_sample.classes for data_sample in batch_data_samples
+ ]
+ mask = torch.zeros(
+ len(batch_inputs),
+ self.num_classes,
+ max(tensor.size(1) for tensor in batch_inputs),
+ dtype=torch.float)
+ for i in range(len(batch_inputs)):
+ batch_target_tensor[i, :np.shape(batch_data_samples[i].classes
+ )[0]] = torch.from_numpy(
+ batch_data_samples[i].classes)
+ mask[i, :, :np.
+ shape(batch_data_samples[i].classes)[0]] = torch.ones(
+ self.num_classes,
+ np.shape(batch_data_samples[i].classes)[0])
+ batch_target_tensor = batch_target_tensor.to(device)
+ mask = mask.to(device)
+ batch_inputs = batch_inputs.to(device)
+ predictions = self.model(batch_inputs, mask)
+ for i in range(len(predictions)):
+ confidence, predicted = torch.max(
+ F.softmax(predictions[i], dim=1).data, 1)
+ confidence, predicted = confidence.squeeze(), predicted.squeeze()
+ confidence, predicted = confidence.squeeze(), predicted.squeeze()
+ recognition = []
+ ground = [
+ batch_data_samples[0].index2label[idx] for idx in batch_target[0]
+ ]
+ for i in range(len(predicted)):
+ recognition = np.concatenate((recognition, [
+ list(actions_dict.keys())[list(actions_dict.values()).index(
+ predicted[i].item())]
+ ]))
+ output = [dict(ground=ground, recognition=recognition)]
+ return output
+
+
+def exponential_descrease(idx_decoder, p=3):
+ return math.exp(-p * idx_decoder)
+
+
+class AttentionHelper(nn.Module):
+
+ def __init__(self):
+ super(AttentionHelper, self).__init__()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask):
+ """scalar dot attention.
+
+ :param proj_query: shape of (B, C, L) =>
+ (Batch_Size, Feature_Dimension, Length)
+ :param proj_key: shape of (B, C, L)
+ :param proj_val: shape of (B, C, L)
+ :param padding_mask: shape of (B, C, L)
+ :return: attention value of shape (B, C, L)
+ """
+ m, c1, l1 = proj_query.shape
+ m, c2, l2 = proj_key.shape
+
+ assert c1 == c2
+
+ energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key)
+ attention = energy / np.sqrt(c1)
+ attention = attention + torch.log(padding_mask + 1e-6)
+ attention = self.softmax(attention)
+ attention = attention * padding_mask
+ attention = attention.permute(0, 2, 1)
+ out = torch.bmm(proj_val, attention)
+ return out, attention
+
+
+class AttLayer(nn.Module):
+
+ def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage,
+ att_type): # r1 = r2 (2)
+ super(AttLayer, self).__init__()
+ self.query_conv = nn.Conv1d(
+ in_channels=q_dim, out_channels=q_dim // r1, kernel_size=1)
+ self.key_conv = nn.Conv1d(
+ in_channels=k_dim, out_channels=k_dim // r2, kernel_size=1)
+ self.value_conv = nn.Conv1d(
+ in_channels=v_dim, out_channels=v_dim // r3, kernel_size=1)
+
+ self.conv_out = nn.Conv1d(
+ in_channels=v_dim // r3, out_channels=v_dim, kernel_size=1)
+
+ self.bl = bl
+ self.stage = stage
+ self.att_type = att_type
+ assert self.att_type in ['normal_att', 'block_att', 'sliding_att']
+ assert self.stage in ['encoder', 'decoder']
+
+ self.att_helper = AttentionHelper()
+ self.window_mask = self.construct_window_mask()
+
+ def construct_window_mask(self):
+ """construct window mask of shape (1, l, l + l//2 + l//2), used for
+ sliding window self attention."""
+ window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2)))
+ for i in range(self.bl):
+ window_mask[:, i, i:i + self.bl] = 1
+ return window_mask
+
+ def forward(self, x1, x2, mask):
+ query = self.query_conv(x1)
+ key = self.key_conv(x1)
+
+ if self.stage == 'decoder':
+ assert x2 is not None
+ value = self.value_conv(x2)
+ else:
+ value = self.value_conv(x1)
+
+ if self.att_type == 'normal_att':
+ return self._normal_self_att(query, key, value, mask)
+ elif self.att_type == 'block_att':
+ return self._block_wise_self_att(query, key, value, mask)
+ elif self.att_type == 'sliding_att':
+ return self._sliding_window_self_att(query, key, value, mask)
+
+ def _normal_self_att(self, q, k, v, mask):
+ device = q.device
+ m_batchsize, c1, L = q.size()
+ _, c2, L = k.size()
+ _, c3, L = v.size()
+ padding_mask = torch.ones(
+ (m_batchsize, 1, L)).to(device) * mask[:, 0:1, :]
+ output, attentions = self.att_helper.scalar_dot_att(
+ q, k, v, padding_mask)
+ output = self.conv_out(F.relu(output))
+ output = output[:, :, 0:L]
+ return output * mask[:, 0:1, :]
+
+ def _block_wise_self_att(self, q, k, v, mask):
+ device = q.device
+ m_batchsize, c1, L = q.size()
+ _, c2, L = k.size()
+ _, c3, L = v.size()
+
+ nb = L // self.bl
+ if L % self.bl != 0:
+ q = torch.cat([
+ q,
+ torch.zeros(
+ (m_batchsize, c1, self.bl - L % self.bl)).to(device)
+ ],
+ dim=-1)
+ k = torch.cat([
+ k,
+ torch.zeros(
+ (m_batchsize, c2, self.bl - L % self.bl)).to(device)
+ ],
+ dim=-1)
+ v = torch.cat([
+ v,
+ torch.zeros(
+ (m_batchsize, c3, self.bl - L % self.bl)).to(device)
+ ],
+ dim=-1)
+ nb += 1
+
+ padding_mask = torch.cat([
+ torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
+ torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device)
+ ],
+ dim=-1)
+
+ q = q.reshape(m_batchsize, c1, nb,
+ self.bl).permute(0, 2, 1,
+ 3).reshape(m_batchsize * nb, c1,
+ self.bl)
+ padding_mask = padding_mask.reshape(
+ m_batchsize, 1, nb,
+ self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb, 1, self.bl)
+ k = k.reshape(m_batchsize, c2, nb,
+ self.bl).permute(0, 2, 1,
+ 3).reshape(m_batchsize * nb, c2,
+ self.bl)
+ v = v.reshape(m_batchsize, c3, nb,
+ self.bl).permute(0, 2, 1,
+ 3).reshape(m_batchsize * nb, c3,
+ self.bl)
+
+ output, attentions = self.att_helper.scalar_dot_att(
+ q, k, v, padding_mask)
+ output = self.conv_out(F.relu(output))
+
+ output = output.reshape(m_batchsize, nb, c3, self.bl).permute(
+ 0, 2, 1, 3).reshape(m_batchsize, c3, nb * self.bl)
+ output = output[:, :, 0:L]
+ return output * mask[:, 0:1, :]
+
+ def _sliding_window_self_att(self, q, k, v, mask):
+ device = q.device
+ m_batchsize, c1, L = q.size()
+ _, c2, _ = k.size()
+ _, c3, _ = v.size()
+ nb = L // self.bl
+ if L % self.bl != 0:
+ q = torch.cat([
+ q,
+ torch.zeros(
+ (m_batchsize, c1, self.bl - L % self.bl)).to(device)
+ ],
+ dim=-1)
+ k = torch.cat([
+ k,
+ torch.zeros(
+ (m_batchsize, c2, self.bl - L % self.bl)).to(device)
+ ],
+ dim=-1)
+ v = torch.cat([
+ v,
+ torch.zeros(
+ (m_batchsize, c3, self.bl - L % self.bl)).to(device)
+ ],
+ dim=-1)
+ nb += 1
+ padding_mask = torch.cat([
+ torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
+ torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device)
+ ],
+ dim=-1)
+ q = q.reshape(m_batchsize, c1, nb,
+ self.bl).permute(0, 2, 1,
+ 3).reshape(m_batchsize * nb, c1,
+ self.bl)
+ k = torch.cat([
+ torch.zeros(m_batchsize, c2, self.bl // 2).to(device), k,
+ torch.zeros(m_batchsize, c2, self.bl // 2).to(device)
+ ],
+ dim=-1)
+ v = torch.cat([
+ torch.zeros(m_batchsize, c3, self.bl // 2).to(device), v,
+ torch.zeros(m_batchsize, c3, self.bl // 2).to(device)
+ ],
+ dim=-1)
+ padding_mask = torch.cat([
+ torch.zeros(m_batchsize, 1, self.bl // 2).to(device), padding_mask,
+ torch.zeros(m_batchsize, 1, self.bl // 2).to(device)
+ ],
+ dim=-1)
+ k = torch.cat([
+ k[:, :, i * self.bl:(i + 1) * self.bl + (self.bl // 2) * 2]
+ for i in range(nb)
+ ],
+ dim=0) # special case when self.bl = 1
+ v = torch.cat([
+ v[:, :, i * self.bl:(i + 1) * self.bl + (self.bl // 2) * 2]
+ for i in range(nb)
+ ],
+ dim=0)
+ padding_mask = torch.cat([
+ padding_mask[:, :, i * self.bl:(i + 1) * self.bl +
+ (self.bl // 2) * 2] for i in range(nb)
+ ],
+ dim=0) # of shape (m*nb, 1, 2l)
+ final_mask = self.window_mask.to(device).repeat(
+ m_batchsize * nb, 1, 1) * padding_mask
+
+ output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask)
+ output = self.conv_out(F.relu(output))
+
+ output = output.reshape(m_batchsize, nb, -1, self.bl).permute(
+ 0, 2, 1, 3).reshape(m_batchsize, -1, nb * self.bl)
+ output = output[:, :, 0:L]
+ return output * mask[:, 0:1, :]
+
+
+class MultiHeadAttLayer(nn.Module):
+
+ def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type,
+ num_head):
+ super(MultiHeadAttLayer, self).__init__()
+ # assert v_dim % num_head == 0
+ self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1)
+ self.layers = nn.ModuleList([
+ copy.deepcopy(
+ AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type))
+ for i in range(num_head)
+ ])
+ self.dropout = nn.Dropout(p=0.5)
+
+ def forward(self, x1, x2, mask):
+ out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1)
+ out = self.conv_out(self.dropout(out))
+ return out
+
+
+class ConvFeedForward(nn.Module):
+
+ def __init__(self, dilation, in_channels, out_channels):
+ super(ConvFeedForward, self).__init__()
+ self.layer = nn.Sequential(
+ nn.Conv1d(
+ in_channels,
+ out_channels,
+ 3,
+ padding=dilation,
+ dilation=dilation), nn.ReLU())
+
+ def forward(self, x):
+ """Define the computation performed at every call.
+
+ Args:
+ x (torch.Tensor): The input data.
+ Returns:
+ torch.Tensor: The output of the module.
+ """
+ return self.layer(x)
+
+
+class FCFeedForward(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super(FCFeedForward, self).__init__()
+ self.layer = nn.Sequential(
+ nn.Conv1d(in_channels, out_channels, 1), nn.ReLU(), nn.Dropout(),
+ nn.Conv1d(out_channels, out_channels, 1))
+
+ def forward(self, x):
+ """Define the computation performed at every call.
+
+ Args:
+ x (torch.Tensor): The input data.
+ Returns:
+ torch.Tensor: The output of the module.
+ """
+ return self.layer(x)
+
+
+class AttModule(nn.Module):
+
+ def __init__(self, dilation, in_channels, out_channels, r1, r2, att_type,
+ stage, alpha):
+ super(AttModule, self).__init__()
+ self.feed_forward = ConvFeedForward(dilation, in_channels,
+ out_channels)
+ self.instance_norm = nn.InstanceNorm1d(
+ in_channels, track_running_stats=False)
+ self.att_layer = AttLayer(
+ in_channels,
+ in_channels,
+ out_channels,
+ r1,
+ r1,
+ r2,
+ dilation,
+ att_type=att_type,
+ stage=stage) # dilation
+ self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
+ self.dropout = nn.Dropout()
+ self.alpha = alpha
+
+ def forward(self, x, f, mask):
+ out = self.feed_forward(x)
+ out = self.alpha * self.att_layer(self.instance_norm(out), f,
+ mask) + out
+ out = self.conv_1x1(out)
+ out = self.dropout(out)
+ return (x + out) * mask[:, 0:1, :]
+
+
+class PositionalEncoding(nn.Module):
+ """Implement the PE function."""
+
+ def __init__(self, d_model, max_len=10000):
+ super(PositionalEncoding, self).__init__()
+ # Compute the positional encodings once in log space.
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).permute(0, 2, 1) # of shape (1, d_model, l)
+ self.pe = nn.Parameter(pe, requires_grad=True)
+
+ # self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ """Define the computation performed at every call.
+
+ Args:
+ x (torch.Tensor): The input data.
+ Returns:
+ torch.Tensor: The output of the module.
+ """
+ return x + self.pe[:, :, 0:x.shape[2]]
+
+
+class Encoder(nn.Module):
+
+ def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes,
+ channel_masking_rate, att_type, alpha):
+ super(Encoder, self).__init__()
+ self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)
+ self.layers = nn.ModuleList([
+ AttModule(2**i, num_f_maps, num_f_maps, r1, r2, att_type,
+ 'encoder', alpha) for i in # 2**i
+ range(num_layers)
+ ])
+
+ self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
+ self.dropout = nn.Dropout2d(p=channel_masking_rate)
+ self.channel_masking_rate = channel_masking_rate
+
+ def forward(self, x, mask):
+ '''
+ :param x: (N, C, L)
+ :param mask:
+ :return:
+ '''
+
+ if self.channel_masking_rate > 0:
+ x = x.unsqueeze(2)
+ x = self.dropout(x)
+ x = x.squeeze(2)
+
+ feature = self.conv_1x1(x)
+ for layer in self.layers:
+ feature = layer(feature, None, mask)
+
+ out = self.conv_out(feature) * mask[:, 0:1, :]
+
+ return out, feature
+
+
+class Decoder(nn.Module):
+
+ def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes,
+ att_type, alpha):
+ super(Decoder, self).__init__(
+ ) # self.position_en = PositionalEncoding(d_model=num_f_maps)
+ self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)
+ self.layers = nn.ModuleList([
+ AttModule(2**i, num_f_maps, num_f_maps, r1, r2, att_type,
+ 'decoder', alpha) for i in # 2 ** i
+ range(num_layers)
+ ])
+ self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
+
+ def forward(self, x, fencoder, mask):
+
+ feature = self.conv_1x1(x)
+ for layer in self.layers:
+ feature = layer(feature, fencoder, mask)
+
+ out = self.conv_out(feature) * mask[:, 0:1, :]
+
+ return out, feature
+
+
+class MyTransformer(nn.Module):
+ """An encoder-decoder transformer."""
+
+ def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim,
+ num_classes, channel_masking_rate):
+ super(MyTransformer, self).__init__()
+ self.encoder = Encoder(
+ num_layers,
+ r1,
+ r2,
+ num_f_maps,
+ input_dim,
+ num_classes,
+ channel_masking_rate,
+ att_type='sliding_att',
+ alpha=1)
+ self.decoders = nn.ModuleList([
+ copy.deepcopy(
+ Decoder(
+ num_layers,
+ r1,
+ r2,
+ num_f_maps,
+ num_classes,
+ num_classes,
+ att_type='sliding_att',
+ alpha=exponential_descrease(s)))
+ for s in range(num_decoders)
+ ]) # num_decoders
+
+ def forward(self, x, mask):
+ """Define the computation performed at every call.
+
+ Args:
+ x (torch.Tensor): The input data.
+ Returns:
+ torch.Tensor: The output of the module.
+ """
+ out, feature = self.encoder(x, mask)
+ outputs = out.unsqueeze(0)
+
+ for decoder in self.decoders:
+ out, feature = decoder(
+ F.softmax(out, dim=1) * mask[:, 0:1, :],
+ feature * mask[:, 0:1, :], mask)
+ outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
+
+ return outputs
diff --git a/mmaction/models/losses/__init__.py b/mmaction/models/losses/__init__.py
index 41afcb7ace..b2ccee1c76 100644
--- a/mmaction/models/losses/__init__.py
+++ b/mmaction/models/losses/__init__.py
@@ -5,6 +5,7 @@
from .cross_entropy_loss import (BCELossWithLogits, CBFocalLoss,
CrossEntropyLoss)
from .hvu_loss import HVULoss
+from .mse_loss import MeanSquareErrorLoss
from .nll_loss import NLLLoss
from .ohem_hinge_loss import OHEMHingeLoss
from .ssn_loss import SSNLoss
@@ -12,5 +13,5 @@
__all__ = [
'BaseWeightedLoss', 'CrossEntropyLoss', 'NLLLoss', 'BCELossWithLogits',
'BinaryLogisticRegressionLoss', 'BMNLoss', 'OHEMHingeLoss', 'SSNLoss',
- 'HVULoss', 'CBFocalLoss'
+ 'HVULoss', 'CBFocalLoss', 'MeanSquareErrorLoss'
]
diff --git a/mmaction/models/losses/mse_loss.py b/mmaction/models/losses/mse_loss.py
new file mode 100644
index 0000000000..7f97bed123
--- /dev/null
+++ b/mmaction/models/losses/mse_loss.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+
+from mmaction.registry import MODELS
+from .base import BaseWeightedLoss
+
+
+@MODELS.register_module()
+class MeanSquareErrorLoss(BaseWeightedLoss):
+ """Mean Square Error Loss."""
+
+ def __init__(self, loss_weight: float = 1., reduction: str = 'none'):
+ super().__init__(loss_weight=loss_weight)
+ self.reduction = reduction
+
+ def _forward(self, cls_score: torch.Tensor, label: torch.Tensor,
+ **kwargs) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ cls_score (torch.Tensor): The class score.
+ label (torch.Tensor): The ground truth label.
+ kwargs: Any keyword argument to be used to calculate
+ MeanSquareError loss.
+
+ Returns:
+ torch.Tensor: The returned MeanSquareError loss.
+ """
+ if cls_score.size() == label.size():
+ assert len(kwargs) == 0, \
+ ('For now, no extra args are supported for soft label, '
+ f'but get {kwargs}')
+
+ loss_cls = F.mse_loss(cls_score, label, reduction=self.reduction)
+ return loss_cls
diff --git a/tools/data/action_seg/README.md b/tools/data/action_seg/README.md
new file mode 100644
index 0000000000..3ada793e2d
--- /dev/null
+++ b/tools/data/action_seg/README.md
@@ -0,0 +1,132 @@
+# Preparing Datasets for Action Segmentation
+
+## Introduction
+
+
+
+```BibTeX
+@inproceedings{fathi2011learning,
+ title={Learning to recognize objects in egocentric activities},
+ author={Fathi, Alireza and Ren, Xiaofeng and Rehg, James M},
+ booktitle={CVPR 2011},
+ pages={3281--3288},
+ year={2011},
+ organization={IEEE}
+}
+```
+
+```BibTeX
+@inproceedings{stein2013combining,
+ title={Combining embedded accelerometers with computer vision for recognizing food preparation activities},
+ author={Stein, Sebastian and McKenna, Stephen J},
+ booktitle={Proceedings of the 2013 ACM international joint conference on Pervasive and ubiquitous computing},
+ pages={729--738},
+ year={2013}
+}
+```
+
+```BibTeX
+@inproceedings{kuehne2014language,
+ title={The language of actions: Recovering the syntax and semantics of goal-directed human activities},
+ author={Kuehne, Hilde and Arslan, Ali and Serre, Thomas},
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
+ pages={780--787},
+ year={2014}
+}
+```
+
+For basic dataset information, you can refer to the articles.
+Before we start, please make sure that the directory is located at `$MMACTION2/tools/data/`.
+To run the bash scripts below, you need to install `unzip`. you can install it by `sudo apt-get install unzip`.
+
+## Step 1. Prepare Annotations and Features
+
+First of all, you can run the following script to prepare annotations and features.
+
+```shell
+bash download_datasets.sh
+```
+
+## Step 2. Preprocess the Data
+
+you can execute the following script to preprocess the downloaded data and generate two folders for each dataset, 'gt_arr' and 'gt_boundary_arr'.
+
+```shell
+python tools/data/action_seg/generate_boundary_array.py --dataset-dir action_seg
+python tools/data/action_seg/generate_gt_array.py --dataset_dir data/action_seg
+```
+
+## Step 3. Check Directory Structure
+
+After the whole data process for GTEA, 50Salads and Breakfast preparation,
+you will get the features, splits ,annotation files and groundtruth boundaries for the datasets.
+
+In the context of the whole project (for GTEA, 50Salads and Breakfast), the folder structure will look like:
+
+```
+mmaction2
+├── mmaction
+├── tools
+├── configs
+├── data
+│ ├── action_seg
+│ │ ├── gtea
+│ │ │ ├── features
+│ │ │ │ ├── S1_Cheese_C1.npy
+│ │ │ │ ├── S1_Coffee_C1.npy
+│ │ │ │ ├── ...
+│ │ │ ├── groundTruth
+│ │ │ │ ├── S1_Cheese_C1.txt
+│ │ │ │ ├── S1_Coffee_C1.txt
+│ │ │ │ ├── ...
+│ │ │ ├── gt_arr
+│ │ │ │ ├── S1_Cheese_C1.npy
+│ │ │ │ ├── S1_Coffee_C1.npy
+│ │ │ │ ├── ...
+│ │ │ ├── gt_boundary_arr
+│ │ │ │ ├── S1_Cheese_C1.npy
+│ │ │ │ ├── S1_Coffee_C1.npy
+│ │ │ │ ├── ...
+│ │ │ ├── splits
+│ │ │ │ ├── fifa_mean_dur_split1.pt
+│ │ │ │ ├── fifa_mean_dur_split2.pt
+│ │ │ │ ├── ...
+│ │ │ │ ├── test.split0.bundle
+│ │ │ │ ├── test.split1.bundle
+│ │ │ │ ├── ...
+│ │ │ │ ├── train.split0.bundle
+│ │ │ │ ├── train.split1.bundle
+│ │ │ │ ├── ...
+│ │ │ │ ├── train_split1_mean_duration.txt
+│ │ │ │ ├── train_split2_mean_duration.txt
+│ │ │ │ ├── ...
+│ │ │ │ ├── ...
+│ │ │ ├── mapping.txt
+│ │ ├── 50salads
+│ │ │ ├── features
+│ │ │ │ ├── ...
+│ │ │ ├── groundTruth
+│ │ │ │ ├── ...
+│ │ │ ├── gt_arr
+│ │ │ │ ├── ...
+│ │ │ ├── gt_boundary_arr
+│ │ │ │ ├── ...
+│ │ │ ├── splits
+│ │ │ │ ├── ...
+│ │ │ ├── mapping.txt
+│ │ ├── breakfast
+│ │ │ ├── features
+│ │ │ │ ├── ...
+│ │ │ ├── groundTruth
+│ │ │ │ ├── ...
+│ │ │ ├── gt_arr
+│ │ │ │ ├── ...
+│ │ │ ├── gt_boundary_arr
+│ │ │ │ ├── ...
+│ │ │ ├── splits
+│ │ │ │ ├── ...
+│ │ │ ├── mapping.txt
+
+```
+
+For training and evaluating on GTEA, 50Salads and Breakfast, please refer to [Training and Test Tutorial](/docs/en/user_guides/train_test.md).
diff --git a/tools/data/action_seg/download_datasets.sh b/tools/data/action_seg/download_datasets.sh
new file mode 100644
index 0000000000..65dda9b53c
--- /dev/null
+++ b/tools/data/action_seg/download_datasets.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+
+set -e
+
+DATA_DIR="../../../data"
+
+if [[ ! -d "${DATA_DIR}" ]]; then
+ echo "${DATA_DIR} does not exist. Creating";
+ mkdir -p ${DATA_DIR}
+fi
+
+cd ${DATA_DIR}
+wget https://zenodo.org/record/3625992/files/data.zip --no-check-certificate
+
+# sudo apt-get install unzip
+unzip data.zip
+rm data.zip
+
+mv data action_seg
+
+cd -
diff --git a/tools/data/action_seg/generate_boundary_array.py b/tools/data/action_seg/generate_boundary_array.py
new file mode 100644
index 0000000000..262a0394b9
--- /dev/null
+++ b/tools/data/action_seg/generate_boundary_array.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import glob
+import os
+
+import numpy as np
+
+
+def get_arguments() -> argparse.Namespace:
+ """parse all the arguments from command line interface return a list of
+ parsed arguments."""
+
+ parser = argparse.ArgumentParser(
+ description='generate ground truth arrays for boundary regression.')
+ parser.add_argument(
+ '--dataset_dir',
+ type=str,
+ help='path to a dataset directory',
+ )
+
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = get_arguments()
+
+ datasets = ['50salads', 'gtea', 'breakfast']
+
+ for dataset in datasets:
+ # make directory for saving ground truth numpy arrays
+ save_dir = os.path.join(args.dataset_dir, dataset, 'gt_boundary_arr')
+ if not os.path.exists(save_dir):
+ os.mkdir(save_dir)
+
+ gt_dir = os.path.join(args.dataset_dir, dataset, 'groundTruth')
+ gt_paths = glob.glob(os.path.join(gt_dir, '*.txt'))
+
+ for gt_path in gt_paths:
+ # the name of ground truth text file
+ gt_name = os.path.relpath(gt_path, gt_dir)
+
+ with open(gt_path, 'r') as f:
+ gt = f.read().split('\n')[:-1]
+
+ # define the frame where new action starts as boundary frame
+ boundary = np.zeros(len(gt))
+ last = gt[0]
+ boundary[0] = 1
+ for i in range(1, len(gt)):
+ if last != gt[i]:
+ boundary[i] = 1
+ last = gt[i]
+
+ # save array
+ np.save(os.path.join(save_dir, gt_name[:-4] + '.npy'), boundary)
+
+ print('Done')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/data/action_seg/generate_gt_array.py b/tools/data/action_seg/generate_gt_array.py
new file mode 100644
index 0000000000..e4f7ffe7d0
--- /dev/null
+++ b/tools/data/action_seg/generate_gt_array.py
@@ -0,0 +1,100 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import glob
+import os
+import sys
+from typing import Dict
+
+import numpy as np
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+
+dataset_names = ['50salads', 'breakfast', 'gtea']
+
+
+def get_class2id_map(dataset: str,
+ dataset_dir: str = './dataset') -> Dict[str, int]:
+ """
+ Args:
+ dataset: 50salads, gtea, breakfast
+ dataset_dir: the path to the dataset directory
+ """
+
+ assert (dataset in dataset_names
+ ), 'You have to choose 50salads, gtea or breakfast as dataset.'
+
+ with open(
+ os.path.join(dataset_dir, '{}/mapping.txt'.format(dataset)),
+ 'r') as f:
+ actions = f.read().split('\n')[:-1]
+
+ class2id_map = dict()
+ for a in actions:
+ class2id_map[a.split()[1]] = int(a.split()[0])
+
+ return class2id_map
+
+
+def get_id2class_map(dataset: str,
+ dataset_dir: str = './dataset') -> Dict[int, str]:
+ class2id_map = get_class2id_map(dataset, dataset_dir)
+
+ return {val: key for key, val in class2id_map.items()}
+
+
+def get_n_classes(dataset: str, dataset_dir: str = './dataset') -> int:
+ return len(get_class2id_map(dataset, dataset_dir))
+
+
+def get_arguments() -> argparse.Namespace:
+ """parse all the arguments from command line interface return a list of
+ parsed arguments."""
+
+ parser = argparse.ArgumentParser(
+ description='convert ground truth txt files to numpy array')
+ parser.add_argument(
+ '--dataset_dir',
+ type=str,
+ default='./dataset',
+ help='path to a dataset directory (default: ./dataset)',
+ )
+
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = get_arguments()
+
+ datasets = ['50salads', 'gtea', 'breakfast']
+
+ for dataset in datasets:
+ # make directory for saving ground truth numpy arrays
+ save_dir = os.path.join(args.dataset_dir, dataset, 'gt_arr')
+ if not os.path.exists(save_dir):
+ os.mkdir(save_dir)
+
+ # class to index mapping
+ class2id_map = get_class2id_map(dataset, dataset_dir=args.dataset_dir)
+
+ gt_dir = os.path.join(args.dataset_dir, dataset, 'groundTruth')
+ gt_paths = glob.glob(os.path.join(gt_dir, '*.txt'))
+
+ for gt_path in gt_paths:
+ # the name of ground truth text file
+ gt_name = os.path.relpath(gt_path, gt_dir)
+
+ with open(gt_path, 'r') as f:
+ gt = f.read().split('\n')[:-1]
+
+ gt_array = np.zeros(len(gt))
+ for i in range(len(gt)):
+ gt_array[i] = class2id_map[gt[i]]
+
+ # save array
+ np.save(os.path.join(save_dir, gt_name[:-4] + '.npy'), gt_array)
+
+ print('Done')
+
+
+if __name__ == '__main__':
+ main()