Skip to content

Instantly share code, notes, and snippets.

@jiahao
Last active July 10, 2023 14:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jiahao/de49d4c948676e07c234217e4a7b957a to your computer and use it in GitHub Desktop.
Save jiahao/de49d4c948676e07c234217e4a7b957a to your computer and use it in GitHub Desktop.
Minimal Julia implementation of NF4 floating point for QLoRA
using Statistics
using BFloat16s
using StaticArrays
import Base: getindex, setindex!, length, iterate
###########################################
# Implementation of the NormedFloat4 type
# and its container type, QLoRAArray
#
# Ref:
# https://arxiv.org/abs/2305.14314
# https://github.com/artidoro/qlora
###########################################
# A data type for one-dimensional UInt4 arrays that index like a normal array but are stored in packed nibbles
struct PackedUInt4Array{N}
data::MVector{N, UInt8}
end
function PackedUInt4Array(N) #This N is the logical size
Nbytes = ceil(Int, N/2)
data = Vector{UInt8}(undef, Nbytes)
PackedUInt4Array{Nbytes}(data)
end
get_high_nibble(byte::UInt8) = byte >> 4
get_low_nibble(byte::UInt8) = byte & 0x0f
pack_high_nibble(byte::UInt8, v::UInt8) = (v<<4) + byte & 0x0f
pack_low_nibble(byte::UInt8, v::UInt8) = v + byte & 0xf0
pack_nibbles(hi::UInt8, lo::UInt8) = (hi<<4) + lo
pack_nibbles(UInt8(0xd), UInt8(0x8))
function getindex(A::PackedUInt4Array, idx::Integer)
offset = (idx-1) >> 1
nibble_idx = (idx-1) & 0x1
byte = A.data[offset+1]
if nibble_idx == 0
nibble = get_high_nibble(byte)
else #nibble_idx == 1
nibble = get_low_nibble(byte)
end
return nibble
end
function setindex!(A::PackedUInt4Array, X::UInt8, idx::Integer)
# Silently truncate input
z = X & 0xf
offset = (idx-1) >> 1
nibble_idx = (idx-1) & 0x1
byte = A.data[offset+1]
if nibble_idx == 0
new_byte = pack_high_nibble(byte, z)
else #nibble_idx == 1
new_byte = pack_low_nibble(byte, z)
end
A.data[offset+1] = new_byte
end
struct QLoRAArray{T, N}
μ::T
σ::T
scale::T
data::PackedUInt4Array
end
const NF4Quantiles = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0]
function quantizeNF4(x)
# Appendix E
q = searchsortedfirst(NF4Quantiles, x)
if q > 1 && (x - NF4Quantiles[q-1] < NF4Quantiles[q] - x) # RoundNearest
q -= 1
end
UInt8(max(q-1, 0))
end
function QLoRAArray(x::AbstractArray{T}) where T
# Standardize x - could probably use StatsBase.ZScoreTransform
μ, σ = mean(x), std(x)
z = (x .- μ) ./ σ
# Maximum absolute scaling
zmin, zmax = extrema(z)
scale = 1/max(-zmin, zmax)
N = length(x)
data = PackedUInt4Array(N)
for i in eachindex(z)
data[i] = quantizeNF4(z[i]*scale)
end
QLoRAArray{T, N}(μ, σ, scale, data)
end
function getindex(A::QLoRAArray{T}, idx::Integer) where T
v = A.data[idx]
z = NF4Quantiles[v+1]/A.scale
x = z*A.σ + A.μ
return x
end
length(::QLoRAArray{T, N}) where {T, N} = N
iterate(A::QLoRAArray{T, N}) where {T, N} = N==0 ? nothing : (A[1], 1)
iterate(A::QLoRAArray{T, N}, i) where {T, N} = i==N ? nothing : (A[i+1], i+1)
###########################################
# Tiny demo
###########################################
v = randn(BFloat16, 80)
z = QLoRAArray(v)
println("Quantization MAE = ", maximum(abs.(v.-z)))
using Plots
plot([v, BFloat16.(z)])
###########################################
# Implementation of the Float8 floating point types
#
# NVIDIA H100s now support two 8-bit floating point formats
# with different bit lenghts for exponent and mantissa/significand
#
# Why, AI people?? WHY??
#
# Ref:
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html
# https://arxiv.org/abs/2209.05433
#
###########################################
primitive type Float8_E4M3 <: AbstractFloat 8 end
primitive type Float8_E5M2 <: AbstractFloat 8 end
const Float8 = Float8_E4M3 # Looks like this is the more common one for forward pass
const Float8s = Union{Float8_E4M3, Float8_E5M2}
Base.sign_mask(::Type{Float8_E4M3}) = 0x80
Base.sign_mask(::Type{Float8_E5M2}) = 0x80
Base.exponent_one(::Type{Float8_E4M3}) = 0x3b
Base.exponent_one(::Type{Float8_E5M2}) = 0x3c
Base.exponent_half(::Type{Float8_E4M3}) = 0x30
Base.exponent_half(::Type{Float8_E5M2}) = 0x38
Base.exponent_bias(::Type{Float8_E4M3}) = 7
Base.exponent_bias(::Type{Float8_E5M2}) = 15
Base.exponent_bits(::Type{Float8_E4M3}) = 4
Base.exponent_bits(::Type{Float8_E5M2}) = 5
Base.exponent_mask(::Type{Float8_E4M3}) = 0b0111_1000
Base.exponent_mask(::Type{Float8_E5M2}) = 0b0111_1100
Base.significand_bits(::Type{Float8_E4M3}) = 3
Base.significand_bits(::Type{Float8_E5M2}) = 2
Base.significand_mask(::Type{Float8_E4M3}) = 0b000_00111
Base.significand_mask(::Type{Float8_E5M2}) = 0b000_00011
Base.signbit(x::Float8s) = (reinterpret(UInt8, x) & 0x80) !== 0x00
# The infinities
Base.isinf(x::Type{Float8_E4M3}) = false
Base.isinf(x::Type{Float8_E5M2}) = (x & 0b0111_1100) == 0b0111_1100
# The NaNs
const NaN8_mask = 0b0111_1111
Base.isnan(x::Type{Float8_E4M3}) = (x & NaN8_mask) == NaN8_mask
Base.isnan(x::Type{Float8_E5M2}) = !isinf(x) && ((x & NaN8_mask) == NaN8_mask)
const NaN8_E4M3 = reinterpret(Float8_E4M3, NaN8_mask)
const NaN8_E5M2 = reinterpret(Float8_E5M2, NaN8_mask)
# Inf8_E4M3 does not exist
const Inf8_E5M2 = reinterpret(Float8_E5M2, 0b0111_1100)
Base.floatmax(::Type{Float8_E4M3}) = reinterpret(Float8_E4M3, 0b0111_1110)
Base.floatmax(::Type{Float8_E5M2}) = reinterpret(Float8_E5M2, 0b0111_1011)
Base.floatmin(::Type{Float8_E5M2}) = reinterpret(Float8_E4M3, 0b0000_1000)
Base.floatmin(::Type{Float8_E5M2}) = reinterpret(Float8_E5M2, 0b0000_0100)
Base.typemin(::Type{Float8_E4M3}) = -Base.floatmax(Float8_E4M3)
Base.typemin(::Type{Float8_E5M2}) = -Inf8_E5M2
Base.typemax(::Type{Float8_E4M3}) = Base.floatmax(Float8_E4M3)
Base.typemax(::Type{Float8_E5M2}) = Inf8_E5M2
Base.promote_rule(::Type{S}, ::Type{T}) where {S<:AbstractFloat,T<:Float8s} = S
Base.promote_rule(::Type{Float8_E4M3}, ::Type{Float8_E5M2}) = Float8_E5M2
########################################################
# Interconversion with other floats (primarily B/Float16
########################################################
# E5M2 has same structure as Float16 but with last 8 bits of mantissa lopped off
Base.convert(::Type{Float8_E5M2}, x::AbstractFloat) = (
reinterpret(Float8_E5M2, UInt8(reinterpret(UInt16, Float16(x)) >> 8)))
Float8_E5M2(x) = Base.convert(Float8_E5M2, x)
Base.convert(::Type{Float8_E4M3}, x::AbstractFloat) = convert(Float8_E4M3, Float16(x))
Float8_E4M3(x) = Base.convert(Float8_E4M3, x)
Base.convert(::Type{Float8_E4M3}, x) = Base.convert(Float8_E4M3, Float16(x))
function Base.convert(::Type{Float8_E4M3}, x::Float16)
isnan(x) && return NaN8_E4M3
isinf(x) && error("Infinities not representable")
z = reinterpret(UInt16, x)
sign_bits = UInt8((Base.sign_mask(Float16) & z) >> 8)
significand_bitshift = Base.significand_bits(Float16) - Base.significand_bits(Float8_E4M3)
significand_bits = UInt8((Base.significand_mask(Float16) & z) >> significand_bitshift)
exponent = (Base.exponent_mask(Float16) & z) >> Base.significand_bits(Float16) - Base.exponent_bias(Float16)
exponent_bits = UInt8(exponent + Base.exponent_bias(Float8_E4M3)) << Base.significand_bits(Float8_E4M3)
final_bits = sign_bits | exponent_bits | significand_bits
return reinterpret(Float8_E4M3, final_bits)
end
Base.Float16(x::Float8s) = Base.convert(Float16, x)
Base.Float32(x::Float8s) = Float32(Float16(x))
Base.Float64(x::Float8s) = Float64(Float16(x))
function Base.convert(::Type{Float16}, x::Float8s)
T = typeof(x)
isnan(x) && return NaN16
z = reinterpret(UInt8, x)
sign_bits = UInt16(Base.sign_mask(T) & z) << 8
significand_bitshift = Base.significand_bits(Float16) - Base.significand_bits(T)
significand_bits = (UInt16(Base.significand_mask(T) & z) << significand_bitshift)
exponent = (Base.exponent_mask(T) & z) >> Base.significand_bits(T) - Base.exponent_bias(T)
exponent_bits = UInt16(exponent + Base.exponent_bias(Float16)) << Base.significand_bits(Float16)
final_bits = sign_bits | exponent_bits | significand_bits
return reinterpret(Float16, final_bits)
end
z = convert(Float8_E4M3, 1.25)
###########################################
# A very incomplete implementation of arithmetic mediated by Float16s
# The docs seem to indicate that the Float8s are storage only and
# native arithmetic is in some 16-point format, possible BFloat16s
###########################################
import Base: +, -, *, /
+(x::T, y::T) where T<:Float8s = T(Float16(x) + Float16(y))
-(x::T, y::T) where T<:Float8s = T(Float16(x) - Float16(y))
*(x::T, y::T) where T<:Float8s = T(Float16(x) * Float16(y))
/(x::T, y::T) where T<:Float8s = T(Float16(x) / Float16(y))
###########################################
# An incomplete implementation of the doubly quantized arrays
# used by QLoRA
###########################################
struct ChunkedQuantizedArray{Nx, Ny, S, Nb, B, T} <: StaticArray{Tuple{Nx, Ny}, S, 2}
c :: SVector{Nb, S}
data :: Matrix{T}
end
#(1) of arXiv:2305.14314v1
function quantize(::Type{T}, X::AbstractMatrix{S},
B::Int=64, # Block size
rm::RoundingMode = RoundNearest) where {S,T}
Nx, Ny = size(X)
Nb = round(Int, length(X)/B, RoundUp) #Number of blocks
c = @MVector zeros(S, Nb)
data = Array{T}(undef, Nb, B)
for ib in 1:Nb #Iterate over blocks
Xb = view(X, ((ib-1)*B+1):min(ib*B, length(X)))
c[ib] = typemax(T) / maximum(abs.(Xb))
data[ib, 1:length(Xb)] = round.(T, c[ib]*Xb, rm)
end
ChunkedQuantizedArray{Nx, Ny, S, Nb, B, T}(c, data)
end
function getindex(X::ChunkedQuantizedArray{Nx, Ny, S, Nb, B, T}, idx::Int) where {Nx, Ny, S, Nb, B, T}
I = CartesianIndices((1:Nb, 1:B))[idx]
X.data[I] / X.c[I[1]]
end
function doublequantize(::Type{T2}, ::Type{T}, X::AbstractMatrix{S},
B::Int=64, B2::Int=256, # Block size
rm::RoundingMode = RoundNearest) where {S,T,T2}
Xc = quantize(T, X, B, rm)
Xc2 = quantize(T2, reshape(Xc.c, length(Xc.c), 1), B2, rm)
(Xc2.c, Xc2.data, Xc.data)
end
###########################################
# An skeleton of the low-rank adapter
# "module"
###########################################
struct LowRankAdapter
# X*W is a low rank projection (factorization) of Y
W # size (h, o)
s
L1 # size (h, r)
L2 # size (r, o)
end
# This is how LowRankAdapter acts on input X in the forward pass
# TODO Figure out the canonical spelling in MLJ or Flux
#
function forward(A::LowRankAdapter, X)
Y = X * A.W + A.s * X * A.L1 * A.L2
return Y
end
@svilupp
Copy link

svilupp commented Jul 10, 2023

on L#67, I think it should be q>1 otherwise you'd be indexing with zero for q=1

@jiahao
Copy link
Author

jiahao commented Jul 10, 2023

@svilupp good catch, thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment