-
Notifications
You must be signed in to change notification settings - Fork 327
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add convertion script for PyTorch model
This script helps to convert the PAX model to PyTorch model and saves the converted checkpoint for future loading
- Loading branch information
1 parent
271aecf
commit 99130a9
Showing
1 changed file
with
238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
import logging | ||
import sys | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
from absl import flags | ||
from paxml import checkpoints | ||
|
||
from timesfm.timesfm import TimesFm | ||
|
||
from .pytorch_patched_decoder import ( | ||
PatchedTimeSeriesDecoder, | ||
TimesFMConfig, | ||
) | ||
|
||
logger = logging.getLogger(Path(__file__).stem) | ||
logger.setLevel(logging.INFO) | ||
|
||
ROOT = Path(__file__).parent.parent.parent | ||
FLAGS = flags.FLAGS | ||
_MODEL_PATH = flags.DEFINE_string( | ||
"model_path", "timesfm-1.0-200m/checkpoints/", "The model checkpoint path" | ||
) | ||
_OUTPUT_PATH = flags.DEFINE_string( | ||
"output_path", "ckpts/timesfm-1.0-200m.pth", "The output path for the PyTorch model" | ||
) | ||
|
||
|
||
def rmrf(path: Path): | ||
if path.is_dir(): | ||
for child in path.iterdir(): | ||
rmrf(child) | ||
path.rmdir() | ||
else: | ||
path.unlink() | ||
|
||
|
||
def dump(path: Path, data: dict): | ||
# Collect metadata | ||
meta_info = {} | ||
for key, value in data.items(): | ||
if isinstance(value, dict): | ||
new_path = path / key | ||
new_path.mkdir(exist_ok=True) | ||
meta_info[key] = dump(new_path, value) | ||
elif isinstance(value, np.ndarray): | ||
np.save(path / (key + ".npy"), value) | ||
meta_info[key] = str(tuple(value.shape)) | ||
else: | ||
raise ValueError(f"Unknown type {type(value)}, found in {path / key}") | ||
return meta_info | ||
|
||
|
||
def convert(dump_path: Path): | ||
"""Core function to convert the directory structure to match PyTorch model""" | ||
for path in dump_path.iterdir(): | ||
if path.is_dir(): | ||
# Rename directories and files to match PyTorch model, recursively | ||
if ( | ||
len(list(path.iterdir())) == 2 | ||
and (path / "bias" / "b.npy").exists() | ||
and (path / "linear" / "w.npy").exists() | ||
): | ||
# Move path/bias/b.npy to path/bias.npy | ||
(path / "bias.npy").write_bytes((path / "bias" / "b.npy").read_bytes()) | ||
# Move path/linear/w.npy to path/weight.npy | ||
(path / "weight.npy").write_bytes( | ||
(path / "linear" / "w.npy").read_bytes() | ||
) | ||
rmrf(path / "bias") | ||
rmrf(path / "linear") | ||
|
||
if path.name == "ff_layer": | ||
# Rename path to mlp | ||
path = path.rename(path.parent / "mlp") | ||
|
||
elif path.name == "ffn_layer1": | ||
# Rename path to gate_proj | ||
path = path.rename(path.parent / "gate_proj") | ||
|
||
elif path.name == "ffn_layer2": | ||
# Rename path to down_proj | ||
path = path.rename(path.parent / "down_proj") | ||
|
||
elif path.name == "freq_emb": | ||
# Move emb_var.npy to weight.npy | ||
(path / "weight.npy").write_bytes((path / "emb_var.npy").read_bytes()) | ||
rmrf(path / "emb_var.npy") | ||
|
||
elif path.name.startswith("x_layers"): | ||
# Rename path to the index number | ||
idx = path.name.split("_")[-1] | ||
path = path.rename(path.parent / idx) | ||
|
||
elif path.name == "stacked_transformer_layer": | ||
# Move everything in path to path/layers | ||
(path / "layers").mkdir() | ||
for file in path.iterdir(): | ||
if file.name != "layers": | ||
file.rename(path / "layers" / file.name) | ||
# Rename path to stacked_transformer | ||
path = path.rename(path.parent / "stacked_transformer") | ||
|
||
elif ( | ||
path.name == "layer_norm" | ||
and (path / "bias.npy").exists() | ||
and (path / "scale.npy").exists() | ||
): | ||
# Rename path/scale.npy to path/weight.npy | ||
(path / "weight.npy").write_bytes((path / "scale.npy").read_bytes()) | ||
rmrf(path / "scale.npy") | ||
|
||
elif path.name == "layer_norm" and not (path / "bias.npy").exists(): | ||
# Rename path to input_layernorm | ||
path = path.rename(path.parent / "input_layernorm") | ||
# Rename path/scale.npy to path/weight.npy | ||
(path / "weight.npy").write_bytes((path / "scale.npy").read_bytes()) | ||
rmrf(path / "scale.npy") | ||
|
||
elif path.name == "self_attention": | ||
# Read everything in path | ||
w_k: np.ndarray = np.load(path / "key" / "w.npy") | ||
w_q: np.ndarray = np.load(path / "query" / "w.npy") | ||
w_v: np.ndarray = np.load(path / "value" / "w.npy") | ||
b_k: np.ndarray = np.load(path / "key" / "b.npy") | ||
b_q: np.ndarray = np.load(path / "query" / "b.npy") | ||
b_v: np.ndarray = np.load(path / "value" / "b.npy") | ||
w_o: np.ndarray = np.load(path / "post" / "w.npy") | ||
b_o: np.ndarray = np.load(path / "post" / "b.npy") | ||
model_dim, num_heads, head_dim = w_k.shape | ||
scaling = np.load(path / "per_dim_scale" / "per_dim_scale.npy") | ||
rmrf(path) | ||
# Rename path to self_attn | ||
path = path.parent / "self_attn" | ||
path.mkdir() | ||
# Combine qkv weights and biases | ||
w_qkv = np.concatenate( | ||
[ | ||
w_q.reshape(model_dim, -1), | ||
w_k.reshape(model_dim, -1), | ||
w_v.reshape(model_dim, -1), | ||
], | ||
axis=-1, | ||
) | ||
b_qkv = np.concatenate( | ||
[b_q.flatten(), b_k.flatten(), b_v.flatten()], axis=-1 | ||
) | ||
qkv_proj = path / "qkv_proj" | ||
qkv_proj.mkdir() | ||
np.save(qkv_proj / "weight.npy", w_qkv) | ||
np.save(qkv_proj / "bias.npy", b_qkv) | ||
# Rename post to o_proj | ||
o_proj = path / "o_proj" | ||
o_proj.mkdir() | ||
np.save(o_proj / "weight.npy", w_o.reshape(model_dim, -1)) | ||
np.save(o_proj / "bias.npy", b_o) | ||
# Rename per_dim_scale to scaling | ||
np.save(path / "scaling.npy", scaling) | ||
|
||
convert(path) | ||
|
||
elif path.is_file(): | ||
# Transpose the 2D weight matrices as they are different in PyTorch than Flax | ||
match path.name: | ||
case "weight.npy": | ||
if path.parent.name == "freq_emb": | ||
# This weight already matches the target shape | ||
return | ||
data = np.load(path) | ||
if data.ndim == 2: | ||
data = data.T | ||
np.save(path, data) | ||
case _: | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
FLAGS = flags.FLAGS | ||
FLAGS(sys.argv) | ||
model_path = ROOT.joinpath(_MODEL_PATH.value) | ||
temp_path = ROOT / "model_states" | ||
output_path = ROOT.joinpath(_OUTPUT_PATH.value) | ||
temp_path.mkdir(exist_ok=True) | ||
output_path.parent.mkdir(exist_ok=True) | ||
|
||
logger.info("Loading original model") | ||
model = TimesFm( | ||
context_len=512, | ||
horizon_len=96, | ||
input_patch_len=32, | ||
output_patch_len=128, | ||
num_layers=20, | ||
model_dims=1280, | ||
backend="cpu", | ||
per_core_batch_size=16, | ||
quantiles=list(np.arange(1, 10) / 10), | ||
) | ||
model.load_from_checkpoint( | ||
str(model_path), checkpoint_type=checkpoints.CheckpointType.FLAX | ||
) | ||
|
||
logger.info("Dumping model weights to numpy files") | ||
meta = dump(temp_path, model._train_state.mdl_vars["params"]) | ||
del model | ||
|
||
logger.info("Converting directory structure to match PyTorch model") | ||
convert(temp_path) | ||
|
||
logger.info("Creating PyTorch model") | ||
config = TimesFMConfig(quantiles=list(np.arange(1, 10) / 10.0)) | ||
model = PatchedTimeSeriesDecoder(config) | ||
state_dict = model.state_dict() | ||
|
||
logger.info("Loading weights into PyTorch model") | ||
for k, v in state_dict.items(): | ||
parts = k.split(".") | ||
parts[-1] += ".npy" | ||
path = temp_path.joinpath(*parts) | ||
if not path.exists(): | ||
print(k) | ||
else: | ||
npy = np.load(path) | ||
# Compare shape | ||
if tuple(v.shape) != npy.shape: | ||
print( | ||
f"Shape mismatch at {k}: expect {tuple(v.shape)}, found {npy.shape}" | ||
) | ||
# Update the model state | ||
state_dict[k].copy_(torch.tensor(npy)) | ||
|
||
logger.info("Saving PyTorch model to disk") | ||
torch.save(state_dict, "ckpts/timesfm-1.0-200m.pth") | ||
|
||
logger.info("Cleaning up temporary files") | ||
rmrf(temp_path) | ||
|
||
logger.info("Done") |