Skip to content

Instantly share code, notes, and snippets.

@onsi
Created March 4, 2014 05:44
Show Gist options
  • Save onsi/9340903 to your computer and use it in GitHub Desktop.
Save onsi/9340903 to your computer and use it in GitHub Desktop.
test_server
//something like...
var _ = Describe("A Server", func() {
var s *Server
BeforeEach(func() {
s = New()
s.Append(CombineHandlers(
VerifyRequest("GET", "/foo/bar"),
VerifyBasicAuth("bob", "password"),
Respond(200, []byte("foo")),
))
s.Append(func(w http.ResponseWriter, req *http.Request) {
//make custom assertions for this request
//write a response, put things in closure variables, etc...
})
})
It("should ...", func() {
})
})
package test_server
//little building blocks to build handlers
import (
"encoding/base64"
"fmt"
. "github.com/onsi/gomega"
"net/http"
)
func CombineHandlers(handlers ...http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
for _, handler := range handlers {
handler(w, req)
}
}
}
func VerifyRequest(method string, path string, rawQuery ...string) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
Ω(req.Method).Should(Equal(method), "Method mismatch")
Ω(req.URL.Path).Should(Equal(path), "Path mismatch")
if len(rawQuery) > 0 {
Ω(req.URL.RawQuery).Should(Equal(rawQuery[0]), "RawQuery mismatch")
}
}
}
func VerifyBasicAuth(username string, password string) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
auth := req.Header.Get("Authorization")
decoded, err := base64.StdEncoding.DecodeString(auth[7:])
Ω(err).ShouldNot(HaveOccurred())
Ω(string(decoded)).Should(Equal(fmt.Sprintf("%s:%s", username, password)), "Authorization mismatch")
}
}
func VerifyHeader(header http.Header) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
for key, values := range header {
key = http.CanonicalHeaderKey(key)
Ω(req.Header[key]).Should(Equal(values), "Header mismatch for key: %s", key)
}
}
}
func Respond(statusCode int, body []byte) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(statusCode)
w.Write(body)
}
}
package test_server
// a very simple test server wrapper
import (
"fmt"
"github.com/onsi/ginkgo"
"github.com/onsi/gomega/format"
"io/ioutil"
"net/http"
"net/http/httptest"
)
func New() *Server {
s := &Server{}
s.HTTPTestServer = httptest.NewServer(s)
return s
}
func NewTLS() *Server {
s := &Server{}
s.HTTPTestServer = httptest.NewTLSServer(s)
return s
}
type Server struct {
HTTPTestServer *httptest.Server
ReceivedRequests []*http.Request
RequestHandlers []http.HandlerFunc
FailOnUnhandledRequests bool
calls int
}
func (s *Server) Close() {
server := s.HTTPTestServer
s.HTTPTestServer = nil
server.Close()
}
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.ReceivedRequests = append(s.ReceivedRequests, req)
if s.calls < len(s.RequestHandlers) {
s.RequestHandlers[s.calls](w, req)
} else {
if s.FailOnUnhandledRequests {
ginkgo.Fail(fmt.Sprintf("Received unhandled request:\n%s", format.Object(req, 1)))
} else {
ioutil.ReadAll(req.Body)
req.Body.Close()
w.WriteHeader(http.StatusInternalServerError)
}
}
s.calls++
}
func (s *Server) Append(handlers ...http.HandlerFunc) {
s.RequestHandlers = append(s.RequestHandlers, handlers...)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment