Skip to content

Instantly share code, notes, and snippets.

Created January 19, 2023 23:07
Show Gist options
  • Save Birch-san/4f4945f219aa0118712a3f2fc619eba2 to your computer and use it in GitHub Desktop.
Save Birch-san/4f4945f219aa0118712a3f2fc619eba2 to your computer and use it in GitHub Desktop.
Matrix multiplication, by computing mantissae and exponents separately
# suppose we want to matmul some matrix against some column vector..
# the usual way is this:
mat = np.array([[1.3e2, 8.e2], [1.6e1, 5.e-1]])
vec = np.array([2.1e2, 3.6e-4])
mat @ vec
array([27300.288 , 3360.00018])
# but what if we wanted to exploit some efficiencies..
# - addition can use less cycles / energy / silicon / time than multiplication
# - for machine learning training: we want to represent a wide range of exponents, but don't need such range on the mantissa
# - multiplying *exponents* in isolation can be done cheaply, because it's just addition! 10^3 * 10^4 = 10^7
# **can we matmul the mantissa separately from the exponent?**
# first, let's change our representation; shift left until we have no decimal portion.
# update exponent by however much you shift
# you won't be able to preserve *all* decimal detail (depends how many bits we allocate for our mantissa)
# so you may have to round down and lose some precision
# mat = np.array([[1.3e2, 8.e2], [1.6e1, 5.e-1]]) becomes:
# = np.array([[13.e1, 8.e2], [16.e0, 5.e-1]])
# vec = np.array([2.1e2, 3.6e-4]) becomes
# = np.array([21e1, 36e-5])
# now, let's represent the mantissae and exponents as separate ndarrays:
mat_m = np.array([[13, 8], [16, 5]])
mat_e = np.array([[1, 2], [0, -1]])
vec_m = np.array([21, 36])
vec_e = np.array([1,-5])
# elementwise-multiply mantissae:
# we're not ready to add their elements together (i.e. dot product),
# because the elements won't be the right magnitudes without the extra context of their exponents
mantissa_elementwise_product = mat_m * vec_m
# this one's important! exponents can be multiplied by *adding powers*. cheap!
exponent_elementwise_product = mat_e + vec_e
# you have two options here.
# you could stay in separated mantissa/exponent representation..
# for example, undergo an activation (perhaps one that considers only your exponent)
# then do further "matmuls" like this, with more elementwise products over mantissa and exponent
# *or* you could say "I'm done" and return back to canonical representation
# this isn't super cheap because we gain an elementwise multiply and raising by our exponent.
# so the hope is that you manage to do a few more layers in mantissa/exponent space before doing this
elementwise_product = mantissa_elementwise_product * np.array([10.])**exponent_elementwise_product
dot_product = elementwise_product.sum(1)
array([27300.288 , 3360.00018])
# same result as our matmul earlier
# what's the implication of this?
# perhaps we could have:
# int8 mantissa
# - int multiplication cheaper than float
# - mantissa doesn't need a wide range
# int8 exponent
# - int cheaper than float
# - addition cheaper than multiplication
# int16 exponent?
# - representing wide range of gradients is helpful for training
# - cost of int16 addition may be more palatable than float16 multiplication
# a potential problem is what to do about carries when you multiply mantissas.
# I think if you have your mantissa be nominally an int4 (living inside an int8, i.e. with headroom reserved):
# you can multiply by another int4 to get at maximum an int8.
# you can probably then use bitwise operations to determine how much you need to increase the corresponding exponent by,
# then right-shift the mantissa back down afterwards until everything fits into int4 again, losing some precision.
# not ideal. quite a few steps, so might lose a lot of the perf we were hoping to gain.
# maybe we need int16 to have enough space to hold a decent carry, and to represent NaN/Infinity/negative.
# still, int16 multiplication could have an efficiency over float multiplication. not sure how int16 compares to float8.
# there's a lot of factors involved in the performance and practicality,
# so maybe there's an aspect of this that is a non-starter.
# but it'd be cool if it turns out there's a use for this!
Copy link

Birch-san commented Jan 20, 2023

We can make the 10^x step cheaper if we use 2 as our exponentiation base. That way it just becomes a left shift.

Regarding "if we want to multiply an int3 mantissa by another int3 mantissa: does that mean we need to reserve int6 to store the multiplication result"?

We don't actually need to store this in the datatype. We can store multiplication result elsewhere (temporary buffer) until we get an opportunity to flush it (work out how many powers of 2 we went up by, bit shift right by that amount, add the same amount to our exponent).

But storing the result inside the datatype could give a perf advantage, if the alternative of buffering the multiplication result elsewhere would be more expensive.

int3 is probably sufficient mantissa size; float8 training can succeed even with a 2-bit significand:

Maybe int3*int3 has so few combinations you could store them in a lookup table? could that be faster than doing an int3 multiply? might benefit from custom silicon. you could also store the amount by which the multiplication result would change your exponent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment