-
-
Save mcabbott/6154bb78b735e8f0a9348767a7d59c86 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
##### Some code I had for normalisation | |
using ChainRulesCore, Statistics | |
function normal_now(x::AbstractArray) # like what LayerNorm does in Flux | |
μ = mean(x; dims=1) | |
σ2 = var(x; dims=1, mean=μ, corrected=false) | |
ε = eps(eltype(x)) | |
(x .- μ) ./ sqrt.(σ2 .+ ε) | |
end | |
function normal_new(x::AbstractArray) # improved version | |
μ = mean(x; dims=1) | |
σ2 = var(x; dims=1, mean=μ, corrected=false) | |
ε = eps(eltype(x)) | |
istd = inv.(sqrt.(σ2.+ε)) # this is the only idea -- don't do sqrt N^2 times | |
y = (x .- μ) .* istd | |
# y = @fastmath (x .- μ) ./ sqrt.(σ2.+ε) # this is quicker on GPU | |
end | |
# Mostly I thought a bit about how to write a fused gradient here: | |
function ChainRulesCore.rrule(::typeof(normal_new), x::AbstractArray) | |
μ = mean(x; dims=1) | |
σ2 = var(x; dims=1, mean=μ, corrected=false) | |
ε = eps(eltype(x)) | |
istd = inv.(sqrt.(σ2.+ε)) | |
y = (x .- μ) .* istd | |
# y = @fastmath (x .- μ) ./ sqrt.(σ2.+ε) # quicker on GPU | |
iN = one(eltype(y))/size(x,1) | |
function back(Δ) | |
dy = unthunk(Δ) | |
# dx = istd.*(dy .- mean(dy; dims=1) .- y.*istd.*mean(dy.*(x.-μ); dims=1)) | |
# Like this, uses 38% of the memory: | |
# dx = istd.*(dy .- iN .* sum(dy; dims=1) .- y .* iN .* istd .* sum(dy.*(x.-μ); dims=1)) | |
# Like this, uses 25%, I said. Should be safe in a jacobian. | |
tmp = dy.*(x.-μ) | |
dx = tmp .= istd.*(dy .- iN .* sum(dy; dims=1) .- y .* iN .* istd .* sum(tmp; dims=1)) | |
# dx = @fastmath tmp .= (dy .- iN .* (sum(dy; dims=1) .- y .* sum(tmp; dims=1)) ./ sqrt.(σ2.+ε)) | |
(NoTangent(), dx) | |
end | |
y, back | |
end | |
##### Newer version... | |
using GPUArraysCore, Statistics, ChainRulesCore | |
function mean_var(A::AbstractArray; dims, corrected::Bool=false) | |
μ = mean(A; dims) | |
σ2 = var(A; dims, corrected, mean=μ) | |
(; mean=μ, var=σ2) | |
end | |
# This is something like Welford's algorithm. It's pretty slow on CPU. | |
function mean_var(A::AbstractGPUArray{Float32}; dims, corrected::Bool=false) | |
init = (zero(Float32), zero(Float32), Int32(0)) | |
tri(x::Float32) = (x, zero(Float32), Int32(1)) | |
N = sum_length(A, dims) | |
λ = Float32(1/(N-corrected)) | |
function red(tup1::Tuple, tup2::Tuple) | |
m1, v1, k1 = tup1 | |
m2, v2, k2 = tup2 | |
k = k1 + k2 | |
iszero(k) && return init # this is essential for GPU! | |
invk = inv(Float32(k)) | |
m = (k1 * m1 + k2 * m2) * invk | |
v = v1 + v2 + λ * (m2 - m1)^2 * k1 * k2 * invk | |
(Float32(m), Float32(v), Int32(k)) | |
end | |
R = mapreduce(tri, red, A; dims, init) | |
mean, var, _ = if dims isa Colon | |
R | |
else | |
R3 = reinterpret(reshape, Float32, R) | |
eachslice(R3, dims=1) # produces view(::CuArray), seems OK | |
end | |
(; mean, var) | |
end | |
sum_length(x::AbstractArray, dims::Integer) = size(x, dims) | |
sum_length(x::AbstractArray, dims::Colon) = length(x) | |
sum_length(x::AbstractArray, dims) = prod(size(x,d) for d in unique(dims); init=1) | |
# Here too it seems to pay to have a separate path for GPU arrays | |
function normal_newer(x::AbstractArray; dims=1) | |
μ, σ2 = mean_var(x; dims, corrected=false) | |
ε = eps(eltype(x)) | |
istd = inv.(sqrt.(σ2.+ε)) # avoids N^2 sqrt evals | |
y = (x .- μ) .* istd | |
end | |
function normal_newer(x::AbstractGPUArray; dims=1) | |
μ, σ2 = mean_var(x; dims, corrected=false) | |
ε = eps(eltype(x)) | |
y = (x .- μ) ./ sqrt.(σ2.+ε) # this is quicker on GPU | |
end | |
# Two paths here starts to get really ugly... | |
function ChainRulesCore.rrule(::typeof(normal_newer), x::AbstractArray; dims=1) | |
μ, σ2 = mean_var(x; dims, corrected=false) | |
ε = eps(eltype(x)) | |
istd = inv.(sqrt.(σ2.+ε)) | |
y = (x .- μ) .* istd | |
iN = one(eltype(y))/sum_length(x, dims) | |
function back(Δ) | |
dy = unthunk(Δ) | |
tmp = dy .* (x .- μ) | |
dx = tmp .= istd.*(dy .- iN .* sum(dy; dims=1) .- y .* iN .* istd .* sum(tmp; dims=1)) | |
(NoTangent(), dx) | |
end | |
y, back | |
end | |
function ChainRulesCore.rrule(::typeof(normal_newer), x::AbstractGPUArray; dims=1) | |
μ, σ2 = mean_var(x; dims, corrected=false) | |
ε = eps(eltype(x)) | |
y = @fastmath (x .- μ) ./ sqrt.(σ2.+ε) # quicker on GPU | |
iN = one(eltype(y))/sum_length(x, dims) | |
function back(Δ) | |
dy = unthunk(Δ) | |
tmp = dy .* (x .- μ) | |
dx = @fastmath tmp .= (dy .- iN .* (sum(dy; dims) .- y .* sum(tmp; dims)) ./ sqrt.(σ2.+ε)) | |
(NoTangent(), dx) | |
end | |
y, back | |
end | |
##### Some code from https://github.com/FluxML/NNlib.jl/pull/452/files | |
import ChainRulesCore: rrule, @ignore_derivatives | |
function norm_stats(x, dims) | |
μ = mean(x; dims) | |
σ² = var(x; dims, mean = μ, corrected = false) | |
return μ, σ² | |
end | |
function rrule(::typeof(norm_stats), x, dims) | |
μ, mean_pullback = rrule(mean, x; dims) | |
σ², var_pullback = rrule(var, x; dims, mean = μ, corrected = false) | |
function norm_stats_pullback(dargs) | |
dμ, dσ² = unthunk(dargs) | |
dx = ChainRulesCore.add!!(var_pullback(dμ)[2], mean_pullback(dσ²)[2]) | |
return (NoTangent(), dx, NoTangent()) | |
end | |
return (μ, σ²), norm_stats_pullback | |
end | |
_maybe_reshape(::Nothing, _) = nothing | |
_maybe_reshape(x, dims) = reshape(x, dims) | |
_apply_scale_bias(x, ::Nothing, ::Nothing) = x | |
_apply_scale_bias(x, scale, bias) = x .* scale .+ bias | |
ofeltype(x, y) = convert(float(eltype(x)), y) | |
function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, | |
bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ)) | |
@ignore_derivatives if isnothing(scale) != isnothing(bias) | |
error("both scale and bias must be provided or left as nothing") | |
end | |
scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size) | |
return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′) | |
end | |
function layernorm(x::AbstractArray{<:Any, N}, ::Val{S} = Val(1), scale = nothing, bias = nothing, | |
ϵ = ofeltype(x, 1e-5)) where {N, S} | |
@ignore_derivatives if S > N | |
throw(DimensionMismatch("got $S reduction dims for $N-dimensional array")) | |
end | |
μ, σ² = norm_stats(x, ntuple(identity, S)) | |
return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]) | |
end | |
##### Some code from https://github.com/chengchingwen/NeuralAttentionlib.jl/blob/b418c0d2a9e99c960e88879a5fd879d47d8e4c22/src/functional/layernorm.jl | |
_x_x2(x) = (x, x * x) | |
_x_y2(x, y) = (x, x * y) | |
function _normalize(inN::T, ϵ::T, x::T, sum_sum2::NTuple{2, T}) where T | |
μ, s = sum_sum2 .* inN | |
σ₀ = sqrt(fma(μ, -μ, s)) # @fastmath sqrt helps too, 2x? | |
σ = max(σ₀, ϵ) # @fastmath max helps 30% here | |
return (x - μ) / σ | |
end | |
function _rstd(inN::T, ϵ::T, sum_sum2::NTuple{2, T}) where T | |
μ, s = sum_sum2 .* inN | |
σ₀ = sqrt(fma(μ, -μ, s)) | |
σ = max(σ₀, ϵ) # @fastmath max has no effect here | |
return inv(σ) | |
end | |
layer_norm(alpha, beta, x) = layer_norm(1e-5, alpha, beta, x) | |
function layer_norm(epsilon, alpha, beta, x) | |
T = eltype(x) | |
N = size(x, 1) | |
ϵ = convert(T, epsilon) | |
α = isnothing(alpha) ? one(T) : alpha | |
β = isnothing(beta) ? zero(T) : beta | |
# This makes an array of tuples, sum_sum2 == tuple.(sum(x; dims), sum(abs2, x; dims)) | |
sum_sum2 = mapreduce(_x_x2, .+, x; dims=1, init = (zero(T), zero(T))) | |
# And this calls sqrt on all N^2 elements | |
return fma.(α, _normalize.(convert(T, 1//N), ϵ, x, sum_sum2), β) | |
end | |
_fma2(dy::T, dya::NTuple{2, T}, n::T, inN::T, is::T) where T = fma(fma(n, last(dya), first(dya)), inN, dy) * is | |
function Δlayer_norm_dx(Ȳ, ϵ, α, n, x, sum_sum2) | |
T = eltype(x) | |
N = size(x, 1) | |
is = Broadcast.instantiate(Broadcast.broadcasted(_rstd, convert(T, 1//N), ϵ, sum_sum2)) | |
dy = Broadcast.instantiate(Broadcast.broadcasted(*, Ȳ, α)) | |
# This mapreduce(f, op, a, b) will be slow on CPU, map then reduce | |
dya = mapreduce(_x_y2, .+, dy, n; dims=1, init=(zero(T), zero(T))) | |
∂x = _fma2.(dy, dya, n, -convert(T, 1//N), is) | |
return ∂x | |
end | |
_taildims(Ȳ) = Base.tail(ntuple(identity, Val(ndims(Ȳ)))) | |
function ChainRulesCore.rrule(::typeof(layer_norm), alpha, beta, x) | |
y, pullback = rrule(layer_norm, 1e-5, alpha, beta, x) | |
layer_norm_pullback(Ȳ) = (NoTangent(), last_n(pullback(Ȳ), static(3))...) | |
return y, layer_norm_pullback | |
end | |
function ChainRulesCore.rrule(::typeof(layer_norm), epsilon, alpha, beta, x) | |
T = eltype(x) | |
N = size(x, 1) | |
ϵ = convert(T, epsilon) | |
cα = static(isnothing(alpha)) | |
cβ = static(isnothing(beta)) | |
aα = static(alpha isa AbstractArray) | |
aβ = static(beta isa AbstractArray) | |
α = as_bool(cα) ? one(T) : alpha | |
β = as_bool(cβ) ? zero(T) : beta | |
sum_sum2 = mapreduce(_x_x2, .+, x; dims=1, init = (zero(T), zero(T))) | |
n = _normalize.(convert(T, 1//N), ϵ, x, sum_sum2) | |
y = fma.(α, n, β) | |
function layer_norm_pullback(Ybar) | |
Ȳ = unthunk(Ybar) | |
∂α = as_bool(cα) ? NoTangent() : @thunk sum( | |
Broadcast.instantiate(Broadcast.broadcasted(*, Ȳ, n)); | |
dims = as_bool(aα) ? _taildims(Ȳ) : :, init = zero(eltype(Ȳ)) | |
) | |
∂β = as_bool(cβ) ? NoTangent() : @thunk sum(Ȳ; dims = as_bool(aβ) ? _taildims(Ȳ) : :) | |
∂x = @thunk Δlayer_norm_dx(Ȳ, ϵ, α, n, x, sum_sum2) | |
return (NoTangent(), NoTangent(), ∂α, ∂β, ∂x) | |
end | |
return y, layer_norm_pullback | |
end | |
using Static | |
as_bool(b::Bool) = b | |
as_bool(b::StaticBool) = Bool(b) | |
function last_n(s::Tuple, n) | |
offset = static(length(s)) - n | |
ntuple(i->s[offset + i], n) | |
end | |
##### Testing forwards | |
x = rand(Int8, 4, 5).+0.0 | |
normal_now(x) | |
normal_new(x) | |
normal_newer(x) | |
layernorm(x, Val(1)) # so I made Val(1) the default | |
layer_norm(true, false, x) | |
layer_norm(nothing, nothing, x) # also encodes this | |
julia> @btime normal_now($(rand(100, 100))); # most of the time in sqrt! | |
min 14.541 μs, mean 19.148 μs (10 allocations, 80.45 KiB) | |
min 12.875 μs, mean 16.387 μs (10 allocations, 80.45 KiB) # with @fastmath sqrt | |
julia> @btime normal_new($(rand(100, 100))); # much quicker, same memory | |
min 5.048 μs, mean 8.829 μs (11 allocations, 81.33 KiB) | |
min 5.056 μs, mean 9.781 μs (11 allocations, 81.33 KiB) # with @fastmath sqrt | |
julia> @btime normal_newer($(rand(100, 100))); | |
min 5.076 μs, mean 11.021 μs (11 allocations, 81.33 KiB) | |
julia> @btime layernorm($(rand(100, 100))); | |
min 14.500 μs, mean 18.573 μs (18 allocations, 80.73 KiB) | |
julia> @btime layer_norm(true, false, $(rand(100, 100))); | |
min 16.834 μs, mean 20.344 μs (3 allocations, 79.94 KiB) # as above | |
min 6.917 μs, mean 10.702 μs (3 allocations, 79.94 KiB) # with @fastmath max & sqrt | |
#= | |
# Components, to see where the time is... | |
@btime copy($(rand(100, 100))); # 78.17 KiB | |
@btime sum($(rand(100, 100)); dims=1); | |
m1 = @btime mean($(rand(100, 100)); dims=1); | |
@btime var($(rand(100, 100)); dims=1, mean=$m1, corrected=false); # does not allocate a copy | |
@btime std($(rand(100, 100)); dims=1, mean=$m1, corrected=false); | |
@btime sqrt.($m1); | |
@btime inv.(sqrt.($m1)); | |
@btime inv.(sqrt.($(rand(100, 100)))); | |
=# | |
##### Testing gradient | |
# Can save a lot of memory, but not much speedup, at least on M1 mac. | |
using Zygote, BenchmarkTools | |
Zygote.gradient(x -> sum(normal_now(x)[1,:]), x)[1] | |
Zygote.gradient(x -> sum(normal_new(x)[1,:]), x)[1] | |
Zygote.gradient(x -> sum(normal_newer(x)[1,:]), x)[1] | |
Zygote.gradient(x -> sum(layernorm(x)[1,:]), x)[1] # this does not look the same | |
Zygote.gradient(x -> sum(layer_norm(true, false, x)[1,:]), x)[1] | |
julia> @btime Zygote.gradient(x -> sum(abs2, x), $(rand(100, 100))); # baseline, no norm! | |
min 3.391 μs, mean 11.354 μs (2 allocations, 78.17 KiB) | |
julia> @btime Zygote.gradient(x -> sum(abs2, normal_now(x)), $(rand(100, 100))); | |
min 26.334 μs, mean 85.340 μs (57 allocations, 637.97 KiB) | |
julia> @btime Zygote.gradient(x -> sum(abs2, normal_new(x)), $(rand(100, 100))); | |
min 26.750 μs, mean 48.510 μs (33 allocations, 240.11 KiB) | |
julia> @btime Zygote.gradient(x -> sum(abs2, normal_newer(x)), $(rand(100, 100))); | |
min 26.959 μs, mean 53.320 μs (33 allocations, 240.11 KiB) | |
julia> (637.97-78.17) / (240.11-78.17) | |
3.4568358651352358 | |
julia> @btime Zygote.gradient(x -> sum(abs2, layernorm(x)), $(rand(100, 100))); | |
min 81.666 μs, mean 130.307 μs (252 allocations, 570.58 KiB) | |
julia> @btime Zygote.gradient(x -> sum(abs2, layer_norm(true, false, x)), $(rand(100, 100))); | |
min 77.500 μs, mean 125.485 μs (37 allocations, 473.62 KiB) | |
##### GPU times | |
# The conclusion here is that NVIDIA did this optimisation better than I did. | |
# No point optimising GPU case; can LayerNorm call BatchNorm's backend too? | |
using CUDA, Flux | |
julia> cx = cu(randn(100, 1000)); | |
julia> CUDA.@time cx .+ 1; # baseline | |
0.000124 seconds (40 CPU allocations: 1.750 KiB) (1 GPU allocation: 390.625 KiB, 16.67% memmgmt time) | |
julia> CUDA.@time Flux.normalise(cx); | |
0.000431 seconds (441 CPU allocations: 21.453 KiB) (9 GPU allocations: 784.766 KiB, 11.47% memmgmt time) | |
julia> CUDA.@time normal_now(cx); | |
0.000292 seconds (258 CPU allocations: 11.516 KiB) (6 GPU allocations: 796.875 KiB, 13.49% memmgmt time) | |
julia> CUDA.@time normal_new(cx); | |
0.000307 seconds (296 CPU allocations: 13.203 KiB) (7 GPU allocations: 800.781 KiB, 13.97% memmgmt time) | |
# gradients | |
julia> CUDA.@time Zygote.gradient(x -> sum(sin, x), cx); # baseline | |
0.060339 seconds (16.40 k CPU allocations: 963.312 KiB) (6 GPU allocations: 1.908 MiB, 0.13% memmgmt time) | |
julia> CUDA.@time Zygote.gradient(x -> sum(sin, Flux.normalise(x)), cx); | |
0.156524 seconds (38.36 k CPU allocations: 2.249 MiB) (29 GPU allocations: 5.348 MiB, 0.60% memmgmt time) | |
julia> CUDA.@time Zygote.gradient(x -> sum(sin, normal_now(x)), cx); | |
0.079000 seconds (23.59 k CPU allocations: 1.313 MiB) (28 GPU allocations: 5.387 MiB, 0.21% memmgmt time) | |
julia> CUDA.@time Zygote.gradient(x -> sum(sin, normal_new(x)), cx); | |
0.074505 seconds (24.45 k CPU allocations: 1.396 MiB) (16 GPU allocations: 3.079 MiB, 0.15% memmgmt time) | |
# batchnorm | |
# bn2 = BatchNorm(100, affine=false) |> gpu # gives an error! | |
# But what does work, is more efficient than mine: 2.672 MiB < 3.079 MiB | |
julia> bn1 = BatchNorm(100) |> gpu | |
BatchNorm(100) # 200 parameters, plus 200 non-trainable | |
julia> CUDA.@time bn1(cx); | |
0.000291 seconds (42 CPU allocations: 1.312 KiB) (1 GPU allocation: 390.625 KiB, 7.79% memmgmt time) | |
julia> CUDA.@time Zygote.gradient(x -> sum(sin, bn1(x)), cx); | |
0.099942 seconds (36.84 k CPU allocations: 2.064 MiB) (10 GPU allocations: 2.672 MiB, 0.13% memmgmt time) | |
# layernorm | |
# here it is literally calling Flux.normalise | |
# When the dims work, can it call BatchNorm instead? | |
julia> ln1 = LayerNorm(100) |> gpu | |
LayerNorm(100) # 200 parameters | |
julia> ln2 = LayerNorm(100, affine=false) |> gpu | |
LayerNorm(100) # 200 parameters | |
julia> CUDA.@time ln2(cx); | |
0.000369 seconds (320 CPU allocations: 15.688 KiB) (7 GPU allocations: 800.781 KiB, 16.56% memmgmt time) | |
julia> CUDA.@time Zygote.gradient(x -> sum(sin, ln2(x)), cx); | |
0.136035 seconds (37.73 k CPU allocations: 2.104 MiB, 23.66% gc time) (26 GPU allocations: 5.376 MiB, 0.19% memmgmt time) | |
##### GPU, January | |
# Now updated to include https://github.com/JuliaGPU/GPUArrays.jl/pull/443 | |
julia> let x = CUDA.randn(100, 1000) # Forward pass, @btime | |
@btime CUDA.@sync copy($x) # baseline | |
println() | |
@btime CUDA.@sync normal_now($x) | |
@btime CUDA.@sync normal_new($x) | |
@btime CUDA.@sync normal_newer($x) | |
println() | |
@btime CUDA.@sync layernorm($x) | |
@btime CUDA.@sync layer_norm(nothing, nothing, $x) | |
println() | |
μ = @btime CUDA.@sync mean($x; dims=1) | |
@btime CUDA.@sync var($x; mean=$μ, corrected=false, dims=1) | |
end; | |
17.749 μs (13 allocations: 400 bytes) | |
96.518 μs (219 allocations: 10.88 KiB) | |
106.690 μs (257 allocations: 12.56 KiB) | |
56.772 μs (144 allocations: 9.45 KiB) | |
93.800 μs (225 allocations: 11.41 KiB) | |
43.795 μs (93 allocations: 4.33 KiB) | |
43.810 μs (99 allocations: 4.59 KiB) | |
41.918 μs (92 allocations: 5.17 KiB) | |
julia> let x = CUDA.randn(1000, 10_000) # Forward pass, alloc | |
CUDA.@time copy(x) # baseline | |
println() | |
CUDA.@time normal_now(x) | |
CUDA.@time normal_new(x) # surprisingly large alloc, why? fixed. | |
CUDA.@time normal_newer(x) | |
println() | |
CUDA.@time layernorm(x) | |
CUDA.@time layer_norm(nothing, nothing, x) | |
println() | |
μ = CUDA.@time mean(x; dims=1) | |
CUDA.@time var(x; mean=μ, corrected=false, dims=1) # allocates a big array... before 443 | |
end; | |
0.000192 seconds (13 CPU allocations: 400 bytes) (1 GPU allocation: 38.147 MiB, 13.82% memmgmt time) | |
0.001525 seconds (279 CPU allocations: 14.281 KiB) (4 GPU allocations: 38.261 MiB, 1.92% memmgmt time) | |
0.000904 seconds (315 CPU allocations: 15.859 KiB) (5 GPU allocations: 38.300 MiB, 3.44% memmgmt time) | |
0.006859 seconds (216 CPU allocations: 13.625 KiB) (2 GPU allocations: 38.261 MiB, 0.36% memmgmt time) | |
0.009269 seconds (365 CPU allocations: 19.094 KiB) (4 GPU allocations: 38.261 MiB, 0.38% memmgmt time) | |
0.005154 seconds (110 CPU allocations: 5.188 KiB) (2 GPU allocations: 38.223 MiB, 0.44% memmgmt time) | |
0.000320 seconds (101 CPU allocations: 4.625 KiB) (2 GPU allocations: 78.125 KiB, 3.75% memmgmt time) | |
0.015003 seconds (156 CPU allocations: 8.453 KiB) (1 GPU allocation: 39.062 KiB, 0.09% memmgmt time) | |
julia> let x = CUDA.randn(100, 1000) # Gradient, @btime | |
@btime CUDA.@sync Zygote.gradient(x -> sum(abs2, x), $x) # baseline, no norm! | |
println() | |
@btime CUDA.@sync Zygote.gradient(x -> sum(abs2, normal_now(x)), $x) | |
@btime CUDA.@sync Zygote.gradient(x -> sum(abs2, normal_new(x)), $x) | |
@btime CUDA.@sync Zygote.gradient(x -> sum(abs2, normal_newer(x)), $x) | |
println() | |
@btime CUDA.@sync Zygote.gradient(x -> sum(abs2, layernorm(x, Val(1))), $x) | |
@btime CUDA.@sync Zygote.gradient(x -> sum(abs2, layer_norm(nothing, nothing, x)), $x) | |
end | |
176.829 μs (325 allocations: 14.14 KiB) | |
685.050 μs (1250 allocations: 60.31 KiB) | |
413.809 μs (789 allocations: 42.55 KiB) | |
373.748 μs (666 allocations: 39.12 KiB) | |
993.590 μs (1494 allocations: 74.17 KiB) | |
365.185 μs (565 allocations: 25.61 KiB) | |
julia> let x = CUDA.randn(1000, 10_000) # Gradient, alloc | |
CUDA.@time Zygote.gradient(x -> sum(abs2, x), x) # baseline, no norm! | |
println() | |
CUDA.@time Zygote.gradient(x -> sum(abs2, normal_now(x)), x) | |
CUDA.@time Zygote.gradient(x -> sum(abs2, normal_new(x)), x) | |
CUDA.@time Zygote.gradient(x -> sum(abs2, normal_newer(x)), x) | |
println() | |
CUDA.@time Zygote.gradient(x -> sum(abs2, layernorm(x, Val(1))), x) | |
CUDA.@time Zygote.gradient(x -> sum(abs2, layer_norm(nothing, nothing, x)), x) | |
end; | |
0.055923 seconds (16.29 k CPU allocations: 959.204 KiB) (6 GPU allocations: 190.735 MiB, 0.09% memmgmt time) | |
0.079773 seconds (23.45 k CPU allocations: 1.308 MiB) (26 GPU allocations: 458.222 MiB, 0.18% memmgmt time) | |
0.076155 seconds (24.29 k CPU allocations: 1.388 MiB) (14 GPU allocations: 267.258 MiB, 0.12% memmgmt time) | |
0.079763 seconds (24.18 k CPU allocations: 1.392 MiB) (11 GPU allocations: 267.220 MiB, 0.11% memmgmt time) | |
0.106857 seconds (30.34 k CPU allocations: 1.698 MiB) (27 GPU allocations: 496.369 MiB, 0.15% memmgmt time) | |
0.076044 seconds (24.26 k CPU allocations: 1.388 MiB) (11 GPU allocations: 305.329 MiB, 0.11% memmgmt time) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment