Skip to content

Instantly share code, notes, and snippets.

@rybit
Created August 8, 2017 19:00
Show Gist options
  • Save rybit/190c8129fb649123f5f61e6995cc9216 to your computer and use it in GitHub Desktop.
Save rybit/190c8129fb649123f5f61e6995cc9216 to your computer and use it in GitHub Desktop.
example of jws
package main
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"time"
)
var sharedSecret = "secret"
type jwsPayload struct {
Timestamp int64 `json:"timestamp"`
UserID string `json:"user_id"`
AccountID string `json:"account_id"`
}
const (
dateHeader = "date"
)
func main() {
server := startServer()
defer server.Close()
// server = startBadActor(server.URL)
// defer server.Close()
// now make a request
makeRequest(server.URL)
}
func startBadActor(nextStage string) *httptest.Server {
nextURL, err := url.Parse(nextStage)
checkErr(err)
rp := httputil.ReverseProxy{
Director: func(r *http.Request) {
// we can extract each part
total := r.Header.Get(signatureHeader)
if total == "" {
panic(errors.New("bad actor: Missing the right header - def not us"))
}
parts := strings.Split(total, ".")
if len(parts) != 3 {
panic(errors.New("bad actor: malformed header"))
}
encodedProtected := parts[0]
encodedPayload := parts[1]
// signature := parts[2]
// now we can reverse and figure out what values we used
rawProtected, err := base64Decode(encodedProtected)
checkErr(err)
protected := map[string]interface{}{}
checkErr(json.Unmarshal(rawProtected, &protected))
log.Printf("protected header: %v\n", protected)
// we could use this to figure out the signing method and such..but we know it soo skip!
rawPayload, err := base64Decode(encodedPayload)
checkErr(err)
payload := new(jwsPayload)
checkErr(json.Unmarshal(rawPayload, payload))
log.Printf("signing payload: %v\n", payload)
// now we can generate the header again...with the same data
signRequest(r, payload.UserID, payload.AccountID, "I DON'T KNOW")
// and a different URL
r.URL = nextURL
},
}
return httptest.NewServer(&rp)
}
func makeRequest(addr string) {
req, err := http.NewRequest(http.MethodGet, addr, nil)
checkErr(err)
signRequest(req, "some-user", "some-account", sharedSecret)
rsp, err := http.DefaultClient.Do(req)
checkErr(err)
if rsp.StatusCode != http.StatusOK {
fmt.Printf("Failed to verify signature. Got a %d\n", rsp.StatusCode)
} else {
fmt.Printf("Signatures match!")
}
}
func signRequest(req *http.Request, userID, accountID, secret string) error {
now := time.Now()
if req.Header.Get(dateHeader) != "" {
var err error
now, err = http.ParseTime(req.Header.Get(dateHeader))
if err != nil {
return err
}
}
now = now.UTC()
// set the headers that we are going to sign
req.Header.Set(dateHeader, now.Format(http.TimeFormat))
req.Header.Set(userHeader, userID)
req.Header.Set(accountHeader, accountID)
// create a payload to sign that matches the headers
pay := &jwsPayload{
Timestamp: now.Unix(),
UserID: userID,
AccountID: accountID,
}
raw, err := json.Marshal(pay)
if err != nil {
return err
}
signature, err := sign(raw, []byte(secret))
if err != nil {
return err
}
req.Header.Set(signatureHeader, signature)
return nil
}
func startServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
total := r.Header.Get(signatureHeader)
if total == "" {
log.Println("Missing the right header - def not us")
w.WriteHeader(http.StatusForbidden)
return
}
parts := strings.Split(total, ".")
if len(parts) != 3 {
log.Println("Malformed header - something is def wrong here")
w.WriteHeader(http.StatusBadRequest)
return
}
raw := hmac.New(newHash, []byte(sharedSecret)).Sum([]byte(fmt.Sprintf("%s.%s", parts[0], parts[1])))
shouldBe := base64Encode(raw)
if parts[2] != shouldBe {
log.Printf("Mismatch between signature vs expected: '%s' vs '%s'\n", total, shouldBe)
w.WriteHeader(http.StatusForbidden)
return
}
// now we can trust the values in the header and payload
w.WriteHeader(http.StatusOK)
}))
}
// ---------------------------------------------------------------------------------------------------------------------
// JWS Signing
// ---------------------------------------------------------------------------------------------------------------------
const (
signingAlg = "HS256"
signingType = "JWT"
)
var protectedHeader = struct {
Type string `json:"typ"`
Alg string `json:"alg"`
}{
Type: "JWT",
Alg: "HS256",
}
// these don't vary between signatures
var newHash = func() hash.Hash {
return sha256.New()
}
/*
BASE64URL(UTF8(JWS Protected Header)) || '.' ||
BASE64URL(JWS Payload) || '.' ||
BASE64URL(JWS Signature)
*/
func sign(rawPayload, secret []byte) (string, error) {
rawProtected, err := json.Marshal(protectedHeader)
if err != nil {
return "", err
}
protected := base64Encode(rawProtected)
payload := base64Encode(rawPayload)
raw := hmac.New(newHash, secret).Sum([]byte(fmt.Sprintf("%s.%s", protected, payload)))
signature := base64Encode(raw)
return fmt.Sprintf("%s.%s.%s", protected, payload, signature), nil
}
// this is like the general encoding, but it handles stripping the padding as necessary
// check out https://tools.ietf.org/html/rfc7515#appendix-C
func base64Encode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
encoded = strings.TrimRight(encoded, "=")
encoded = strings.Replace(encoded, "+", "-", -1)
encoded = strings.Replace(encoded, "/", "_", -1)
return encoded
}
func base64Decode(str string) ([]byte, error) {
str = strings.Replace(str, "-", "+", -1)
str = strings.Replace(str, "_", "/", -1)
switch len(str) % 4 {
case 0:
case 1:
return nil, errors.New("Invalid base64 string")
case 2:
str += "=="
case 3:
str += "="
}
return base64.URLEncoding.DecodeString(str)
}
// ---------------------------------------------------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------------------------------------------------
func extractPayload(rsp *http.Response) []byte {
b, err := ioutil.ReadAll(rsp.Body)
checkErr(err)
defer rsp.Body.Close()
return b
}
// def just for demo
func checkErr(err error) {
if err != nil {
panic(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment