Skip to content

Commit

Permalink
Remove model_with_aux from write metrics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700281575
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 26, 2024
1 parent 8f744cf commit 8b0a965
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 13 deletions.
1 change: 0 additions & 1 deletion kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def evaluate(
step=step,
aux=merged_aux,
schedules={},
model_with_aux=self.model_with_aux,
log_summaries=True,
)
return merged_aux
Expand Down
12 changes: 1 addition & 11 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down Expand Up @@ -201,15 +200,7 @@ def write_step_metrics(

if log_summaries:
with jax.transfer_guard("allow"):
# TODO(klausg): remove once all summaries are migrated to new protocol
# image summaries
image_summaries_old = {
name: summary.get_images(**aux.summary_kwargs[name])
for name, summary in model_with_aux.summaries.items()
if isinstance(summary, summaries.ImageSummary)
}

image_summaries = image_summaries_old | {
image_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, Float["n h w #3"])
Expand Down Expand Up @@ -586,7 +577,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down
1 change: 0 additions & 1 deletion kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def train_impl(
step=i,
aux=aux,
schedules=trainer.schedules,
model_with_aux=trainstep.model_with_aux,
timer=chrono,
log_summaries=log_summaries,
)
Expand Down

0 comments on commit 8b0a965

Please sign in to comment.