Skip to content

Instantly share code, notes, and snippets.

@FZambia
Created September 21, 2019 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FZambia/d67f7ff4de9aa8a706c293b8eaf5532c to your computer and use it in GitHub Desktop.
Save FZambia/d67f7ff4de9aa8a706c293b8eaf5532c to your computer and use it in GitHub Desktop.
Klauspost compress library as replacement for std lib in Gorilla Websocket
diff --git a/compression.go b/compression.go
index 813ffb1..4c492e0 100644
--- a/compression.go
+++ b/compression.go
@@ -41,16 +41,47 @@ func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}
-func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
+// FlateWriter ...
+type FlateWriter interface {
+ Write(data []byte) (n int, err error)
+ Reset(dst io.Writer)
+ Flush() error
+ Close() error
+}
+
+func defaultAcquireFlateWriter(w io.Writer, level int) FlateWriter {
p := &flateWriterPools[level-minCompressionLevel]
- tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
- fw, _ = flate.NewWriter(tw, level)
+ fw, _ = flate.NewWriter(w, level)
} else {
- fw.Reset(tw)
+ fw.Reset(w)
}
- return &flateWriteWrapper{fw: fw, tw: tw, p: p}
+ return &poolFlateWriter{fw, p}
+}
+
+type poolFlateWriter struct {
+ *flate.Writer
+ p *sync.Pool
+}
+
+func (w *poolFlateWriter) Close() error {
+ w.p.Put(w.Writer)
+ return nil
+}
+
+func compressNoContextTakeoverFlateWriter(acquireFlateWriter func(w io.Writer, level int) FlateWriter) func(w io.WriteCloser, level int) io.WriteCloser {
+ return func(w io.WriteCloser, level int) io.WriteCloser {
+ tw := &truncWriter{w: w}
+ fw := acquireFlateWriter(tw, level)
+ return &flateWriteWrapper{fw: fw, tw: tw}
+ }
+}
+
+func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
+ tw := &truncWriter{w: w}
+ fw := defaultAcquireFlateWriter(tw, level)
+ return &flateWriteWrapper{fw: fw, tw: tw}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
@@ -90,9 +121,8 @@ func (w *truncWriter) Write(p []byte) (int, error) {
}
type flateWriteWrapper struct {
- fw *flate.Writer
+ fw FlateWriter
tw *truncWriter
- p *sync.Pool
}
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
@@ -107,16 +137,19 @@ func (w *flateWriteWrapper) Close() error {
return errWriteClosed
}
err1 := w.fw.Flush()
- w.p.Put(w.fw)
+ err2 := w.fw.Close()
w.fw = nil
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
- err2 := w.tw.w.Close()
+ err3 := w.tw.w.Close()
if err1 != nil {
return err1
}
- return err2
+ if err2 != nil {
+ return err1
+ }
+ return err3
}
type flateReadWrapper struct {
diff --git a/compression_test.go b/compression_test.go
index 8a26b30..b089e64 100644
--- a/compression_test.go
+++ b/compression_test.go
@@ -5,7 +5,10 @@ import (
"fmt"
"io"
"io/ioutil"
+ "sync"
"testing"
+
+ customFlate "github.com/klauspost/compress/flate"
)
type nopCloser struct{ io.Writer }
@@ -65,6 +68,44 @@ func BenchmarkWriteWithCompression(b *testing.B) {
b.ReportAllocs()
}
+var (
+ customFlateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
+)
+
+func acquireCustomFlateWriter(w io.Writer, level int) FlateWriter {
+ p := &customFlateWriterPools[level-minCompressionLevel]
+ fw, _ := p.Get().(*customFlate.Writer)
+ if fw == nil {
+ fw, _ = customFlate.NewWriter(w, level)
+ } else {
+ fw.Reset(w)
+ }
+ return &customPoolFlateWriter{fw, p}
+}
+
+type customPoolFlateWriter struct {
+ *customFlate.Writer
+ p *sync.Pool
+}
+
+func (w *customPoolFlateWriter) Close() error {
+ w.p.Put(w.Writer)
+ return nil
+}
+
+func BenchmarkWriteWithCompressionCustom(b *testing.B) {
+ w := ioutil.Discard
+ c := newTestConn(nil, w, false)
+ messages := textMessages(100)
+ c.enableWriteCompression = true
+ c.newCompressionWriter = compressNoContextTakeoverFlateWriter(acquireCustomFlateWriter)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ c.WriteMessage(TextMessage, messages[i%len(messages)])
+ }
+ b.ReportAllocs()
+}
+
func TestValidCompressionLevel(t *testing.T) {
c := newTestConn(nil, nil, false)
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go
index cb88cbb..75bff5b 100644
--- a/conn_broadcast_test.go
+++ b/conn_broadcast_test.go
@@ -71,7 +71,7 @@ func (b *broadcastBench) makeConns(numConns int) {
c := newTestConn(nil, b.w, true)
if b.compression {
c.enableWriteCompression = true
- c.newCompressionWriter = compressNoContextTakeover
+ c.newCompressionWriter = compressNoContextTakeoverFlateWriter(acquireCustomFlateWriter)
}
conns[i] = newBroadcastConn(c)
go func(c *broadcastConn) {
diff --git a/server.go b/server.go
index 887d558..6e6c5be 100644
--- a/server.go
+++ b/server.go
@@ -70,6 +70,8 @@ type Upgrader struct {
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
+
+ AcquireFlateWriter func(w io.Writer, level int) FlateWriter
}
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
@@ -203,7 +205,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.subprotocol = subprotocol
if compress {
- c.newCompressionWriter = compressNoContextTakeover
+ if u.AcquireFlateWriter != nil {
+ c.newCompressionWriter = compressNoContextTakeoverFlateWriter(u.AcquireFlateWriter)
+ } else {
+ c.newCompressionWriter = compressNoContextTakeover
+ }
c.newDecompressionReader = decompressNoContextTakeover
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment