Skip to content

Commit

Permalink
fix: use direct mean or maximum only for reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 11, 2024
1 parent 43ac5ef commit 25f13de
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Static: False
using Lux: Lux, LuxOps, Training
using Lux.Training: TrainingBackendCache, ReactantBackend

include("patches.jl")
include("training.jl")

end
2 changes: 2 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
(g::Lux.GlobalPoolMode)(::TracedRArray) = g
3 changes: 1 addition & 2 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ end

struct GlobalPoolMode <: AbstractPoolMode end

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
(g::GlobalPoolMode)(::Any) = g
(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)])

@concrete struct AdaptivePoolMode <: AbstractPoolMode
out_size <: Tuple{Vararg{IntegerType}}
Expand Down

0 comments on commit 25f13de

Please sign in to comment.