Skip to content

Commit

Permalink
Fix sweep serialization for string
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700273671
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 26, 2024
1 parent 8b0a965 commit d71bdb1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
14 changes: 11 additions & 3 deletions kauldron/utils/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def iter_sweep_configs(
if sweep_mode == SweepMode.FIRST:
all_sweep_items = all_sweep_items[:1]

output_cfg = []
with ecolab.collapse(f'Resolving {len(all_sweep_items)} sweeps configs'):
for i, sweep_item in enumerate(all_sweep_items):
# Re-create the config to avoid mutations leak between iterations.
Expand All @@ -76,13 +77,20 @@ def iter_sweep_configs(
# TODO(epot): Display the sweep short name (workdir) and config.
for k, v in sweep_kwargs.items():
kontext.set_by_path(cfg, k, v)

# Only for visualization.
sweep_cfg_overwrites = konfig.ConfigDict(sweep_kwargs)
print(f'Work-unit {i+1}:', flush=True)
ecolab.disp(sweep_cfg_overwrites)
# TODO(epot): Should report IPython.display.HTML bug to Colab team.
# Use `print` rather than `ecolab.disp` as HTML display bug here for
# some reason and mangle the previous outputs.
print(sweep_cfg_overwrites)
print()

# Somhow, yield here remove some of the outputs, so instead append and
# returns everything at the end.
output_cfg.append(cfg)

yield cfg
return output_cfg


def _update_sweep_names_forms(module: types.ModuleType) -> None:
Expand Down
8 changes: 8 additions & 0 deletions kauldron/utils/sweep_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def sweep():
yield {
'eval_ds.batch_size': bs,
'train_ds.batch_size': bs,
'aux.model_size': 'big' if bs == 16 else 'small',
}


Expand All @@ -53,10 +54,17 @@ def test_sweep():
assert len(all_sweep_info) == 4 # Cross product

sweep0 = kauldron_utils._encode_sweep_item(all_sweep_info[0])
assert sweep0.job_kwargs == {
'cfg.eval_ds.batch_size': '16',
'cfg.train_ds.batch_size': '16',
'cfg.aux.model_size': 'big',
'cfg.model': '{"__qualname__": "flax.linen:Dense", "0": 12}',
}
sweep0 = kauldron_utils.deserialize_job_kwargs(sweep0.job_kwargs)
assert sweep0 == {
'eval_ds.batch_size': 16,
'train_ds.batch_size': 16,
'aux.model_size': 'big',
'model': {'__qualname__': 'flax.linen:Dense', '0': 12},
}

Expand Down
20 changes: 18 additions & 2 deletions kauldron/xm/_src/kauldron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,27 @@ def _encode_sweep_item(


def _serialize_job_kwargs(job_kwargs: dict[str, _Json]) -> dict[str, _Json]:
return {f"cfg.{k}": _JsonEncoder().encode(v) for k, v in job_kwargs.items()}
return {
f"cfg.{k}": v if isinstance(v, str) else _JsonEncoder().encode(v)
for k, v in job_kwargs.items()
}


def deserialize_job_kwargs(job_kwargs: dict[str, _Json]) -> dict[str, _Json]:
return {k.removeprefix("cfg."): json.loads(v) for k, v in job_kwargs.items()}
return {
k.removeprefix("cfg."): _decode_json_or_str(v)
for k, v in job_kwargs.items()
}


def _decode_json_or_str(v: _Json) -> _Json:
"""Decodes the JSON string or returns the string itself."""
# The decoded values should always have been encoded JSON strings from
# `_serialize_job_kwargs`, so there shouldn't be risk of badly formatted JSON.
try:
return json.loads(v)
except json.JSONDecodeError:
return v


def _ui_repr(v):
Expand Down

0 comments on commit d71bdb1

Please sign in to comment.