-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: inherit scalar indexing functionality from GPUArraysCore #268
Conversation
@mofeing can you check if this helps your case where you saw the scalar indexing warnings |
we confirm that this remove the infinite warnings we had in our code. Thanks @avik-pal! I would approve the PR but it seems like this is breaking the tests? |
The x86 ones are broken since we dont have the binaries in-place. But I still need to add some tests before merging |
src/Reactant.jl
Outdated
@@ -110,12 +111,19 @@ function __init__() | |||
end | |||
|
|||
function set_default_backend(backend::XLA.Client) | |||
if backend === XLA.backends["cpu"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this won't quite work because we can end up with both cpu and gpu tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For XLA Buffer I can do a check with buffer on cpu, but I couldn't figure out how to do it for TracedRArray.
One solution is to set the local_task_storage to ScalarAllowed for CPU when entering the compile function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah traced we should always err for there, because that has its own problems of accidentally splitting up tensor ops into a bunch of scalars, regardless of backend impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TracedRArray
doesn't know about the backend accelerator, be CPU, GPU or TPU. actually, HLO dialects neither know about which backend are they gonna run in.
@wsmoses correct me if I'm wrong but that step is done later in XLA when compiling HLO to native executable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean even if they did [and you're right they don't] we should err for traced
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I asked to remove them for CPU because it doesn't make much sense to raise a warning for CPU and they pollute a loooot the stdout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think removing for cpu concretearray is fine, but the problem is that it will equally pollute the IR we compile on traced, so we should still warn (or allowscalar)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the behavior to (default CUDA behavior):
- Allowed by warn in REPL
- Disallowed with error in scripts
- Can be locally allowed without warning using
@allowscalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we reexport allowscalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
) | ||
getindex_warned[] = true | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have this call a function of our own which calls gpuarrays assertscalar if it's loaded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can, but it is extremely lightweight:
julia> @time_imports using GPUArraysCore
0.2 ms Adapt
0.4 ms GPUArraysCore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eh okay I'm fine with this then
4b39384
to
0ae7cb3
Compare
389c6c9
to
371a21f
Compare
df60d2f
to
f54a37b
Compare
@@ -16,7 +14,7 @@ using InteractiveUtils | |||
|
|||
a = Reactant.ConcreteRArray(x) | |||
|
|||
c_res = sum(a) | |||
c_res = @allowscalar sum(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay ideally this shouldn't be required. I feel like a load of a concretenumber/tracednumber itself should be automatically allowscalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These work fine if the backend is CPU but the default implementation of sum will just loop over the indices which fails the GPU ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait really, shouldn’t it fall back to a reduce?
if not this is definitely a bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry this is for the concretearray not traced
yeah we should still eventually make this a reduce, but for another time
needs some tests before merging
Example Usage
On CPU no error is ever thrown unless the user manually opts in for no-scalar indexing.
fixes #232