Skip to content

Commit

Permalink
support reading/writing sharded numpy safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
mar-muel committed Aug 16, 2024
1 parent 984bc11 commit 9d8809b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
40 changes: 38 additions & 2 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import os
from pickle import UnpicklingError
from typing import Dict, Tuple
from typing import Dict, Tuple, Union, Optional
import json

import jax
import jax.numpy as jnp
Expand All @@ -36,15 +37,44 @@
if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
from safetensors.torch import load_file as torch_safe_load_file


logger = logging.get_logger(__name__)
SAFETENSORS_FILE_EXTENSION = "safetensors"


#####################
# PyTorch => Flax #
#####################

def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return torch_safe_load_file(checkpoint_file, device="cpu")
else:
return torch.load(checkpoint_file, map_location="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
"model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
)

def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
Expand All @@ -60,6 +90,12 @@ def load_pytorch_checkpoint_in_flax_state_dict(
with safe_open(pt_path, framework="flax") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
elif pt_path.endswith("model.safetensors.index.json"):
index = json.load(open(pt_path))
pt_state_dict = {}
for filename in set(index['weight_map'].values()):
sd = load_state_dict(os.path.join(os.path.dirname(pt_path), filename))
pt_state_dict = {**pt_state_dict, **sd}
else:
try:
import torch # noqa: F401
Expand Down
72 changes: 41 additions & 31 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
from safetensors.flax import save_file as safe_save_file
from safetensors.numpy import load_file as np_safe_load_file

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -478,30 +479,36 @@ def load_flax_sharded_weights(cls, shard_files):
state_sharded_dict = {}

for shard_file in shard_files:
# load using msgpack utils
try:
with open(shard_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
with open(shard_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
if shard_file.endswith('.safetensors'):
shard = np_safe_load_file(shard_file)
state_sharded_dict.update(shard)
sep = '.'
else:
# load using msgpack utils
try:
with open(shard_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
with open(shard_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")

state = flatten_dict(state, sep="/")
state_sharded_dict.update(state)
del state
gc.collect()
state = flatten_dict(state, sep="/")
state_sharded_dict.update(state)
del state
gc.collect()
sep = '/'

# the state dict is unflattened to the match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/")
return unflatten_dict(state_sharded_dict, sep=sep)

@classmethod
def can_generate(cls) -> bool:
Expand Down Expand Up @@ -748,7 +755,7 @@ def from_pretrained(
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
# raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
Expand All @@ -768,7 +775,8 @@ def from_pretrained(
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
if from_pt:
filename = WEIGHTS_NAME
# filename = WEIGHTS_NAME
filename = SAFE_WEIGHTS_INDEX_NAME
else:
filename = FLAX_WEIGHTS_NAME

Expand Down Expand Up @@ -913,7 +921,9 @@ def from_pretrained(
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
if _do_init:
state = jax.tree_util.tree_map(jnp.array, state)
# state = jax.tree_util.tree_map(jnp.array, state)
cpu_device = jax.devices('cpu')[0]
state = jax.device_put(state, cpu_device)
else:
# keep the params on CPU if we don't want to initialize
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
Expand Down Expand Up @@ -1048,13 +1058,13 @@ def from_pretrained(
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)

if len(bf16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
# if len(bf16_params) > 0:
# logger.warning(
# f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
# f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
# "You should probably UPCAST the model weights to float32 if this was not intended. "
# "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
# )

# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
Expand Down

0 comments on commit 9d8809b

Please sign in to comment.