diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index d1ffac2cd..6e4b2dd96 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -2,6 +2,7 @@ Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -15,6 +16,7 @@ PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -23,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Comonicon = "1.0.8" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" +Enzyme = "0.13.14" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" @@ -36,6 +39,7 @@ PreferenceTools = "0.1.2" Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" +Reactant = "0.2.5" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md index f072c1074..560b2b1d3 100644 --- a/examples/ConvMixer/README.md +++ b/examples/ConvMixer/README.md @@ -11,6 +11,9 @@ for new experiments on small datasets. You can get around **90.0%** accuracy in just **25 epochs** by running the script with the following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. +> [!NOTE] +> To train the model using Reactant.jl pass in `--backend=reactant` to the script. + ```bash julia --startup-file=no \ --project=. \ diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 03ddc63a5..602ede83c 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -1,6 +1,7 @@ using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random, StableRNGs, Statistics, Zygote +using Reactant, Enzyme CUDA.allowscalar(false) @@ -17,7 +18,7 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y end -function get_dataloaders(batchsize) +function get_dataloaders(batchsize; kwargs...) cifar10_mean = (0.4914, 0.4822, 0.4465) cifar10_std = (0.2471, 0.2435, 0.2616) @@ -29,10 +30,10 @@ function get_dataloaders(batchsize) test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) trainset = TensorDataset(CIFAR10(:train), train_transform) - trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true) + trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...) testset = TensorDataset(CIFAR10(:test), test_transform) - testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true) + testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...) return trainloader, testloader end @@ -43,10 +44,14 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), BatchNorm(dim), [Chain( - SkipConnection( - Chain(Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, - pad=SamePad()), BatchNorm(dim)), +), - Conv((1, 1), dim => dim, gelu), BatchNorm(dim)) + SkipConnection( + Chain( + Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()), + BatchNorm(dim) + ), + + + ), + Conv((1, 1), dim => dim, gelu), BatchNorm(dim)) for _ in 1:depth]..., GlobalMeanPool(), FlattenLayer(), @@ -57,10 +62,11 @@ end function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 + cdev = cpu_device() st = Lux.testmode(st) for (x, y) in dataloader - target_class = onecold(y) - predicted_class = onecold(first(model(x, ps, st))) + target_class = onecold(cdev(y)) + predicted_class = onecold(cdev(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -69,23 +75,46 @@ end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, - clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01) + clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, + backend::String="gpu_if_available") rng = StableRNG(seed) - gdev = gpu_device() - trainloader, testloader = get_dataloaders(batchsize) .|> gdev + if backend == "gpu_if_available" + accelerator_device = gpu_device() + elseif backend == "gpu" + accelerator_device = gpu_device(; force=true) + elseif backend == "reactant" + accelerator_device = reactant_device(; force=true) + elseif backend == "cpu" + accelerator_device = cpu_device() + else + error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ + `reactant`, and `cpu`.") + end + + kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () + trainloader, testloader = get_dataloaders(batchsize; kwargs...) |> accelerator_device model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) - ps, st = Lux.setup(rng, model) |> gdev + ps, st = Lux.setup(rng, model) |> accelerator_device opt = AdamW(; eta=lr_max, lambda=weight_decay) clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) - train_state = Training.TrainState( - model, ps, st, AdamW(; eta=lr_max, lambda=weight_decay)) + train_state = Training.TrainState(model, ps, st, opt) lr_schedule = linear_interpolation( - [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]) + [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] + ) + + adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() + + if backend == "reactant" + x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device + model_compiled = @compile model(x_ra, ps, st) + else + model_compiled = model + end loss = CrossEntropyLoss(; logits=Val(true)) @@ -96,14 +125,17 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) train_state = Optimisers.adjust!(train_state, lr) (_, _, _, train_state) = Training.single_train_step!( - AutoZygote(), loss, (x, y), train_state) + adtype, loss, (x, y), train_state + ) end ttime = time() - stime train_acc = accuracy( - model, train_state.parameters, train_state.states, trainloader) * 100 - test_acc = accuracy(model, train_state.parameters, train_state.states, testloader) * - 100 + model_compiled, train_state.parameters, train_state.states, trainloader + ) * 100 + test_acc = accuracy( + model_compiled, train_state.parameters, train_state.states, testloader + ) * 100 @printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \ Time: %.2f\n" epoch lr train_acc test_acc ttime diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 0af6705e4..56876d8e0 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x) function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber) return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y)) end + +# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint +Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 819aaaeeb..20ec3d6da 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -40,15 +40,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode abstract type AbstractPoolOp end struct MaxPoolOp <: AbstractPoolOp end + (m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) +function (m::MaxPoolOp)(x, ::GlobalPoolMode) + return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf)) +end struct MeanPoolOp <: AbstractPoolOp end + (m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) +(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2)) @concrete struct LpPoolOp <: AbstractPoolOp p end + (m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) +(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p) symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()