You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is somewhat related to #12, although I can see how mlx improves significantly upon torch. My question is why reinvent the wheel with mlx, when the core of mlx seems to closely follow jax.
The mlx documentation lists these differences between torch/jax:
The design of MLX is inspired by frameworks like PyTorch, Jax, and ArrayFire. A notable difference from these frameworks and MLX is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without performing data copies. Currently supported device types are the CPU and GPU.
Is there some part of jax that prevents utilizing shared memory? Unlike torch, you are generally not shuffling data around with .to(device) in jax. Couldn't Apple write an OpenXLA backend that avoids copies and executes operations on Neural Engine/GPU/CPU devices as necessary?
The other major difference between jax and mlx seems to be the lazy evaluation. Lazy evaluation is not possible in jax, but doesn't jax.compile solve a similar problem to lazy evaluation? If you compile the outermost function, then unused variables, unnecessary function calls, etc are compiled away. Are there situations where one would prefer lazy evaluation over compilation?
I suppose it seems like a shame to me that the mlx and jax interfaces are so similar, because I could imagine how nice it would be to prototype a model on my Macbook, then deploy to a CUDA cluster for larger-scale training. I would imagine that this would also provide a smoother transition for researchers to move away from CUDA as Apple's ML hardware improves.
The text was updated successfully, but these errors were encountered:
smorad
changed the title
Differences between MLX and JAX
[Question] Differences between MLX and JAX
Oct 12, 2024
This is somewhat related to #12, although I can see how
mlx
improves significantly upontorch
. My question is why reinvent the wheel withmlx
, when the core ofmlx
seems to closely followjax
.The
mlx
documentation lists these differences betweentorch
/jax
:Is there some part of
jax
that prevents utilizing shared memory? Unliketorch
, you are generally not shuffling data around with.to(device)
injax
. Couldn't Apple write an OpenXLA backend that avoids copies and executes operations on Neural Engine/GPU/CPU devices as necessary?The other major difference between
jax
andmlx
seems to be the lazy evaluation. Lazy evaluation is not possible injax
, but doesn'tjax.compile
solve a similar problem to lazy evaluation? If youcompile
the outermost function, then unused variables, unnecessary function calls, etc are compiled away. Are there situations where one would prefer lazy evaluation over compilation?I suppose it seems like a shame to me that the
mlx
andjax
interfaces are so similar, because I could imagine how nice it would be to prototype a model on my Macbook, then deploy to a CUDA cluster for larger-scale training. I would imagine that this would also provide a smoother transition for researchers to move away from CUDA as Apple's ML hardware improves.The text was updated successfully, but these errors were encountered: