Skip to content

Instantly share code, notes, and snippets.

@ityonemo
Last active November 18, 2023 20:07
Show Gist options
  • Save ityonemo/d75a43104b4dc153cfce20f7f07b2d29 to your computer and use it in GitHub Desktop.
Save ityonemo/d75a43104b4dc153cfce20f7f07b2d29 to your computer and use it in GitHub Desktop.
Floating point to GPTQ in Nx
defmodule NxExt do
import Nx.Defn
@bitshift Nx.tensor([[1], [16]], type: {:u, 8})
@doc """
Takes an N-vector of floats (arbitrarily typed) and converts it into 4-bit gptq, which has
a range of -8..-7. Should be compacted into two "floats" per byte, with the lower indexed
value in the less significant nybble
### TODO: check that the sub-endianness is correct.
```elixir
iex> [-6.0, 1.0, 7.0, -3.0]
...> |> Nx.tensor(type: {:f, 16})
...> |> NxExt.to_gptq()
...> |> Nx.to_binary()
<<1::signed-size(4), -6::signed-size(4), -3::signed-size(4), 7::signed-size(4)>> =
```
"""
defn to_gptq(tensor) do
reshaped =
tensor
|> Nx.clip(-8, 7)
|> Nx.as_type({:s, 8})
|> Nx.bitcast({:u, 8})
|> Nx.bitwise_and(15)
|> Nx.reshape({:auto, 2})
|> Nx.dot(@bitshift)
|> Nx.reshape({:auto})
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment