From a11e9632f8439c5671203e5ef1b11e327d209bf6 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sun, 3 Jul 2022 17:25:37 +0800 Subject: [PATCH 1/2] [Fix] RAFT KITTI test cfg and set_env --- configs/_base_/datasets/kitti2015_raft_test.py | 3 ++- mmflow/utils/set_env.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/configs/_base_/datasets/kitti2015_raft_test.py b/configs/_base_/datasets/kitti2015_raft_test.py index 2a67368c..9c5a5dfc 100644 --- a/configs/_base_/datasets/kitti2015_raft_test.py +++ b/configs/_base_/datasets/kitti2015_raft_test.py @@ -1,4 +1,5 @@ -img_norm_cfg = dict(mean=[0., 0., 0.], std=[255., 255., 255.], to_rgb=False) +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False) kitti_test_pipeline = [ dict(type='LoadImageFromFile'), diff --git a/mmflow/utils/set_env.py b/mmflow/utils/set_env.py index 1f38b92a..ae9fc860 100644 --- a/mmflow/utils/set_env.py +++ b/mmflow/utils/set_env.py @@ -11,7 +11,6 @@ def setup_multi_processes(cfg): """Setup multi-processing environment variables.""" logger = get_root_logger() - # set multi-process start method if platform.system() != 'Windows': mp_start_method = cfg.get('mp_start_method', None) @@ -31,7 +30,13 @@ def setup_multi_processes(cfg): else: logger.info(f'OpenCV num_threads is `{cv2.getNumThreads()}') - if cfg.data.train_dataloader.workers_per_gpu > 1: + if cfg.data.get('train_dataloader') is not None: + workers_per_gpu = cfg.data.train_dataloader.get('workers_per_gpu', 1) + elif cfg.data.get('test_dataloader') is not None: + workers_per_gpu = cfg.data.test_dataloader.get('workers_per_gpu', 1) + else: + workers_per_gpu = 0 + if workers_per_gpu > 1: # setup OMP threads # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa omp_num_threads = cfg.get('omp_num_threads', None) From ae6b61a36ff7026b0d7f0ba98be70cee44c9a639 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sun, 3 Jul 2022 17:42:59 +0800 Subject: [PATCH 2/2] num_worker default value --- mmflow/utils/set_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmflow/utils/set_env.py b/mmflow/utils/set_env.py index ae9fc860..a99a161f 100644 --- a/mmflow/utils/set_env.py +++ b/mmflow/utils/set_env.py @@ -31,9 +31,9 @@ def setup_multi_processes(cfg): logger.info(f'OpenCV num_threads is `{cv2.getNumThreads()}') if cfg.data.get('train_dataloader') is not None: - workers_per_gpu = cfg.data.train_dataloader.get('workers_per_gpu', 1) + workers_per_gpu = cfg.data.train_dataloader.get('workers_per_gpu', 0) elif cfg.data.get('test_dataloader') is not None: - workers_per_gpu = cfg.data.test_dataloader.get('workers_per_gpu', 1) + workers_per_gpu = cfg.data.test_dataloader.get('workers_per_gpu', 0) else: workers_per_gpu = 0 if workers_per_gpu > 1: