Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Differences between MLX and JAX #1480

Open
smorad opened this issue Oct 12, 2024 · 1 comment
Open

[Question] Differences between MLX and JAX #1480

smorad opened this issue Oct 12, 2024 · 1 comment

Comments

@smorad
Copy link

smorad commented Oct 12, 2024

This is somewhat related to #12, although I can see how mlx improves significantly upon torch. My question is why reinvent the wheel with mlx, when the core of mlx seems to closely follow jax.

The mlx documentation lists these differences between torch/jax:

The design of MLX is inspired by frameworks like PyTorch, Jax, and ArrayFire. A notable difference from these frameworks and MLX is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without performing data copies. Currently supported device types are the CPU and GPU.

Is there some part of jax that prevents utilizing shared memory? Unlike torch, you are generally not shuffling data around with .to(device) in jax. Couldn't Apple write an OpenXLA backend that avoids copies and executes operations on Neural Engine/GPU/CPU devices as necessary?

The other major difference between jax and mlx seems to be the lazy evaluation. Lazy evaluation is not possible in jax, but doesn't jax.compile solve a similar problem to lazy evaluation? If you compile the outermost function, then unused variables, unnecessary function calls, etc are compiled away. Are there situations where one would prefer lazy evaluation over compilation?

I suppose it seems like a shame to me that the mlx and jax interfaces are so similar, because I could imagine how nice it would be to prototype a model on my Macbook, then deploy to a CUDA cluster for larger-scale training. I would imagine that this would also provide a smoother transition for researchers to move away from CUDA as Apple's ML hardware improves.

@smorad smorad changed the title Differences between MLX and JAX [Question] Differences between MLX and JAX Oct 12, 2024
@Downchuck
Copy link

Long-standing issues like jax-ml/jax#16321 are also one of the differences.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants