Skip to content

Instantly share code, notes, and snippets.

@klauspost
Created May 3, 2022 11:25
Show Gist options
  • Save klauspost/617e149f31f8967bc184f5a48c3834f4 to your computer and use it in GitHub Desktop.
Save klauspost/617e149f31f8967bc184f5a48c3834f4 to your computer and use it in GitHub Desktop.
//go:build amd64 && !appengine && !noasm && gc
// +build amd64,!appengine,!noasm,gc
// This file contains the specialisation of Decoder.Decompress4X
// that uses an asm implementation of its main loop.
package huff0
import (
"errors"
"fmt"
)
// decompress4x_main_loop_amd64_9 is an x86 assembler implementation
// of Decompress4X when tablelog > 8.
//go:noescape
func decompress4x_main_loop_amd64_9(ctx *decompress4xContext) uint8
//go:noescape
func decompress4x_main_loop_amd64_10(ctx *decompress4xContext) uint8
//go:noescape
func decompress4x_main_loop_amd64_11(ctx *decompress4xContext) uint8
// decompress4x_8b_loop_x86 is an x86 assembler implementation
// of Decompress4X when tablelog <= 8 which decodes 4 entries
// per loop.
//go:noescape
func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) uint8
// fallback8BitSize is the size where using Go version is faster.
const fallback8BitSize = 800
type decompress4xContext struct {
pbr0 *bitReaderShifted
pbr1 *bitReaderShifted
pbr2 *bitReaderShifted
pbr3 *bitReaderShifted
peekBits uint8
buf *byte
tbl *dEntrySingle
}
// Decompress4X will decompress a 4X encoded stream.
// The length of the supplied input must match the end of a block exactly.
// The *capacity* of the dst slice must match the destination size of
// the uncompressed data exactly.
func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
if len(d.dt.single) == 0 {
return nil, errors.New("no table loaded")
}
if len(src) < 6+(4*1) {
return nil, errors.New("input too small")
}
use8BitTables := d.actualTableLog <= 8
if cap(dst) < fallback8BitSize && use8BitTables {
return d.decompress4X8bit(dst, src)
}
var br [4]bitReaderShifted
// Decode "jump table"
start := 6
for i := 0; i < 3; i++ {
length := int(src[i*2]) | (int(src[i*2+1]) << 8)
if start+length >= len(src) {
return nil, errors.New("truncated input (or invalid offset)")
}
err := br[i].init(src[start : start+length])
if err != nil {
return nil, err
}
start += length
}
err := br[3].init(src[start:])
if err != nil {
return nil, err
}
// destination, offset to match first output
dstSize := cap(dst)
dst = dst[:dstSize]
out := dst
dstEvery := (dstSize + 3) / 4
const tlSize = 1 << tableLogMax
const tlMask = tlSize - 1
single := d.dt.single[:tlSize]
// Use temp table to avoid bound checks/append penalty.
buf := d.buffer()
var off uint8
var decoded int
const debug = false
ctx := decompress4xContext{
pbr0: &br[0],
pbr1: &br[1],
pbr2: &br[2],
pbr3: &br[3],
peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
buf: &buf[0][0],
tbl: &single[0],
}
// Decode 2 values from each decoder/loop.
const bufoff = 256
for {
if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
break
}
if use8BitTables {
off = decompress4x_8b_main_loop_amd64(&ctx)
} else {
switch d.actualTableLog {
case 9:
off = decompress4x_main_loop_amd64_9(&ctx)
case 10:
off = decompress4x_main_loop_amd64_10(&ctx)
case 11:
off = decompress4x_main_loop_amd64_11(&ctx)
default:
//panic(fmt.Sprintf("unexpected tablelog: %d", d.actualTableLog))
}
}
if debug {
fmt.Print("DEBUG: ")
fmt.Printf("off=%d,", off)
for i := 0; i < 4; i++ {
fmt.Printf(" br[%d]={bitsRead=%d, value=%x, off=%d}",
i, br[i].bitsRead, br[i].value, br[i].off)
}
fmt.Println("")
}
if off != 0 {
break
}
if bufoff > dstEvery {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 1")
}
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left.
if len(out) < dstEvery*3 {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 2")
}
}
if off > 0 {
ioff := int(off)
if len(out) < dstEvery*3+ioff {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 3")
}
copy(out, buf[0][:off])
copy(out[dstEvery:], buf[1][:off])
copy(out[dstEvery*2:], buf[2][:off])
copy(out[dstEvery*3:], buf[3][:off])
decoded += int(off) * 4
out = out[off:]
}
// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
}
br := &br[i]
bitsLeft := br.remaining()
for bitsLeft > 0 {
br.fill()
if offset >= endsAt {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 4")
}
// Read value and increment offset.
val := br.peekBitsFast(d.actualTableLog)
v := single[val&tlMask].entry
nBits := uint8(v)
br.advance(nBits)
bitsLeft -= uint(nBits)
out[offset] = uint8(v >> 8)
offset++
}
if offset != endsAt {
d.bufs.Put(buf)
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
}
decoded += offset - dstEvery*i
err = br.close()
if err != nil {
return nil, err
}
}
d.bufs.Put(buf)
if dstSize != decoded {
return nil, errors.New("corruption detected: short output block")
}
return dst, nil
}
package main
//go:generate go run gen.go -out ../decompress_amd64.s -pkg=huff0
import (
"flag"
"fmt"
"strconv"
_ "github.com/klauspost/compress"
. "github.com/mmcloughlin/avo/build"
"github.com/mmcloughlin/avo/buildtags"
. "github.com/mmcloughlin/avo/operand"
"github.com/mmcloughlin/avo/reg"
)
func main() {
flag.Parse()
Constraint(buildtags.Not("appengine").ToConstraint())
Constraint(buildtags.Not("noasm").ToConstraint())
Constraint(buildtags.Term("gc").ToConstraint())
Constraint(buildtags.Not("noasm").ToConstraint())
decompress := decompress4x{}
for i := 9; i <= 11; i++ {
decompress.nBits = i
decompress.n = i * 10
decompress.generateProcedure(fmt.Sprintf("decompress4x_main_loop_amd64_%d", i))
}
decompress8b := decompress4x{}
decompress8b.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64")
Generate()
}
const buffoff = 256 // see decompress.go, we're using [4][256]byte table
type decompress4x struct {
n int
nBits int
bmi2 bool
}
func (d decompress4x) generateProcedure(name string) {
Package("github.com/klauspost/compress/huff0")
TEXT(name, 0, "func(ctx* decompress4xContext) uint8")
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "")
Pragma("noescape")
off := GP64()
XORQ(off, off)
exhausted := GP64()
XORQ(exhausted.As64(), exhausted.As64()) // exhausted = false
buffer := GP64()
table := GP64()
br0 := GP64()
br1 := GP64()
br2 := GP64()
br3 := GP64()
Comment("Preload values")
{
ctx := Dereference(Param("ctx"))
Load(ctx.Field("buf"), buffer)
Load(ctx.Field("tbl"), table)
Load(ctx.Field("pbr0"), br0)
Load(ctx.Field("pbr1"), br1)
Load(ctx.Field("pbr2"), br2)
Load(ctx.Field("pbr3"), br3)
}
Comment("Main loop")
Label(name + "_main_loop")
d.decodeTwoValues(d.n+0, br0, table, buffer, off, exhausted)
d.decodeTwoValues(d.n+1, br1, table, buffer, off, exhausted)
d.decodeTwoValues(d.n+2, br2, table, buffer, off, exhausted)
d.decodeTwoValues(d.n+3, br3, table, buffer, off, exhausted)
ADDB(U8(2), off.As8()) // off += 2
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4?
JNZ(LabelRef(name + "_done"))
CMPB(off.As8(), U8(0))
JNZ(LabelRef(name + "_main_loop"))
Label(name + "_done")
offsetComp, err := ReturnIndex(0).Resolve()
if err != nil {
panic(err)
}
MOVB(off.As8(), offsetComp.Addr)
RET()
}
// TODO [wmu]: I believe it's doable in avo, but can't figure out how to deal
// with arbitrary pointers to a given type
const bitReader_in = 0
const bitReader_off = bitReader_in + 3*8 // {ptr, len, cap}
const bitReader_value = bitReader_off + 8
const bitReader_bitsRead = bitReader_value + 8
func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) (brValue, brBitsRead reg.GPVirtual) {
Commentf("br%d.fillFast32()", id)
brValue = GP64()
brBitsRead = GP64()
MOVQ(Mem{Base: br, Disp: bitReader_value}, brValue)
MOVBQZX(Mem{Base: br, Disp: bitReader_bitsRead}, brBitsRead)
brOffset := GP64()
MOVQ(Mem{Base: br, Disp: bitReader_off}, brOffset)
// We must have at least 2 * max tablelog left
CMPQ(brBitsRead, U8(64-atLeast))
JBE(LabelRef("skip_fill" + strconv.Itoa(id)))
SUBQ(U8(32), brBitsRead) // b.bitsRead -= 32
SUBQ(U8(4), brOffset) // b.off -= 4
// v := b.in[b.off-4 : b.off]
// v = v[:4]
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
tmp := GP64()
MOVQ(Mem{Base: br, Disp: bitReader_in}, tmp)
Comment("b.value |= uint64(low) << (b.bitsRead & 63)")
addr := Mem{Base: brOffset, Index: tmp.As64(), Scale: 1}
if d.bmi2 {
SHLXQ(brBitsRead, addr, tmp.As64()) // tmp = uint32(b.in[b.off:b.off+4]) << (b.bitsRead & 63)
} else {
CX := reg.CL
MOVL(addr, tmp.As32()) // tmp = uint32(b.in[b.off:b.off+4])
MOVQ(brBitsRead, CX.As64())
SHLQ(CX, tmp.As64())
}
ORQ(tmp.As64(), brValue)
{
Commentf("exhausted = exhausted || (br%d.off < 4)", id)
CMPQ(brOffset, U8(4))
tmp = GP64()
SETLT(tmp.As8())
ORB(tmp.As8(), exhausted.As8())
}
MOVQ(brOffset, Mem{Base: br, Disp: bitReader_off})
Label("skip_fill" + strconv.Itoa(id))
return
}
// TODO: WIP, does not work.
// Fill, so there is at least 56 bits available.
// Would make it possible to decode all sizes with 4bytes/loop.
func (d decompress4x) fillFast56(id int, br, exhausted reg.GPVirtual) (brValue, brBitsRead reg.GPVirtual) {
Commentf("br%d.fillFast32()", id)
brBitsRead = GP64()
brOffset := GP64()
brPointer := GP64()
MOVQ(Mem{Base: br, Disp: bitReader_off}, brOffset)
MOVBQZX(Mem{Base: br, Disp: bitReader_bitsRead}, brBitsRead)
MOVQ(Mem{Base: br, Disp: bitReader_in}, brPointer)
off := GP64()
MOVQ(brBitsRead, off)
SHRQ(U8(3), off) // off = brBitsRead / 8
SUBQ(off, brOffset) // brOffset = brOffset - off
brValue = GP64()
MOVQ(Mem{Base: brPointer, Index: brOffset, Scale: 1}, brValue) // brValue = brPointer[brOffset]
ANDQ(U8(7), brBitsRead) // brBitsRead = brBitsRead & 7
// We must have at least 2 * max tablelog left
{
Commentf("exhausted = exhausted || (br%d.off < 4)", id)
CMPQ(brOffset, U8(4))
tmp := GP64()
SETLT(tmp.As8())
ORB(tmp.As8(), exhausted.As8())
}
MOVQ(brOffset, Mem{Base: br, Disp: bitReader_off})
return
}
func (d decompress4x) decodeTwoValues(id int, br, table, buffer, off, exhausted reg.GPVirtual) {
brValue, brBitsRead := d.fillFast32(id, d.nBits*2, br, exhausted)
Commentf("val0 := br%d.peekTopBits(peekBits)", id)
CX := reg.CL
val := GP64()
if true {
MOVQ(U32(64-d.nBits), CX.As64())
MOVQ(brValue, val.As64())
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask
} else if false {
mask := GP64()
MOVQ(U32(64-d.nBits|(d.nBits<<8)), mask)
BEXTRQ(mask, brValue, val.As64())
} else {
MOVQ(brValue, val.As64())
SHRQ(U8(64-d.nBits), val.As64()) // val = (value >> peek_bits) & mask
}
Comment("v0 := table[val0&mask]")
v := reg.RDX
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16())
Commentf("br%d.advance(uint8(v0.entry)", id)
out := reg.RAX // Fixed since we need 8H
MOVB(v.As8H(), out.As8()) // BL = uint8(v0.entry >> 8)
MOVBQZX(v.As8(), CX.As64())
if d.bmi2 {
SHLXQ(v.As64(), brValue, brValue) // value <<= n
} else {
SHLQ(CX, brValue) // value <<= n
}
ADDQ(CX.As64(), brBitsRead) // bits_read += n
Commentf("val1 := br%d.peekTopBits(peekBits)", id)
if true {
// Fastest by far on AMD Zen2+
MOVQ(U32(64-d.nBits), CX.As64())
MOVQ(brValue, val.As64())
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask
} else if false {
// Requires BMI2, not much faster.
mask := GP64()
MOVQ(U32(64-d.nBits|(d.nBits<<8)), mask)
BEXTRQ(mask, brValue, val.As64())
} else {
// Slow on Zen2+
MOVQ(brValue, val.As64())
SHRQ(U8(64-d.nBits), val.As64()) // val = (value >> peek_bits) & mask
}
Comment("v1 := table[val1&mask]")
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16()) // tmp - v1
Commentf("br%d.advance(uint8(v1.entry))", id)
MOVB(v.As8H(), out.As8H()) // BH = uint8(v0.entry >> 8)
MOVBQZX(v.As8(), CX.As64())
if d.bmi2 {
SHLXQ(v.As64(), brValue, brValue) // value <<= n
} else {
SHLQ(CX, brValue) // value <<= n
}
ADDQ(CX.As64(), brBitsRead) // bits_read += n
Comment("these two writes get coalesced")
Comment("buf[stream][off] = uint8(v0.entry >> 8)")
Comment("buf[stream][off+1] = uint8(v1.entry >> 8)")
MOVW(out.As16(), Mem{Base: buffer, Index: off, Scale: 1, Disp: (id % 10) * buffoff})
Comment("update the bitrader reader structure")
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value})
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead})
}
func (d decompress4x) generateProcedure4x8bit(name string) {
Package("github.com/klauspost/compress/huff0")
TEXT(name, 0, "func(ctx* decompress4xContext) uint8")
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "")
Pragma("noescape")
off := GP64()
XORQ(off, off)
exhausted := GP64()
XORQ(exhausted.As64(), exhausted.As64())
peekBits := GP64()
buffer := GP64()
table := GP64()
br0 := GP64()
br1 := GP64()
br2 := GP64()
br3 := GP64()
Comment("Preload values")
{
ctx := Dereference(Param("ctx"))
Load(ctx.Field("peekBits"), peekBits)
Load(ctx.Field("buf"), buffer)
Load(ctx.Field("tbl"), table)
Load(ctx.Field("pbr0"), br0)
Load(ctx.Field("pbr1"), br1)
Load(ctx.Field("pbr2"), br2)
Load(ctx.Field("pbr3"), br3)
}
Comment("Main loop")
Label("main_loop")
d.decodeFourValues(0, br0, peekBits, table, buffer, off, exhausted)
d.decodeFourValues(1, br1, peekBits, table, buffer, off, exhausted)
d.decodeFourValues(2, br2, peekBits, table, buffer, off, exhausted)
d.decodeFourValues(3, br3, peekBits, table, buffer, off, exhausted)
ADDB(U8(4), off.As8()) // off += 4
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4?
JNZ(LabelRef("done"))
CMPB(off.As8(), U8(0))
JNZ(LabelRef("main_loop"))
Label("done")
offsetComp, err := ReturnIndex(0).Resolve()
if err != nil {
panic(err)
}
MOVB(off.As8(), offsetComp.Addr)
RET()
}
func (d decompress4x) decodeFourValues(id int, br, peekBits, table, buffer, off, exhausted reg.GPVirtual) {
brValue, brBitsRead := d.fillFast32(id+1000, 32, br, exhausted)
decompress := func(valID int, outByte reg.Register) {
CX := reg.CL
val := GP64()
Commentf("val%d := br%d.peekTopBits(peekBits)", valID, id)
MOVQ(brValue, val.As64())
MOVQ(peekBits, CX.As64())
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask
Commentf("v%d := table[val0&mask]", valID)
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, CX.As16())
Commentf("br%d.advance(uint8(v%d.entry)", id, valID)
MOVB(CX.As8H(), outByte) // BL = uint8(v0.entry >> 8)
MOVBQZX(CX.As8(), CX.As64())
if d.bmi2 {
SHLXQ(CX.As64(), brValue, brValue) // value <<= n
} else {
SHLQ(CX, brValue) // value <<= n
}
ADDQ(CX.As64(), brBitsRead) // bits_read += n
}
out := reg.RAX // Fixed since we need 8H
decompress(0, out.As8L())
decompress(1, out.As8H())
BSWAPL(out.As32())
decompress(2, out.As8H())
decompress(3, out.As8L())
BSWAPL(out.As32())
Comment("these four writes get coalesced")
Comment("buf[stream][off] = uint8(v0.entry >> 8)")
Comment("buf[stream][off+1] = uint8(v1.entry >> 8)")
Comment("buf[stream][off+2] = uint8(v2.entry >> 8)")
Comment("buf[stream][off+3] = uint8(v3.entry >> 8)")
MOVL(out.As32(), Mem{Base: buffer, Index: off, Scale: 1, Disp: id * buffoff})
Comment("update the bitreader reader structure")
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value})
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment