Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix seed resuming #1468

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,15 @@ def load_checkpoint(
Defaults to None.
"""

@abstractmethod
def resume_seed(self, filename: str):
"""Resume seed from given ``filename``.

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
"""

@abstractmethod
def resume(
self,
Expand Down
36 changes: 23 additions & 13 deletions mmengine/_strategy/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,28 @@ def prepare(
self._prepared = True
return self._prepared_components()

def resume_seed(self, filename: str):
"""Resume seed from given ``filename``.

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
"""
self.logger.info(f'Resume seed from {filename}')

checkpoint = _load_checkpoint(
osp.join(filename, 'meta.pth'), map_location='cpu')

resumed_seed = checkpoint['meta'].get('seed', None)
if resumed_seed is not None and resumed_seed != self.seed:
if self.seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{self.seed}"')
self._randomness.update(seed=resumed_seed)
self._set_randomness(**self._randomness)

def resume(
self,
filename: str,
Expand Down Expand Up @@ -379,18 +401,6 @@ def resume(
self.booster.load_lr_scheduler(
scheduler, f'{schedulers_dir}/scheduler_{i}.pth')

# resume random seed
resumed_seed = extra_ckpt['meta'].get('seed', None)
current_seed = self._randomness.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
self._randomness.update(seed=resumed_seed)
self._set_randomness(**self._randomness)

# resume iter
self.dispatch_kwargs['cur_iter'] = extra_ckpt['meta']['iter']

Expand All @@ -408,7 +418,7 @@ def load_checkpoint(
"""Load checkpoint from given ``filename``.

Warning:
`map_localtion` and `callback` parameters are not supported yet.
`map_location` and `callback` parameters are not supported yet.

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
Expand Down
38 changes: 25 additions & 13 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES)
from mmengine.runner.checkpoint import _load_checkpoint
from mmengine.utils import apply_to, digit_version, get_git_hash
from .base import BaseStrategy

Expand Down Expand Up @@ -416,7 +417,7 @@ def load_checkpoint(
"""Load checkpoint from given ``filename``.

Warning:
`map_localtion` and `callback` parameters are not supported yet.
`map_location` and `callback` parameters are not supported yet.

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
Expand All @@ -437,6 +438,29 @@ def load_checkpoint(

return extra_ckpt

def resume_seed(self, filename: str):
"""Resume seed from given ``filename``.

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
"""
self.logger.info(f'Resume seed from {filename}')

from deepspeed.utils.zero_to_fp32 import get_model_state_files
filename = get_model_state_files(filename)[0]
checkpoint = _load_checkpoint(filename, map_location='cpu')

resumed_seed = checkpoint['meta'].get('seed', None)
if resumed_seed is not None and resumed_seed != self.seed:
if self.seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{self.seed}"')
self._randomness.update(seed=resumed_seed)
self._set_randomness(**self._randomness)

def resume(
self,
filename: str,
Expand Down Expand Up @@ -480,18 +504,6 @@ def resume(
param_schedulers = extra_ckpt.pop('param_schedulers')
self.load_scheduler_state_dict(param_schedulers)

# resume random seed
resumed_seed = extra_ckpt['meta'].get('seed', None)
current_seed = self._randomness.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
self._randomness.update(seed=resumed_seed)
self._set_randomness(**self._randomness)

return extra_ckpt

def save_checkpoint(
Expand Down
36 changes: 23 additions & 13 deletions mmengine/_strategy/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmengine.model import revert_sync_batchnorm
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import STRATEGIES
from mmengine.runner.checkpoint import _load_checkpoint
from mmengine.utils import get_git_hash
from .base import BaseStrategy

Expand Down Expand Up @@ -135,7 +136,6 @@ def load_checkpoint(
checkpoint after loading the checkpoint.
Defaults to None.
"""
from mmengine.runner.checkpoint import _load_checkpoint

self.logger.info(f'Load checkpoint from {filename}')

Expand All @@ -155,6 +155,28 @@ def load_checkpoint(

return checkpoint

def resume_seed(self, filename: str):
"""Resume seed from given ``filename``.

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
"""
self.logger.info(f'Resume seed from {filename}')

checkpoint = _load_checkpoint(filename, map_location='cpu')

resumed_seed = checkpoint['meta'].get('seed', None)
current_seed = self._randomness.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
self._randomness.update(seed=resumed_seed)
self._set_randomness(**self._randomness)

def resume(
self,
filename: str,
Expand Down Expand Up @@ -200,18 +222,6 @@ def resume(
if resume_param_scheduler and hasattr(self, 'param_schedulers'):
self.load_scheduler_state_dict(checkpoint.pop('param_schedulers'))

# resume random seed
resumed_seed = checkpoint['meta'].get('seed', None)
current_seed = self._randomness.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
self._randomness.update(seed=resumed_seed)
self._set_randomness(**self._randomness)

# resume iter
cur_iter = checkpoint['meta']['iter']

Expand Down
21 changes: 21 additions & 0 deletions mmengine/runner/_flexible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ def __init__(
self.logger.info(f'Hooks will be executed in the following '
f'order:\n{self.get_hooks_info()}')

# resume seed if needed
self.resume_seed()

# dump `cfg` to `work_dir`
self.dump_config()

Expand Down Expand Up @@ -1116,6 +1119,24 @@ def get_hooks_info(self) -> str:
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)

def resume_seed(self):
"""resume seed."""

# decide to load from checkpoint or resume from checkpoint
resume_from = None
if isinstance(self._resume, str):
resume_from = self._resume
elif self._resume and self._load_from is None:
# auto resume from the latest checkpoint
resume_from = find_latest_checkpoint(self.work_dir)
self.logger.info(
f'Auto resumed from the latest checkpoint {resume_from}.')
elif self._resume and self._load_from is not None:
# resume from the specified checkpoint
resume_from = self._load_from
if resume_from is not None:
self.strategy.resume_seed(resume_from)

def load_or_resume(self):
"""load or resume checkpoint."""
if self._has_loaded:
Expand Down
8 changes: 8 additions & 0 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ def run(self) -> None:
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
if self._iter > 0:
print_log(
f'Advance dataloader {self._iter} steps to skip data '
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
next(self.dataloader_iterator)
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()

Expand Down
51 changes: 39 additions & 12 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ def __init__(
self.logger.info(f'Hooks will be executed in the following '
f'order:\n{self.get_hooks_info()}')

# resume seed if needed
self.resume_seed()

# dump `cfg` to `work_dir`
self.dump_config()

Expand Down Expand Up @@ -1994,6 +1997,42 @@ def register_hooks(
if custom_hooks is not None:
self.register_custom_hooks(custom_hooks)

def _resume_seed(self, filename: str) -> None:
"""Resume seed from checkpoint.

Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
"""

checkpoint = _load_checkpoint(filename, map_location='cpu')

resumed_seed = checkpoint['meta'].get('seed', None)
if resumed_seed is not None and resumed_seed != self.seed:
if self.seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{self.seed}"')
self._randomness_cfg.update(seed=resumed_seed)
self.set_randomness(**self._randomness_cfg)

def resume_seed(self):
"""resume seed."""

# decide to load from checkpoint or resume from checkpoint
resume_from = None
if self._resume and self._load_from is None:
# auto resume from the latest checkpoint
resume_from = find_latest_checkpoint(self.work_dir)
self.logger.info(f'Seed is auto resumed from the latest '
f'checkpoint {resume_from}.')
elif self._resume and self._load_from is not None:
# resume from the specified checkpoint
resume_from = self._load_from
if resume_from is not None:
self._resume_seed(resume_from)

def resume(self,
filename: str,
resume_optimizer: bool = True,
Expand Down Expand Up @@ -2047,18 +2086,6 @@ def resume(self,
'leaning rate will be adjusted according to the '
f'setting in auto_scale_lr={self.auto_scale_lr}')

# resume random seed
resumed_seed = checkpoint['meta'].get('seed', None)
current_seed = self._randomness_cfg.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
self.logger.warning(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
self._randomness_cfg.update(seed=resumed_seed)
self.set_randomness(**self._randomness_cfg)

resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None)
dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2342,7 +2342,7 @@ def test_checkpoint(self):
torch.save(ckpt_modified, path_modified)
# Warning should be raised since seed is not matched
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
runner.resume(path_modified)
runner._resume_seed(path_modified)

# 1.3.3 test resume with no seed and dataset meta
ckpt_modified = copy.deepcopy(ckpt)
Expand Down
Loading