Skip to content

Instantly share code, notes, and snippets.

@blixt
Created January 9, 2019 11:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save blixt/48f4581437eab9863b977ca7dc3001af to your computer and use it in GitHub Desktop.
Save blixt/48f4581437eab9863b977ca7dc3001af to your computer and use it in GitHub Desktop.
Testing ModifyResponse with ReverseProxy and web sockets
package main
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"
)
func TestReverseProxyWebSocketModifyResponse(t *testing.T) {
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
t.Error("unexpected backend request")
http.Error(w, "unexpected request", 400)
return
}
c, _, err := w.(http.Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer c.Close()
io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
bs := bufio.NewScanner(c)
if !bs.Scan() {
t.Errorf("backend failed to read line from client: %v", bs.Err())
return
}
fmt.Fprintf(c, "backend got %q\n", bs.Text())
}))
defer backendServer.Close()
backURL, _ := url.Parse(backendServer.URL)
rproxy := httputil.NewSingleHostReverseProxy(backURL)
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
rproxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("X-Important-Header", "HelloWorld")
return nil
}
frontendProxy := httptest.NewServer(rproxy)
defer frontendProxy.Close()
req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
c := frontendProxy.Client()
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 101 {
t.Fatalf("status = %v; want 101", res.Status)
}
if strings.ToLower(res.Header.Get("Upgrade")) != "websocket" {
t.Fatalf("not websocket upgrade; got %#v", res.Header)
}
if res.Header.Get("X-Important-Header") != "HelloWorld" {
t.Fatalf("missing/invalid custom header; got %#v", res.Header)
}
rwc, ok := res.Body.(io.ReadWriteCloser)
if !ok {
t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
}
defer rwc.Close()
io.WriteString(rwc, "Hello\n")
bs := bufio.NewScanner(rwc)
if !bs.Scan() {
t.Fatalf("Scan: %v", bs.Err())
}
got := bs.Text()
want := `backend got "Hello"`
if got != want {
t.Errorf("got %#q, want %#q", got, want)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment