-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] searchsorted #1255
Comments
We don't have it, but you can certainly do a binary search with existing ops. Here's something for 1D arrays that works, it should be fairly straight-forward to support an axis parameter if you need it: def searchsorted(a, b):
axis = 0
size = a.shape[axis]
steps = math.ceil(math.log2(size))
upper = size
lower = 0
indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
for _ in range(steps):
lt = b < a[indices]
new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
lower = mx.where(lt, lower, indices)
upper = mx.where(lt, indices, upper)
indices = new_indices
return indices Also it will be a lot faster if you |
I'm open to adding a little binary search implementation like that into MLX to support |
Another option is to do something like the following. It's linear in def searchsorted(a, b):
return (a[None, :] < b[:, None]).sum(axis=1) |
Ah okay I think I can make these workarounds for time time being. Thanks!! |
@awni can I start working on this issue since it was last active on july 7 |
By all means |
Do we have to go with this approach or with the one in numpy array ? |
I would compare def searchsorted(a, b):
axis = 0
size = a.shape[axis]
steps = math.ceil(math.log2(size))
upper = size
lower = 0
indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
for _ in range(steps):
lt = b < a[indices]
new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
lower = mx.where(lt, lower, indices)
upper = mx.where(lt, indices, upper)
indices = new_indices
return indices and def searchsorted(a, b):
return (a[None, :] < b[:, None]).sum(axis=1) And see which is faster. Presumably there will be a size at which the first is faster but it will start out slower. We could try to dispatch based on that. Or just use the more scalable version. |
Assuming B and A are representing length of array b and a. From scalability point of view (if we compare space and time complexity), I think 1st case looks more appropriate right? |
The constant factors of the logarithmic approach are quite larger so it is not as simple as that. The following is on my laptop. Also note that
The TL;DR is that if you want to search in less than 16k elements or if you only searching 1-2 elements it doesn't make much sense in using the binary search. If otoh you are searching for a lot of elements in a large sorted array (in the millions of elements), then you can expect 100x improvement using binary search :-) . |
Thank you @angeloskath for the info. I think based on the application point of view there will be rare cases involving more than 16k elements. Shall I implement linear search then? Or another idea is to implement both and use either based on condition such as sorted size. |
@awni in which particular module shall we implement this feature? |
Is there an equivalent to
np.searchsorted
or a way that I could reasonably implement something similar with the existing ops?The text was updated successfully, but these errors were encountered: