Skip to content

Instantly share code, notes, and snippets.

@obeattie
Created December 2, 2018 16:36
Show Gist options
  • Save obeattie/bcf770518b2f698f6e6a085927bd1420 to your computer and use it in GitHub Desktop.
Save obeattie/bcf770518b2f698f6e6a085927bd1420 to your computer and use it in GitHub Desktop.
// Package h2c provides a HTTP/2.0 h2c client transport implementation
package h2c
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sync"
"golang.org/x/net/http2"
"golang.org/x/sync/singleflight"
)
type uncloseableConn struct {
net.Conn
}
func (u uncloseableConn) Close() error {
return nil
}
type h2cUpgradedConn struct {
net.Conn
onClose func()
closeOnce sync.Once
swallowedFirstReq bool
}
func (c *h2cUpgradedConn) Close() error {
c.closeOnce.Do(c.onClose)
return c.Conn.Close()
}
func (c *h2cUpgradedConn) Write(b []byte) (int, error) {
length := len(b)
// The first request sent via this connection needs to be swallowed; see h2cTransport.RoundTrip for an explanation
if !c.swallowedFirstReq {
rdr := bufio.NewReader(bytes.NewReader(b))
cleaned := new(bytes.Buffer)
// The "magic" client preface (RFC 7540 §3.5) isn't a valid HTTP/2 frame, so deal with that up front ✨
magic := []byte(http2.ClientPreface)
if p, _ := rdr.Peek(len(magic)); bytes.Equal(p, magic) {
io.CopyN(cleaned, rdr, int64(len(p)))
}
// Frames can be batched in a single Write(), so iterate over them and if any are HEADERS(stream=1), drop it
frameHeader := make([]byte, 9) // 9 octet frame header (RFC 7540 §4.1)
for {
if _, err := io.ReadFull(rdr, frameHeader); err != nil {
if err == io.EOF {
break
}
return 0, err
}
frameType := http2.FrameType(frameHeader[3])
frameStream := binary.BigEndian.Uint32(frameHeader[5:]) & (1<<31 - 1)
frameLength := (uint32(frameHeader[0])<<16 | uint32(frameHeader[1])<<8 | uint32(frameHeader[2]))
if frameType == http2.FrameHeaders && frameStream == 1 { // Swallow it
c.swallowedFirstReq = true
rdr.Discard(int(frameLength))
continue
}
cleaned.Write(frameHeader)
io.CopyN(cleaned, rdr, int64(frameLength))
}
b = cleaned.Bytes()
}
n, err := c.Conn.Write(b)
if err == nil {
return length, nil // Pretend to the caller all the data was written, even if we swallowed some
}
return n, err
}
// Transport establishes HTTP/2.0 connections via the h2c method described in RFC 7540 §3.2
type Transport struct {
http2.Transport
sf singleflight.Group
h2Conns sync.Map // map[string("scheme://host")]*http2.ClientConn
}
func (t *Transport) key(u *url.URL) string {
return fmt.Sprintf("%s://%s", u.Scheme, u.Host)
}
// RoundTrip implements http.RoundTripper
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "http" {
return nil, fmt.Errorf("Only http:// URLs are eligible for h2c")
}
// Check for an existing connection that can be reused
key := t.key(req.URL)
if c, ok := t.h2Conns.Load(key); ok {
return c.(*http2.ClientConn).RoundTrip(req)
}
// We don't have an existing connection; start the process of establishing one
// Use of the singleflight here ensures that we only end up with one h2c connection per destination
rt, err, _ := t.sf.Do(key, func() (interface{}, error) {
if c, ok := t.h2Conns.Load(key); ok {
return c.(*http2.ClientConn), nil
}
// We're ready to try a new h2c connection: send an OPTIONS request which is eligible for upgrading
h1UpgradeReq, err := http.NewRequest("OPTIONS", fmt.Sprintf("%s://%s", req.URL.Scheme, req.URL.Host), nil)
h1UpgradeReq.Header.Add("Connection", "Upgrade")
h1UpgradeReq.Header.Add("Connection", "HTTP2-Settings")
h1UpgradeReq.Header.Set("Upgrade", "h2c")
settingsBuf := new(bytes.Buffer)
settingsFramer := http2.NewFramer(settingsBuf, nil)
settingsFramer.WriteSettings(
http2.Setting{ID: http2.SettingEnablePush, Val: 0},
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 << 20})
h1UpgradeReq.Header.Set("HTTP2-Settings", base64.URLEncoding.EncodeToString(settingsBuf.Bytes()))
conn, err := net.Dial("tcp", req.URL.Host)
if err != nil {
return nil, err
}
upgradeTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return uncloseableConn{ // We don't want the http.Transport to be able to close this
Conn: conn}, err
}}
upgradeRsp, err := upgradeTransport.RoundTrip(h1UpgradeReq)
if err != nil {
conn.Close()
return nil, err
}
upgradeRsp.Body.Close()
if upgradeRsp.StatusCode == http.StatusSwitchingProtocols {
// The server is switching to h2; wrap the conn as a h2Conn
h2Conn, err := t.Transport.NewClientConn(&h2cUpgradedConn{
Conn: conn,
onClose: func() {
t.h2Conns.Delete(key)
}})
if err != nil {
conn.Close()
return nil, err
}
// The server will now send the response to the original request via h2, so we need to consume it. This is
// assigned stream ID 0x1 (RFC 7540 §3.2), and because all subsequent stream identifiers must be numerically
// greater than this (RFC 7540 §5.1.1), we need to trick h2Conn into consuming the response and updating its
// internal stream counter.
//
// This is hackishly achieved by round-tripping another "upgrade" request via h2Conn. The "request" half –
// which isn't needed – is dropped by h2cUpgradedConn, but the response is still forwarded. 🤢
h2UpgradeReq, err := http.NewRequest(h1UpgradeReq.Method, "/", nil)
if err != nil {
h2Conn.Close()
return nil, err
}
h2UpgradeRsp, err := h2Conn.RoundTrip(h2UpgradeReq)
if err != nil {
h2Conn.Close()
return nil, err
}
h2UpgradeRsp.Body.Close()
// The connection is now fully upgraded and is useable for h2 requests
t.h2Conns.Store(key, h2Conn)
return h2Conn, nil
}
// The upgrade didn't work; mark this as a known H1 connection (for a while) and fallback to h1
conn.Close()
return nil, fmt.Errorf("h2c upgrade failed; expected 101 Switching Protocols, got %s", upgradeRsp.Status)
})
if err != nil {
return nil, err
}
return rt.(http.RoundTripper).RoundTrip(req)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment