Skip to content

Commit

Permalink
Improving bf16 tests (#507)
Browse files Browse the repository at this point in the history
* Add mlx.

* Add flax.

* Yaml is possibly worse than XML at this point.

* .

* Fixing BF16+BigEndian+Torch on Torch > 2.1

* Fix macos-13 so Python 3.8 is supported.

* Clippy fix.

* Adding macos-latest (ARM) to the test suite.

* Fmt.

* Installing cargo audit.

* I.

* Fixing format.

* Update 2 locations.

* tf.random.normal

* Random normal flax

* Fixing flax test.

* Remove bool test for flax.

* Attempting to upgrade numpy version.

* Revert pyproject.toml

* Fixing bool test (PT).

* Fixing new ARM element of matrix.

* Install mlx on ARM.

* Ignoring flax on macos arm.

* No bfloat16 on mlx.

* Reference issue in code.

* Re-enabling jax tests ?

* Removing ARM+macos target.

* Irrelevant changes.

* Trying to debug remotely.

* Different logs.

* This isn't counterproductive dev speed.

* Install mlx on old non arm macos.

* Adding a clone ?

* Checking bfloat16 serialized values.

* Linter.

* Hopefully fixing the byteswapping.

* Run mlx only on Python 3.11.

* Yaml..

* Fixing MLX tests (making it optional/skipped unfortunately).

* Much better fix that doesn't depend on torch version.

* Cleaner escape for MLX.

* Fixing F8XXX dtypes too on big endians + torch.

* inplace actually modifies storage.

* Inplace for numpy requires new holding place.
  • Loading branch information
Narsil authored Jul 31, 2024
1 parent 2331974 commit c00471e
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 39 deletions.
17 changes: 14 additions & 3 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, macos-13, windows-latest]
# Lowest and highest, no version specified so that
# new releases get automatically tested against
version: [{torch: torch==1.10, python: "3.8"}, {torch: torch, python: "3.11"}]
# TODO this would include macos ARM target.
# however jax has an illegal instruction issue
# that exists only in CI (probably difference in instruction support).
# include:
# - os: macos-latest
# version:
# torch: torch
# python: "3.11"
defaults:
run:
working-directory: ./bindings/python
Expand All @@ -27,12 +35,15 @@ jobs:
toolchain: stable
components: rustfmt, clippy

- name: Cargo install audit
run: cargo install cargo-audit

- uses: Swatinem/rust-cache@v2
with:
workspaces: "bindings/python"

- name: Install Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.version.python }}
architecture: "x64"
Expand All @@ -59,7 +70,7 @@ jobs:
shell: bash

- name: Install (mlx)
if: matrix.os == 'macos-latest' && matrix.version.python == "3.10"
if: matrix.os == 'macos-latest'
run: |
pip install .[mlx]
shell: bash
Expand Down
5 changes: 4 additions & 1 deletion bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def save_model(
raise ValueError(msg)


def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu") -> Tuple[List[str], List[str]]:
def load_model(
model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu"
) -> Tuple[List[str], List[str]]:
"""
Loads a given filename onto a torch model.
This method exists specifically to avoid tensor sharing issues which are
Expand Down Expand Up @@ -340,6 +342,7 @@ def load(data: bytes) -> Dict[str, torch.Tensor]:
flat = deserialize(data)
return _view2torch(flat)


# torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
_float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
_float8_e5m2 = getattr(torch, "float8_e5m2", None)
Expand Down
73 changes: 58 additions & 15 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,27 +595,36 @@ impl Open {
if byteorder == "big" {
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py);
if info.dtype == Dtype::BF16 {
let torch_f16: PyObject = get_pydtype(torch, Dtype::F16, false)?;
tensor = tensor.getattr(intern!(py, "to"))?.call(
(),
Some(&[(intern!(py, "dtype"), torch_f16)].into_py_dict_bound(py)),
)?;
}

let intermediary_dtype = match info.dtype {
Dtype::BF16 => Some(Dtype::F16),
Dtype::F8_E5M2 => Some(Dtype::U8),
Dtype::F8_E4M3 => Some(Dtype::U8),
_ => None,
};
if let Some(intermediary_dtype) = intermediary_dtype {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, intermediary_dtype, false)?;
let view_kwargs =
[(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;

if info.dtype == Dtype::BF16 {
let torch_bf16: PyObject = get_pydtype(torch, Dtype::BF16, false)?;
tensor = tensor.getattr(intern!(py, "to"))?.call(
(),
Some(&[(intern!(py, "dtype"), torch_bf16)].into_py_dict_bound(py)),
)?;
if intermediary_dtype.is_some() {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, info.dtype, false)?;
let view_kwargs =
[(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
}

Expand Down Expand Up @@ -941,15 +950,39 @@ impl PySafeSlice {
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
if byteorder == "big" {
// Important, do NOT use inplace otherwise the slice itself
// is byteswapped, meaning multiple calls will fails
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py);

let intermediary_dtype = match self.info.dtype {
Dtype::BF16 => Some(Dtype::F16),
Dtype::F8_E5M2 => Some(Dtype::U8),
Dtype::F8_E4M3 => Some(Dtype::U8),
_ => None,
};
if let Some(intermediary_dtype) = intermediary_dtype {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, intermediary_dtype, false)?;
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;
if intermediary_dtype.is_some() {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, self.info.dtype, false)?;
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
}
tensor = tensor
.getattr(intern!(py, "reshape"))?
Expand Down Expand Up @@ -1024,7 +1057,17 @@ fn create_tensor<'a>(
(intern!(py, "dtype"), dtype),
]
.into_py_dict_bound(py);
module.call_method("frombuffer", (), Some(&kwargs))?
let mut tensor = module.call_method("frombuffer", (), Some(&kwargs))?;
let sys = PyModule::import_bound(py, intern!(py, "sys"))?;
let byteorder: String = sys.getattr(intern!(py, "byteorder"))?.extract()?;
if byteorder == "big" {
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py);
tensor = tensor
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
}
tensor
};
let mut tensor: PyBound<'_, PyAny> = tensor.call_method1("reshape", (shape,))?;
let tensor = match framework {
Expand Down
9 changes: 5 additions & 4 deletions bindings/python/tests/test_flax_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# This platform is not supported, we don't want to crash on import
# This test will be skipped anyway.
import jax.numpy as jnp
from jax import random
from flax.serialization import msgpack_restore, msgpack_serialize
from safetensors import safe_open
from safetensors.flax import load_file, save_file
Expand All @@ -15,11 +16,11 @@
@unittest.skipIf(platform.system() == "Windows", "Flax is not available on Windows")
class LoadTestCase(unittest.TestCase):
def setUp(self):
key = random.key(0)
data = {
"test": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test2": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test3": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test4": jnp.zeros((1024, 1024), dtype=jnp.bfloat16),
"test": random.normal(key, (1024, 1024), dtype=jnp.float32),
"test2": random.normal(key, (1024, 1024), dtype=jnp.float16),
"test3": random.normal(key, (1024, 1024), dtype=jnp.bfloat16),
}
self.flax_filename = "./tests/data/flax_load.msgpack"
self.sf_filename = "./tests/data/flax_load.safetensors"
Expand Down
27 changes: 19 additions & 8 deletions bindings/python/tests/test_mlx_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,39 @@
import unittest


HAS_MLX = False
if platform.system() == "Darwin":
# This platform is not supported, we don't want to crash on import
# This test will be skipped anyway.
import mlx.core as mx
from safetensors import safe_open
from safetensors.mlx import load_file, save_file
try:
import mlx.core as mx

HAS_MLX = True
except ImportError:
pass
if HAS_MLX:
from safetensors import safe_open
from safetensors.mlx import load_file, save_file


# MLX only exists on Mac
@unittest.skipIf(platform.system() != "Darwin", "Mlx is not available on non Mac")
@unittest.skipIf(not HAS_MLX, "Mlx is not available.")
class LoadTestCase(unittest.TestCase):
def setUp(self):
data = {
"test": mx.zeros((1024, 1024), dtype=mx.float32),
"test2": mx.zeros((1024, 1024), dtype=mx.float32),
"test3": mx.zeros((1024, 1024), dtype=mx.float32),
"test4": mx.zeros((1024, 1024), dtype=mx.bfloat16),
"test": mx.randn((1024, 1024), dtype=mx.float32),
"test2": mx.randn((1024, 1024), dtype=mx.float32),
"test3": mx.randn((1024, 1024), dtype=mx.float32),
# This doesn't work because bfloat16 is not implemented
# with similar workarounds as jax/tensorflow.
# https://github.com/ml-explore/mlx/issues/1296
# "test4": mx.randn((1024, 1024), dtype=mx.bfloat16),
}
self.mlx_filename = "./tests/data/mlx_load.npz"
self.sf_filename = "./tests/data/mlx_load.safetensors"

serialized = mx.savez(self.mlx_filename, **data)
mx.savez(self.mlx_filename, **data)
save_file(data, self.sf_filename)

def test_zero_sized(self):
Expand Down
24 changes: 17 additions & 7 deletions bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,22 @@ def test_serialization(self):
b" \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00",
)

data = torch.ones((2, 2), dtype=torch.bfloat16)
data[0, 0] = 2.25
out = save({"test": data})
self.assertEqual(
out,
b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"BF16","shape":[2,2],"data_offsets":[0,8]}} \x10@\x80?\x80?\x80?',
)

def test_odd_dtype(self):
data = {
"test": torch.zeros((2, 2), dtype=torch.bfloat16),
"test2": torch.zeros((2, 2), dtype=torch.float16),
"test": torch.randn((2, 2), dtype=torch.bfloat16),
"test2": torch.randn((2, 2), dtype=torch.float16),
"test3": torch.zeros((2, 2), dtype=torch.bool),
}
# Modify bool to have both values.
data["test3"][0, 0] = True
local = "./tests/data/out_safe_pt_mmap_small.safetensors"

save_file(data, local)
Expand All @@ -66,7 +76,7 @@ def test_odd_dtype(self):

def test_odd_dtype_fp8(self):
if torch.__version__ < "2.1":
return # torch.float8 requires 2.1
return # torch.float8 requires 2.1

data = {
"test1": torch.tensor([-0.5], dtype=torch.float8_e4m3fn),
Expand All @@ -77,10 +87,10 @@ def test_odd_dtype_fp8(self):
save_file(data, local)
reloaded = load_file(local)
# note: PyTorch doesn't implement torch.equal for float8 so we just compare the single element
self.assertEqual(data["test1"].dtype, torch.float8_e4m3fn)
self.assertEqual(data["test1"].item(), -0.5)
self.assertEqual(data["test2"].dtype, torch.float8_e5m2)
self.assertEqual(data["test2"].item(), -0.5)
self.assertEqual(reloaded["test1"].dtype, torch.float8_e4m3fn)
self.assertEqual(reloaded["test1"].item(), -0.5)
self.assertEqual(reloaded["test2"].dtype, torch.float8_e5m2)
self.assertEqual(reloaded["test2"].item(), -0.5)

def test_zero_sized(self):
data = {
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/tests/test_tf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_deserialization_safe(self):

def test_bfloat16(self):
data = {
"test": tf.zeros((1024, 1024), dtype=tf.bfloat16),
"test": tf.random.normal((1024, 1024), dtype=tf.bfloat16),
}
save_file(data, self.sf_filename)
weights = {}
Expand Down

0 comments on commit c00471e

Please sign in to comment.