From 25b459cd5c68e0ba19d7a866eff4731c95b30289 Mon Sep 17 00:00:00 2001 From: LZHgrla Date: Wed, 10 Jan 2024 10:34:01 +0800 Subject: [PATCH 1/3] fix --- mmengine/_strategy/base.py | 9 +++++ mmengine/_strategy/colossalai.py | 37 ++++++++++++-------- mmengine/_strategy/deepspeed.py | 39 ++++++++++++++-------- mmengine/_strategy/single_device.py | 36 ++++++++++++-------- mmengine/runner/_flexible_runner.py | 21 ++++++++++++ mmengine/runner/loops.py | 6 ++++ mmengine/runner/runner.py | 52 ++++++++++++++++++++++------- 7 files changed, 149 insertions(+), 51 deletions(-) diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index 5df3a79c92..2c3fa2607a 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -893,6 +893,15 @@ def load_checkpoint( Defaults to None. """ + @abstractmethod + def resume_seed(self, filename: str): + """Resume seed from given ``filename``. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + """ + @abstractmethod def resume( self, diff --git a/mmengine/_strategy/colossalai.py b/mmengine/_strategy/colossalai.py index cfbb925c67..4a019bac57 100644 --- a/mmengine/_strategy/colossalai.py +++ b/mmengine/_strategy/colossalai.py @@ -352,6 +352,29 @@ def prepare( self._prepared = True return self._prepared_components() + def resume_seed(self, filename: str): + """Resume seed from given ``filename``. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + """ + self.logger.info(f'Resume seed from {filename}') + + checkpoint = _load_checkpoint( + osp.join(filename, 'meta.pth'), map_location='cpu') + + resumed_seed = checkpoint['meta'].get('seed', None) + current_seed = self._randomness.get('seed') + if resumed_seed is not None and resumed_seed != current_seed: + if current_seed is not None: + self.logger.warning(f'The value of random seed in the ' + f'checkpoint "{resumed_seed}" is ' + f'different from the value in ' + f'`randomness` config "{current_seed}"') + self._randomness.update(seed=resumed_seed) + self._set_randomness(**self._randomness) + def resume( self, filename: str, @@ -379,18 +402,6 @@ def resume( self.booster.load_lr_scheduler( scheduler, f'{schedulers_dir}/scheduler_{i}.pth') - # resume random seed - resumed_seed = extra_ckpt['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - # resume iter self.dispatch_kwargs['cur_iter'] = extra_ckpt['meta']['iter'] @@ -408,7 +419,7 @@ def load_checkpoint( """Load checkpoint from given ``filename``. Warning: - `map_localtion` and `callback` parameters are not supported yet. + `map_location` and `callback` parameters are not supported yet. Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 378616db3d..bc7a36a776 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -18,6 +18,7 @@ from mmengine.optim import BaseOptimWrapper, _ParamScheduler from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS, STRATEGIES) +from mmengine.runner.checkpoint import _load_checkpoint from mmengine.utils import apply_to, digit_version, get_git_hash from .base import BaseStrategy @@ -416,7 +417,7 @@ def load_checkpoint( """Load checkpoint from given ``filename``. Warning: - `map_localtion` and `callback` parameters are not supported yet. + `map_location` and `callback` parameters are not supported yet. Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, @@ -437,6 +438,30 @@ def load_checkpoint( return extra_ckpt + def resume_seed(self, filename: str): + """Resume seed from given ``filename``. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + """ + self.logger.info(f'Resume seed from {filename}') + + from deepspeed.utils.zero_to_fp32 import get_model_state_files + filename = get_model_state_files(filename)[0] + checkpoint = _load_checkpoint(filename, map_location='cpu') + + resumed_seed = checkpoint['meta'].get('seed', None) + current_seed = self._randomness.get('seed') + if resumed_seed is not None and resumed_seed != current_seed: + if current_seed is not None: + self.logger.warning(f'The value of random seed in the ' + f'checkpoint "{resumed_seed}" is ' + f'different from the value in ' + f'`randomness` config "{current_seed}"') + self._randomness.update(seed=resumed_seed) + self._set_randomness(**self._randomness) + def resume( self, filename: str, @@ -480,18 +505,6 @@ def resume( param_schedulers = extra_ckpt.pop('param_schedulers') self.load_scheduler_state_dict(param_schedulers) - # resume random seed - resumed_seed = extra_ckpt['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - return extra_ckpt def save_checkpoint( diff --git a/mmengine/_strategy/single_device.py b/mmengine/_strategy/single_device.py index c7d8accd5a..10dde91f55 100644 --- a/mmengine/_strategy/single_device.py +++ b/mmengine/_strategy/single_device.py @@ -9,6 +9,7 @@ from mmengine.model import revert_sync_batchnorm from mmengine.optim import BaseOptimWrapper, _ParamScheduler from mmengine.registry import STRATEGIES +from mmengine.runner.checkpoint import _load_checkpoint from mmengine.utils import get_git_hash from .base import BaseStrategy @@ -135,7 +136,6 @@ def load_checkpoint( checkpoint after loading the checkpoint. Defaults to None. """ - from mmengine.runner.checkpoint import _load_checkpoint self.logger.info(f'Load checkpoint from {filename}') @@ -155,6 +155,28 @@ def load_checkpoint( return checkpoint + def resume_seed(self, filename: str): + """Resume seed from given ``filename``. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + """ + self.logger.info(f'Resume seed from {filename}') + + checkpoint = _load_checkpoint(filename, map_location='cpu') + + resumed_seed = checkpoint['meta'].get('seed', None) + current_seed = self._randomness.get('seed') + if resumed_seed is not None and resumed_seed != current_seed: + if current_seed is not None: + self.logger.warning(f'The value of random seed in the ' + f'checkpoint "{resumed_seed}" is ' + f'different from the value in ' + f'`randomness` config "{current_seed}"') + self._randomness.update(seed=resumed_seed) + self._set_randomness(**self._randomness) + def resume( self, filename: str, @@ -200,18 +222,6 @@ def resume( if resume_param_scheduler and hasattr(self, 'param_schedulers'): self.load_scheduler_state_dict(checkpoint.pop('param_schedulers')) - # resume random seed - resumed_seed = checkpoint['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - # resume iter cur_iter = checkpoint['meta']['iter'] diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py index 6d727fb4d5..cf66c846ed 100644 --- a/mmengine/runner/_flexible_runner.py +++ b/mmengine/runner/_flexible_runner.py @@ -405,6 +405,9 @@ def __init__( self.logger.info(f'Hooks will be executed in the following ' f'order:\n{self.get_hooks_info()}') + # resume seed if needed + self.resume_seed() + # dump `cfg` to `work_dir` self.dump_config() @@ -1116,6 +1119,24 @@ def get_hooks_info(self) -> str: stage_hook_infos.append(info) return '\n'.join(stage_hook_infos) + def resume_seed(self): + """resume seed.""" + + # decide to load from checkpoint or resume from checkpoint + resume_from = None + if isinstance(self._resume, str): + resume_from = self._resume + elif self._resume and self._load_from is None: + # auto resume from the latest checkpoint + resume_from = find_latest_checkpoint(self.work_dir) + self.logger.info( + f'Auto resumed from the latest checkpoint {resume_from}.') + elif self._resume and self._load_from is not None: + # resume from the specified checkpoint + resume_from = self._load_from + if resume_from is not None: + self.strategy.resume_seed(resume_from) + def load_or_resume(self): """load or resume checkpoint.""" if self._has_loaded: diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 6a874a6ad6..46e3e35b44 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -271,6 +271,12 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') + if self._iter > 0: + print_log( + f'Advance dataloader {self._iter} steps to skip data ' + 'that has already been trained', 'current') + for _ in range(self._iter): + next(self.dataloader_iterator) while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 68716ab253..2c495ee077 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -444,6 +444,9 @@ def __init__( self.logger.info(f'Hooks will be executed in the following ' f'order:\n{self.get_hooks_info()}') + # resume seed if needed + self.resume_seed() + # dump `cfg` to `work_dir` self.dump_config() @@ -1994,6 +1997,43 @@ def register_hooks( if custom_hooks is not None: self.register_custom_hooks(custom_hooks) + def _resume_seed(self, filename: str) -> None: + """Resume seed from checkpoint. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + """ + + checkpoint = _load_checkpoint(filename, map_location='cpu') + + resumed_seed = checkpoint['meta'].get('seed', None) + current_seed = self._randomness_cfg.get('seed') + if resumed_seed is not None and resumed_seed != current_seed: + if current_seed is not None: + self.logger.warning(f'The value of random seed in the ' + f'checkpoint "{resumed_seed}" is ' + f'different from the value in ' + f'`randomness` config "{current_seed}"') + self._randomness_cfg.update(seed=resumed_seed) + self.set_randomness(**self._randomness_cfg) + + def resume_seed(self): + """resume seed.""" + + # decide to load from checkpoint or resume from checkpoint + resume_from = None + if self._resume and self._load_from is None: + # auto resume from the latest checkpoint + resume_from = find_latest_checkpoint(self.work_dir) + self.logger.info(f'Seed is auto resumed from the latest ' + f'checkpoint {resume_from}.') + elif self._resume and self._load_from is not None: + # resume from the specified checkpoint + resume_from = self._load_from + if resume_from is not None: + self._resume_seed(resume_from) + def resume(self, filename: str, resume_optimizer: bool = True, @@ -2047,18 +2087,6 @@ def resume(self, 'leaning rate will be adjusted according to the ' f'setting in auto_scale_lr={self.auto_scale_lr}') - # resume random seed - resumed_seed = checkpoint['meta'].get('seed', None) - current_seed = self._randomness_cfg.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning(f'The value of random seed in the ' - f'checkpoint "{resumed_seed}" is ' - f'different from the value in ' - f'`randomness` config "{current_seed}"') - self._randomness_cfg.update(seed=resumed_seed) - self.set_randomness(**self._randomness_cfg) - resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None) dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None) From ec706dc33c5b319fff97a540c0360a8e68b7365d Mon Sep 17 00:00:00 2001 From: LZHgrla Date: Wed, 10 Jan 2024 11:01:03 +0800 Subject: [PATCH 2/3] fix ut --- mmengine/runner/loops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 46e3e35b44..1f6551ab62 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -274,7 +274,9 @@ def run(self) -> None: if self._iter > 0: print_log( f'Advance dataloader {self._iter} steps to skip data ' - 'that has already been trained', 'current') + 'that has already been trained', + logger='current', + level=logging.WARNING) for _ in range(self._iter): next(self.dataloader_iterator) while self._iter < self._max_iters and not self.stop_training: From 7d577b65b768812f10b82f84a058dccf6704a8f8 Mon Sep 17 00:00:00 2001 From: LZHgrla Date: Wed, 10 Jan 2024 12:51:09 +0800 Subject: [PATCH 3/3] fix ui --- mmengine/_strategy/colossalai.py | 7 +++---- mmengine/_strategy/deepspeed.py | 7 +++---- mmengine/runner/runner.py | 7 +++---- tests/test_runner/test_runner.py | 2 +- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mmengine/_strategy/colossalai.py b/mmengine/_strategy/colossalai.py index 4a019bac57..e8d72c652e 100644 --- a/mmengine/_strategy/colossalai.py +++ b/mmengine/_strategy/colossalai.py @@ -365,13 +365,12 @@ def resume_seed(self, filename: str): osp.join(filename, 'meta.pth'), map_location='cpu') resumed_seed = checkpoint['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: + if resumed_seed is not None and resumed_seed != self.seed: + if self.seed is not None: self.logger.warning(f'The value of random seed in the ' f'checkpoint "{resumed_seed}" is ' f'different from the value in ' - f'`randomness` config "{current_seed}"') + f'`randomness` config "{self.seed}"') self._randomness.update(seed=resumed_seed) self._set_randomness(**self._randomness) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index bc7a36a776..25851ae419 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -452,13 +452,12 @@ def resume_seed(self, filename: str): checkpoint = _load_checkpoint(filename, map_location='cpu') resumed_seed = checkpoint['meta'].get('seed', None) - current_seed = self._randomness.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: + if resumed_seed is not None and resumed_seed != self.seed: + if self.seed is not None: self.logger.warning(f'The value of random seed in the ' f'checkpoint "{resumed_seed}" is ' f'different from the value in ' - f'`randomness` config "{current_seed}"') + f'`randomness` config "{self.seed}"') self._randomness.update(seed=resumed_seed) self._set_randomness(**self._randomness) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 2c495ee077..1a17a8533b 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -2008,13 +2008,12 @@ def _resume_seed(self, filename: str) -> None: checkpoint = _load_checkpoint(filename, map_location='cpu') resumed_seed = checkpoint['meta'].get('seed', None) - current_seed = self._randomness_cfg.get('seed') - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: + if resumed_seed is not None and resumed_seed != self.seed: + if self.seed is not None: self.logger.warning(f'The value of random seed in the ' f'checkpoint "{resumed_seed}" is ' f'different from the value in ' - f'`randomness` config "{current_seed}"') + f'`randomness` config "{self.seed}"') self._randomness_cfg.update(seed=resumed_seed) self.set_randomness(**self._randomness_cfg) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index c8a58e9c8a..e0ce565502 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -2342,7 +2342,7 @@ def test_checkpoint(self): torch.save(ckpt_modified, path_modified) # Warning should be raised since seed is not matched with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): - runner.resume(path_modified) + runner._resume_seed(path_modified) # 1.3.3 test resume with no seed and dataset meta ckpt_modified = copy.deepcopy(ckpt)