diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea4f4d3..e494129 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: flake8 - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.11.5 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-yapf diff --git a/mmflow/datasets/AutoFlow.py b/mmflow/datasets/AutoFlow.py new file mode 100644 index 0000000..6c33018 --- /dev/null +++ b/mmflow/datasets/AutoFlow.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from typing import Any + +from .base_dataset import BaseDataset +from .builder import DATASETS + + +@DATASETS.register_module() +class AutoFlow(BaseDataset): + """AutoFlow dataset.""" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + def load_data_info(self) -> None: + """load data information.""" + + self.subset_dir = 'test' if self.test_mode else 'train' + + self.data_root = osp.join(self.data_root, self.subset_dir) + self.img1_dir = self.data_root + self.img2_dir = self.data_root + self.flow_root = self.data_root + + self.img_suffix = '.png' + self.flow_suffix = '.flo' + + self.all_scene = os.listdir(self.img1_dir) + + img1_filenames = [] + img2_filenames = [] + flow_filenames = [] + + for s in self.all_scene: + file_dir = os.listdir(osp.join(self.img1_dir, s)) + for i in file_dir: + img_file_dir = osp.join(self.img1_dir, s, i) + flow_file_dir = osp.join(self.flow_root, s, i) + + flow_filenames_ = self.get_data_filename( + flow_file_dir, self.flow_suffix) + img_filenames_ = self.get_data_filename( + img_file_dir, self.img_suffix) + + flow_filenames.append(flow_filenames_[0]) + img1_filenames.append(img_filenames_[0]) + img2_filenames.append(img_filenames_[1]) + + # img1_filenames, img2_filenames = self._revise_dir(flow_filenames) + self.load_img_info(self.data_infos, img1_filenames, img2_filenames) + self.load_ann_info(self.data_infos, flow_filenames, 'filename_flow') diff --git a/mmflow/datasets/CrowdFlow.py b/mmflow/datasets/CrowdFlow.py new file mode 100644 index 0000000..1492cf3 --- /dev/null +++ b/mmflow/datasets/CrowdFlow.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from typing import Any + +from .base_dataset import BaseDataset +from .builder import DATASETS + + +@DATASETS.register_module() +class CrowdFlow(BaseDataset): + """CrowdFlow dataset.""" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + def load_data_info(self) -> None: + """load data information.""" + + self.subset_dir = 'test' if self.test_mode else 'train' + + self.data_root = osp.join(self.data_root, self.subset_dir) + self.img1_dir = osp.join(self.data_root, 'images') + self.img2_dir = osp.join(self.data_root, 'images') + self.flow_root = osp.join(self.data_root, 'gt_flow') + + self.img_suffix = '.png' + self.flow_suffix = '.flo' + + self.all_scene = os.listdir(self.img1_dir) + + img1_filenames = [] + img2_filenames = [] + flow_filenames = [] + + for s in self.all_scene: + img_file_dir = osp.join(self.img1_dir, s) + flow_file_dir = osp.join(self.flow_root, s) + + flow_filenames_ = self.get_data_filename(flow_file_dir, + self.flow_suffix) + img_filenames_ = self.get_data_filename(img_file_dir, + self.img_suffix) + flow_num = len(flow_filenames_) + for i in range(flow_num): + flow_filenames += [flow_filenames_[i]] + img1_filenames += [img_filenames_[i]] + img2_filenames += [img_filenames_[i + 1]] + + # img1_filenames, img2_filenames = self._revise_dir(flow_filenames) + self.load_img_info(self.data_infos, img1_filenames, img2_filenames) + self.load_ann_info(self.data_infos, flow_filenames, 'filename_flow') diff --git a/mmflow/datasets/__init__.py b/mmflow/datasets/__init__.py index 5631e21..4a1c53e 100644 --- a/mmflow/datasets/__init__.py +++ b/mmflow/datasets/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .AutoFlow import AutoFlow from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset from .chairssdhom import ChairsSDHom +from .CrowdFlow import CrowdFlow from .dataset_wrappers import ConcatDataset, RepeatDataset from .flyingchairs import FlyingChairs from .flyingchairsocc import FlyingChairsOcc @@ -33,5 +35,5 @@ 'read_flow_kitti', 'GaussianNoise', 'RandomTranslate', 'Compose', 'InputPad', 'FlyingThings3DSubset', 'FlyingThings3D', 'Sintel', 'KITTI2012', 'KITTI2015', 'ChairsSDHom', 'HD1K', 'FlyingChairsOcc', - 'render_color_wheel' + 'CrowdFlow', 'AutoFlow', 'render_color_wheel' ]