Skip to content

Commit

Permalink
flaxlib in cc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696286840
  • Loading branch information
IvyZX authored and Flax Authors committed Nov 14, 2024
1 parent ac3e85a commit 0b19b33
Show file tree
Hide file tree
Showing 15 changed files with 37 additions and 24 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,10 @@ jobs:
uses: astral-sh/setup-uv@v2
with:
version: "0.3.0"
- name: Setup Rust (flaxlib)
uses: actions-rust-lang/setup-rust-toolchain@v1

- name: Install dependencies
run: |
uv sync --extra all --extra testing --extra docs
uv pip install ./flaxlib
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ build/
docs*/**/_autosummary
docs*/_build
docs*/**/tmp
flaxlib_src/build
flaxlib_src/builddir
flaxlib_src/dist
flaxlib_src/subprojects

# used by direnv
.envrc
Expand Down
15 changes: 0 additions & 15 deletions flaxlib/flaxlib/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 14 additions & 0 deletions flaxlib_src/meson.build
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
project(
'flaxlib',
'cpp',
version: '0.0.1',
default_options: ['cpp_std=c++17'],
)
py = import('python').find_installation()
nanobind_dep = dependency('nanobind', static: true)
py.extension_module(
'flaxlib',
sources: ['src/lib.cc'],
dependencies: [nanobind_dep],
install: true,
)
8 changes: 3 additions & 5 deletions flaxlib/pyproject.toml → flaxlib_src/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[build-system]
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
requires = ['meson-python']
build-backend = 'mesonpy'

[project]
name = "flaxlib"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: C++",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
Expand All @@ -15,5 +15,3 @@ dynamic = ["version"]
tests = [
"pytest",
]
[tool.maturin]
features = ["pyo3/extension-module"]
14 changes: 14 additions & 0 deletions flaxlib_src/src/lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <string>

#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h"

namespace flaxlib {
std::string sum_as_string(int a, int b) {
return std::to_string(a + b);
}

NB_MODULE(flaxlib, m) {
m.def("sum_as_string", &sum_as_string);
}
} // namespace flaxlib
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion tests/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ if $RUN_PYTEST; then
echo "=== RUNNING PYTESTS ==="
# Run some test on separate process, avoiding device configs poluting each other
PYTEST_IGNORE=
for file in "tests/jax_utils_test.py"; do
# TODO(ivyzheng): Remove flaxlib_test.py once we get CI build for flaxlib CC.
for file in "tests/jax_utils_test.py" "tests/flaxlib_test.py"; do
echo "pytest -n auto $file $PYTEST_OPTS"
pytest -n auto $file $PYTEST_OPTS
PYTEST_IGNORE+=" --ignore=$file"
Expand Down

0 comments on commit 0b19b33

Please sign in to comment.