Skip to content

Instantly share code, notes, and snippets.

@appellation
Created January 27, 2019 03:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save appellation/e65a3cb736b57e82480da2dbfa9d255e to your computer and use it in GitHub Desktop.
Save appellation/e65a3cb736b57e82480da2dbfa9d255e to your computer and use it in GitHub Desktop.
package state
import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"time"
)
const (
uint32Size = 4
nonceSize = 16
)
// State represents an OAuth2 states
type State struct {
RedirectURL string
Nonce [nonceSize]byte
CreatedAt time.Time
}
// New makes a new OAuth2 state
func New(redirect string) *State {
s := &State{
RedirectURL: redirect,
CreatedAt: time.Now(),
}
rand.Read(s.Nonce[:])
return s
}
// Equals compares this state to another state
func (s *State) Equals(o *State) bool {
if o == nil {
return false
}
return s.RedirectURL == o.RedirectURL && s.Nonce == o.Nonce && s.CreatedAt.Unix() == o.CreatedAt.Unix()
}
// UnmarshalBase64 puts a base64 encoded string into the struct
func (s *State) UnmarshalBase64(b64 string) error {
b, err := base64.URLEncoding.DecodeString(b64)
if err != nil {
return err
}
return s.UnmarshalBinary(b)
}
// MarshalBase64 creates a base64 encoding of this state
func (s *State) MarshalBase64() string {
b, _ := s.MarshalBinary()
return base64.URLEncoding.EncodeToString(b)
}
// MarshalBinary implements the binary marshaler interface
func (s *State) MarshalBinary() ([]byte, error) {
buf := make([]byte, nonceSize+uint32Size+len(s.RedirectURL))
copy(buf, s.Nonce[:])
off := nonceSize
binary.LittleEndian.PutUint32(buf[off:], uint32(s.CreatedAt.Unix()))
off += uint32Size
copy(buf[off:], []byte(s.RedirectURL))
return buf, nil
}
// UnmarshalBinary implements the binary unmarshaler interface
func (s *State) UnmarshalBinary(data []byte) error {
copy(s.Nonce[:], data)
off := nonceSize
s.CreatedAt = time.Unix(int64(binary.LittleEndian.Uint32(data[off:])), 0)
off += uint32Size
s.RedirectURL = string(data[off:])
return nil
}
package state_test
import (
"testing"
"github.com/appellation/pleb/auth/state"
)
func TestState(t *testing.T) {
s := state.New("")
t.Log(s)
b64 := s.MarshalBase64()
t.Log(b64)
compare := &state.State{}
compare.UnmarshalBase64(b64)
t.Log(compare)
if compare.MarshalBase64() != b64 {
t.Errorf("generated \"%s\" not equal to decoded \"%s\"", b64, compare.MarshalBase64())
return
}
}
package state
import (
"sync"
"time"
)
// Store represents a store of states
type Store struct {
Timeout time.Duration
states sync.Map
close chan bool
}
// NewStore makes a new state store
func NewStore() *Store {
s := &Store{time.Hour, sync.Map{}, make(chan bool)}
go s.sweepStates()
return s
}
// Generate makes a new OAuth2 state and stores it in the store. Returns base64 version.
func (s *Store) Generate(redirect string) string {
state := New(redirect)
s.states.Store(state.Nonce, state)
return state.MarshalBase64()
}
// Verify validates a base64 string as an existing state in the store. Returns nil if invalid.
func (s *Store) Verify(b64 string) *State {
state := &State{}
state.UnmarshalBase64(b64)
existing, ok := s.states.Load(state.Nonce)
if ok && state.Equals(existing.(*State)) && !s.Expired(state) {
s.states.Delete(state.Nonce)
return state
}
return nil
}
// Expired determines whether a state is expired
func (s *Store) Expired(state *State) bool {
return state.CreatedAt.Unix() < time.Now().Add(-s.Timeout).Unix()
}
// Close closes this state store
func (s *Store) Close() {
s.close <- true
}
func (s *Store) sweepStates() {
ticker := time.NewTicker(s.Timeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.states.Range(func(k, v interface{}) bool {
if s.Expired(v.(*State)) {
s.states.Delete(k)
}
return true
})
case <-s.close:
s.states = sync.Map{}
return
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment