Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 493079776
  • Loading branch information
Orbax Authors authored and copybara-github committed Dec 5, 2022
1 parent 3328d20 commit d8cd5b3
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
5 changes: 4 additions & 1 deletion orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import functools
import time
from typing import Any, Optional

from absl import logging
Expand Down Expand Up @@ -68,6 +69,7 @@ def save(self,
Raises:
ValueError if the provided directory already exists.
"""
checkpoint_start_time = time.time()
directory = epath.Path(directory)
logging.info('Saving item to %s. Waiting for thread to finish save.',
directory)
Expand All @@ -91,7 +93,8 @@ def save(self,
self._add_futures(commit_ops)
# Directory is the final directory
self._start_async_commit(
functools.partial(utils.ensure_atomic_save, tmpdir, directory))
functools.partial(utils.on_commit_callback, tmpdir, directory,
checkpoint_start_time))

def restore(self,
directory: epath.PathLike,
Expand Down
6 changes: 4 additions & 2 deletions orbax/checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Synchronous Checkpointer implementation."""

import time
from typing import Any, Optional

from absl import logging
Expand Down Expand Up @@ -58,6 +59,7 @@ def save(self,
Raises:
ValueError if the provided directory already exists.
"""
checkpoint_start_time = time.time()
directory = epath.Path(directory)
logging.info('Saving item to %s.', directory)
if directory.exists():
Expand All @@ -72,9 +74,9 @@ def save(self,
self._handler.save(tmpdir, item, *args, **kwargs)
multihost_utils.sync_global_devices('Checkpointer:write')

# Ensure save operation atomicity.
# Ensure save operation atomicity and record time saved by checkpoint.
if jax.process_index() == 0:
utils.ensure_atomic_save(tmpdir, directory)
utils.on_commit_callback(tmpdir, directory, checkpoint_start_time)
multihost_utils.sync_global_devices('Checkpointer:save')

def restore(self,
Expand Down
34 changes: 34 additions & 0 deletions orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_PLACEHOLDER_PREFIX = 'PLACEHOLDER://'
_COMMIT_SUCCESS_FILE = 'commit_success.txt'
_GCS_PATH_PREFIX = 'gs://'
_LAST_CHECKPOINT_WRITE_TIME = time.time()
CheckpointDirs = Tuple[str, str]
PyTree = type(jax.tree_util.tree_structure(None))

Expand Down Expand Up @@ -241,6 +242,39 @@ def ensure_atomic_save(temp_ckpt_dir: epath.Path, final_ckpt_dir: epath.Path):
logging.info('Finished saving checkpoint to `%s`.', final_ckpt_dir)


def record_saved_duration(checkpoint_start_time: float):
"""Record program duration that is accounted for by this checkpoint.
For the very first checkpoint, this is the interval between program init and
current checkpoint start time.
Note that we use the checkpoint start time instead of end time. The saved
duration should not include prallel training duration while the async
checkpoint is being written in the background.
Args:
checkpoint_start_time: Start time of current checkpoint.
"""
global _LAST_CHECKPOINT_WRITE_TIME
# Note: for the very first checkpoint, this is the interval between program
# init and the current checkpoint start time.
duration_since_last_checkpoint = (
checkpoint_start_time - _LAST_CHECKPOINT_WRITE_TIME)
# TODO(hanyangtay): Remove version guard.
if jax.version.__version_info__ > (0, 3, 25):
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/duration_since_last_checkpoint_secs',
duration_since_last_checkpoint)
_LAST_CHECKPOINT_WRITE_TIME = checkpoint_start_time


def on_commit_callback(temp_ckpt_dir: epath.Path, final_ckpt_dir: epath.Path,
checkpoint_start_time: float):
"""Finalize atomic save and record training duration saved in a checkpoint."""
ensure_atomic_save(temp_ckpt_dir, final_ckpt_dir)
record_saved_duration(checkpoint_start_time)


def is_scalar(x):
return isinstance(x, (int, float, np.number))

Expand Down

0 comments on commit d8cd5b3

Please sign in to comment.