Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update ConvMixer to support reactant #1063

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -15,6 +16,7 @@ PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -23,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Comonicon = "1.0.8"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3"
Enzyme = "0.13.14"
ImageCore = "0.10.2"
ImageShow = "0.3.8"
Interpolations = "0.15.1"
Expand All @@ -36,6 +39,7 @@ PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.5"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
3 changes: 3 additions & 0 deletions examples/ConvMixer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ for new experiments on small datasets.
You can get around **90.0%** accuracy in just **25 epochs** by running the script with the
following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2.

> [!NOTE]
> To train the model using Reactant.jl pass in `--backend=reactant` to the script.

```bash
julia --startup-file=no \
--project=. \
Expand Down
72 changes: 52 additions & 20 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA,
MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random,
StableRNGs, Statistics, Zygote
using Reactant, Enzyme

CUDA.allowscalar(false)

Expand All @@ -17,7 +18,7 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y
end

function get_dataloaders(batchsize)
function get_dataloaders(batchsize; kwargs...)
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

Expand All @@ -29,10 +30,10 @@ function get_dataloaders(batchsize)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...)

return trainloader, testloader
end
Expand All @@ -43,10 +44,14 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size),
BatchNorm(dim),
[Chain(
SkipConnection(
Chain(Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim,
pad=SamePad()), BatchNorm(dim)), +),
Conv((1, 1), dim => dim, gelu), BatchNorm(dim))
SkipConnection(
Chain(
Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()),
BatchNorm(dim)
),
+
),
Conv((1, 1), dim => dim, gelu), BatchNorm(dim))
for _ in 1:depth]...,
GlobalMeanPool(),
FlattenLayer(),
Expand All @@ -57,10 +62,11 @@ end

function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
Expand All @@ -69,23 +75,46 @@ end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01)
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="gpu_if_available")
rng = StableRNG(seed)

gdev = gpu_device()
trainloader, testloader = get_dataloaders(batchsize) .|> gdev
if backend == "gpu_if_available"
accelerator_device = gpu_device()
elseif backend == "gpu"
accelerator_device = gpu_device(; force=true)
elseif backend == "reactant"
accelerator_device = reactant_device(; force=true)
elseif backend == "cpu"
accelerator_device = cpu_device()
else
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \
`reactant`, and `cpu`.")
end

kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : ()
trainloader, testloader = get_dataloaders(batchsize; kwargs...) |> accelerator_device

model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)
ps, st = Lux.setup(rng, model) |> gdev
ps, st = Lux.setup(rng, model) |> accelerator_device

opt = AdamW(; eta=lr_max, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

train_state = Training.TrainState(
model, ps, st, AdamW(; eta=lr_max, lambda=weight_decay))
train_state = Training.TrainState(model, ps, st, opt)

lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0])
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
)

adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote()

if backend == "reactant"
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device
model_compiled = @compile model(x_ra, ps, st)
else
model_compiled = model
end

loss = CrossEntropyLoss(; logits=Val(true))

Expand All @@ -96,14 +125,17 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
adtype, loss, (x, y), train_state
)
end
ttime = time() - stime

train_acc = accuracy(
model, train_state.parameters, train_state.states, trainloader) * 100
test_acc = accuracy(model, train_state.parameters, train_state.states, testloader) *
100
model_compiled, train_state.parameters, train_state.states, trainloader
) * 100
test_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, testloader
) * 100

@printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \
Time: %.2f\n" epoch lr train_acc test_acc ttime
Expand Down
3 changes: 3 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x)
function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber)
return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y))
end

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g
8 changes: 8 additions & 0 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,23 @@ symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode
abstract type AbstractPoolOp end

struct MaxPoolOp <: AbstractPoolOp end

(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims)
function (m::MaxPoolOp)(x, ::GlobalPoolMode)
return maximum(x; dims=1:(ndims(x) - 2), init=eltype(x)(-Inf))
end

struct MeanPoolOp <: AbstractPoolOp end

(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims)
(m::MeanPoolOp)(x, ::GlobalPoolMode) = mean(x; dims=1:(ndims(x) - 2))

@concrete struct LpPoolOp <: AbstractPoolOp
p
end

(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p)
(m::LpPoolOp)(x, ::GlobalPoolMode) = lpnormpool(x, PoolDims(x, size(x)[1:(end - 2)]); m.p)

symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp()
symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()
Expand Down
Loading