Skip to content

Commit

Permalink
Add return_context to trainstep.step()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699904187
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 25, 2024
1 parent 69e940d commit c315b56
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 214 deletions.
1 change: 0 additions & 1 deletion kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def evaluate(
step=step,
aux=merged_aux,
schedules={},
model_with_aux=self.model_with_aux,
log_summaries=True,
)
return merged_aux
Expand Down
4 changes: 3 additions & 1 deletion kauldron/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from kauldron.train.setup_utils import Setup
from kauldron.train.setup_utils import TqdmInfo
from kauldron.train.train_step import Auxiliaries
from kauldron.train.train_step import ModelWithAux
from kauldron.train.train_step import AuxiliariesRef
from kauldron.train.train_step import forward
from kauldron.train.train_step import forward_with_loss
from kauldron.train.train_step import TrainState
from kauldron.train.train_step import TrainStep
from kauldron.train.trainer_lib import Trainer
Expand Down
21 changes: 21 additions & 0 deletions kauldron/train/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class Context:
opt_state: The state of the optimizer prior to the update. (available after
the backward pass, e.g. for metrics). The old state is chosen to be
consistent with parameters which are also pre-update.
metric_states: The states of the metrics (after the backward pass)
summary_states: The states of the summaries (after the backward pass)
"""

# These are always available:
Expand All @@ -80,6 +82,9 @@ class Context:
grads: Any = None
updates: Any = None
opt_state: Any = None
# Become available after the metrics computation
metric_states: Any = None
summary_states: Any = None

replace = dataclasses.replace

Expand All @@ -100,3 +105,19 @@ def from_state_and_batch(

def flatten(self) -> dict[str, Any]:
return kontext.flatten_with_path(self)

def get_aux(
self,
*,
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False,
) -> train_step.Auxiliaries:
"""Returns the auxiliaries for the step."""
from kauldron.train import train_step # pylint: disable=g-import-not-at-top

return train_step.Auxiliaries(
loss_states=self.loss_states if return_losses else None,
metric_states=self.metric_states if return_metrics else None,
summary_states=self.summary_states if return_summaries else None,
)
12 changes: 1 addition & 11 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down Expand Up @@ -201,15 +200,7 @@ def write_step_metrics(

if log_summaries:
with jax.transfer_guard("allow"):
# TODO(klausg): remove once all summaries are migrated to new protocol
# image summaries
image_summaries_old = {
name: summary.get_images(**aux.summary_kwargs[name])
for name, summary in model_with_aux.summaries.items()
if isinstance(summary, summaries.ImageSummary)
}

image_summaries = image_summaries_old | {
image_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, Float["n h w #3"])
Expand Down Expand Up @@ -586,7 +577,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down
1 change: 0 additions & 1 deletion kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def train_impl(
step=i,
aux=aux,
schedules=trainer.schedules,
model_with_aux=trainstep.model_with_aux,
timer=chrono,
log_summaries=log_summaries,
)
Expand Down
Loading

0 comments on commit c315b56

Please sign in to comment.