Skip to content

Commit

Permalink
feat: Add convertion script for PyTorch model
Browse files Browse the repository at this point in the history
This script helps to convert the PAX model to PyTorch model and saves the converted checkpoint for future loading
  • Loading branch information
TeddyHuang-00 committed Aug 29, 2024
1 parent 271aecf commit 99130a9
Showing 1 changed file with 238 additions and 0 deletions.
238 changes: 238 additions & 0 deletions src/timesfm_torch/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import logging
import sys
from pathlib import Path

import numpy as np
import torch
from absl import flags
from paxml import checkpoints

from timesfm.timesfm import TimesFm

from .pytorch_patched_decoder import (
PatchedTimeSeriesDecoder,
TimesFMConfig,
)

logger = logging.getLogger(Path(__file__).stem)
logger.setLevel(logging.INFO)

ROOT = Path(__file__).parent.parent.parent
FLAGS = flags.FLAGS
_MODEL_PATH = flags.DEFINE_string(
"model_path", "timesfm-1.0-200m/checkpoints/", "The model checkpoint path"
)
_OUTPUT_PATH = flags.DEFINE_string(
"output_path", "ckpts/timesfm-1.0-200m.pth", "The output path for the PyTorch model"
)


def rmrf(path: Path):
if path.is_dir():
for child in path.iterdir():
rmrf(child)
path.rmdir()
else:
path.unlink()


def dump(path: Path, data: dict):
# Collect metadata
meta_info = {}
for key, value in data.items():
if isinstance(value, dict):
new_path = path / key
new_path.mkdir(exist_ok=True)
meta_info[key] = dump(new_path, value)
elif isinstance(value, np.ndarray):
np.save(path / (key + ".npy"), value)
meta_info[key] = str(tuple(value.shape))
else:
raise ValueError(f"Unknown type {type(value)}, found in {path / key}")
return meta_info


def convert(dump_path: Path):
"""Core function to convert the directory structure to match PyTorch model"""
for path in dump_path.iterdir():
if path.is_dir():
# Rename directories and files to match PyTorch model, recursively
if (
len(list(path.iterdir())) == 2
and (path / "bias" / "b.npy").exists()
and (path / "linear" / "w.npy").exists()
):
# Move path/bias/b.npy to path/bias.npy
(path / "bias.npy").write_bytes((path / "bias" / "b.npy").read_bytes())
# Move path/linear/w.npy to path/weight.npy
(path / "weight.npy").write_bytes(
(path / "linear" / "w.npy").read_bytes()
)
rmrf(path / "bias")
rmrf(path / "linear")

if path.name == "ff_layer":
# Rename path to mlp
path = path.rename(path.parent / "mlp")

elif path.name == "ffn_layer1":
# Rename path to gate_proj
path = path.rename(path.parent / "gate_proj")

elif path.name == "ffn_layer2":
# Rename path to down_proj
path = path.rename(path.parent / "down_proj")

elif path.name == "freq_emb":
# Move emb_var.npy to weight.npy
(path / "weight.npy").write_bytes((path / "emb_var.npy").read_bytes())
rmrf(path / "emb_var.npy")

elif path.name.startswith("x_layers"):
# Rename path to the index number
idx = path.name.split("_")[-1]
path = path.rename(path.parent / idx)

elif path.name == "stacked_transformer_layer":
# Move everything in path to path/layers
(path / "layers").mkdir()
for file in path.iterdir():
if file.name != "layers":
file.rename(path / "layers" / file.name)
# Rename path to stacked_transformer
path = path.rename(path.parent / "stacked_transformer")

elif (
path.name == "layer_norm"
and (path / "bias.npy").exists()
and (path / "scale.npy").exists()
):
# Rename path/scale.npy to path/weight.npy
(path / "weight.npy").write_bytes((path / "scale.npy").read_bytes())
rmrf(path / "scale.npy")

elif path.name == "layer_norm" and not (path / "bias.npy").exists():
# Rename path to input_layernorm
path = path.rename(path.parent / "input_layernorm")
# Rename path/scale.npy to path/weight.npy
(path / "weight.npy").write_bytes((path / "scale.npy").read_bytes())
rmrf(path / "scale.npy")

elif path.name == "self_attention":
# Read everything in path
w_k: np.ndarray = np.load(path / "key" / "w.npy")
w_q: np.ndarray = np.load(path / "query" / "w.npy")
w_v: np.ndarray = np.load(path / "value" / "w.npy")
b_k: np.ndarray = np.load(path / "key" / "b.npy")
b_q: np.ndarray = np.load(path / "query" / "b.npy")
b_v: np.ndarray = np.load(path / "value" / "b.npy")
w_o: np.ndarray = np.load(path / "post" / "w.npy")
b_o: np.ndarray = np.load(path / "post" / "b.npy")
model_dim, num_heads, head_dim = w_k.shape
scaling = np.load(path / "per_dim_scale" / "per_dim_scale.npy")
rmrf(path)
# Rename path to self_attn
path = path.parent / "self_attn"
path.mkdir()
# Combine qkv weights and biases
w_qkv = np.concatenate(
[
w_q.reshape(model_dim, -1),
w_k.reshape(model_dim, -1),
w_v.reshape(model_dim, -1),
],
axis=-1,
)
b_qkv = np.concatenate(
[b_q.flatten(), b_k.flatten(), b_v.flatten()], axis=-1
)
qkv_proj = path / "qkv_proj"
qkv_proj.mkdir()
np.save(qkv_proj / "weight.npy", w_qkv)
np.save(qkv_proj / "bias.npy", b_qkv)
# Rename post to o_proj
o_proj = path / "o_proj"
o_proj.mkdir()
np.save(o_proj / "weight.npy", w_o.reshape(model_dim, -1))
np.save(o_proj / "bias.npy", b_o)
# Rename per_dim_scale to scaling
np.save(path / "scaling.npy", scaling)

convert(path)

elif path.is_file():
# Transpose the 2D weight matrices as they are different in PyTorch than Flax
match path.name:
case "weight.npy":
if path.parent.name == "freq_emb":
# This weight already matches the target shape
return
data = np.load(path)
if data.ndim == 2:
data = data.T
np.save(path, data)
case _:
pass


if __name__ == "__main__":
FLAGS = flags.FLAGS
FLAGS(sys.argv)
model_path = ROOT.joinpath(_MODEL_PATH.value)
temp_path = ROOT / "model_states"
output_path = ROOT.joinpath(_OUTPUT_PATH.value)
temp_path.mkdir(exist_ok=True)
output_path.parent.mkdir(exist_ok=True)

logger.info("Loading original model")
model = TimesFm(
context_len=512,
horizon_len=96,
input_patch_len=32,
output_patch_len=128,
num_layers=20,
model_dims=1280,
backend="cpu",
per_core_batch_size=16,
quantiles=list(np.arange(1, 10) / 10),
)
model.load_from_checkpoint(
str(model_path), checkpoint_type=checkpoints.CheckpointType.FLAX
)

logger.info("Dumping model weights to numpy files")
meta = dump(temp_path, model._train_state.mdl_vars["params"])
del model

logger.info("Converting directory structure to match PyTorch model")
convert(temp_path)

logger.info("Creating PyTorch model")
config = TimesFMConfig(quantiles=list(np.arange(1, 10) / 10.0))
model = PatchedTimeSeriesDecoder(config)
state_dict = model.state_dict()

logger.info("Loading weights into PyTorch model")
for k, v in state_dict.items():
parts = k.split(".")
parts[-1] += ".npy"
path = temp_path.joinpath(*parts)
if not path.exists():
print(k)
else:
npy = np.load(path)
# Compare shape
if tuple(v.shape) != npy.shape:
print(
f"Shape mismatch at {k}: expect {tuple(v.shape)}, found {npy.shape}"
)
# Update the model state
state_dict[k].copy_(torch.tensor(npy))

logger.info("Saving PyTorch model to disk")
torch.save(state_dict, "ckpts/timesfm-1.0-200m.pth")

logger.info("Cleaning up temporary files")
rmrf(temp_path)

logger.info("Done")

0 comments on commit 99130a9

Please sign in to comment.