Skip to content

Instantly share code, notes, and snippets.

@chengchingwen
Created February 25, 2023 12:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chengchingwen/484fd339ca43cc05c8a13243571028c0 to your computer and use it in GitHub Desktop.
Save chengchingwen/484fd339ca43cc05c8a13243571028c0 to your computer and use it in GitHub Desktop.
using Flux
using CUDA
import Optimisers
using NNlibCUDA
NNlibCUDA.softmaxalgo() = NNlibCUDA.CUDNN_SOFTMAX_ACCURATE
CUDA.math_mode!(CUDA.FAST_MATH; precision=:Float16)
# similar to Optimisers._grads! but accumulate
grads!(dict::IdDict, ℓ::Optimisers.Leaf, x, ::Optimisers.Zero...) = nothing
grads!(dict::IdDict, t, x, ::Optimisers.Zero...) = nothing
function grads!(dict::IdDict, ℓ::Optimisers.Leaf, x, x̄s...)
if haskey(dict, ℓ)
x̄s₀ = dict[ℓ]
foreach((x̄₀, x̄) -> (x̄₀ .+= x̄), x̄s₀, x̄s)
else
dict[ℓ] = x̄s
end
nothing
end
function grads!(dict::IdDict, tree, x, x̄s...)
x̄s′ = map(x̄ -> Optimisers.functor(typeof(x), Optimisers.base(x̄))[1], x̄s)
x′, _ = Optimisers.functor(typeof(x), x)
Optimisers.valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end
function foreachgrad(f, dict::IdDict)
for x̄s in values(dict)
for x̄ in x̄s
f(x̄)
end
end
nothing
end
# create model
model = build_model()
model16 = Flux.paramtype(Float16, model)
opt_rule = Optimisers.Adam(1f-6)
opt = Optimisers.setup(opt_rule, model)
function train!(model, model16, opt, dataloader; update_size)
grad = IdDict{Optimisers.Leaf, Any}()
update_i = 0
warm = false
for (i, data) in enumerate(dataloader)
batch_size, input = data
if !warm
input = togpu(input) # move data to gpu
grad_i = Flux.gradient(m->loss(m, input), model)
warm = true
else
input16 = togpu16(input) # move data to gpu, convert to FP 16
grad_i = Flux.gradient(m->loss(m, input16), model16)
end
# collect gradient, this would accumulate FP16 gradient into FP32 gradient buffer
grads!(grad, opt, model, grad_i)
update_i += batch_size
if update_i >= update_size
# average the gradient of each batch
foreachgrad(grad) do dx
dx ./= convert(eltype(dx), update_i)
end
# update the FP32 model
Optimisers._update!(opt, model; grads = grad, params = IdDict())
update_i = 0
foreachgrad(Base.Fix2(fill!, 0), grad)
# copy weight from model to model16
load_weight!(model16, model)
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment