diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index ce0e0cd06..c7065da3b 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -9,6 +9,7 @@ using Static: False using Lux: Lux, LuxOps, Training using Lux.Training: TrainingBackendCache, ReactantBackend +include("patches.jl") include("training.jl") end diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl new file mode 100644 index 000000000..e173ab023 --- /dev/null +++ b/ext/LuxReactantExt/patches.jl @@ -0,0 +1,2 @@ +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +(g::Lux.GlobalPoolMode)(::TracedRArray) = g diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 32c5d5e5f..da729191c 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -13,8 +13,7 @@ end struct GlobalPoolMode <: AbstractPoolMode end -# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint -(g::GlobalPoolMode)(::Any) = g +(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)]) @concrete struct AdaptivePoolMode <: AbstractPoolMode out_size <: Tuple{Vararg{IntegerType}}