Created
January 27, 2019 03:44
-
-
Save appellation/e65a3cb736b57e82480da2dbfa9d255e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package state | |
import ( | |
"crypto/rand" | |
"encoding/base64" | |
"encoding/binary" | |
"time" | |
) | |
const ( | |
uint32Size = 4 | |
nonceSize = 16 | |
) | |
// State represents an OAuth2 states | |
type State struct { | |
RedirectURL string | |
Nonce [nonceSize]byte | |
CreatedAt time.Time | |
} | |
// New makes a new OAuth2 state | |
func New(redirect string) *State { | |
s := &State{ | |
RedirectURL: redirect, | |
CreatedAt: time.Now(), | |
} | |
rand.Read(s.Nonce[:]) | |
return s | |
} | |
// Equals compares this state to another state | |
func (s *State) Equals(o *State) bool { | |
if o == nil { | |
return false | |
} | |
return s.RedirectURL == o.RedirectURL && s.Nonce == o.Nonce && s.CreatedAt.Unix() == o.CreatedAt.Unix() | |
} | |
// UnmarshalBase64 puts a base64 encoded string into the struct | |
func (s *State) UnmarshalBase64(b64 string) error { | |
b, err := base64.URLEncoding.DecodeString(b64) | |
if err != nil { | |
return err | |
} | |
return s.UnmarshalBinary(b) | |
} | |
// MarshalBase64 creates a base64 encoding of this state | |
func (s *State) MarshalBase64() string { | |
b, _ := s.MarshalBinary() | |
return base64.URLEncoding.EncodeToString(b) | |
} | |
// MarshalBinary implements the binary marshaler interface | |
func (s *State) MarshalBinary() ([]byte, error) { | |
buf := make([]byte, nonceSize+uint32Size+len(s.RedirectURL)) | |
copy(buf, s.Nonce[:]) | |
off := nonceSize | |
binary.LittleEndian.PutUint32(buf[off:], uint32(s.CreatedAt.Unix())) | |
off += uint32Size | |
copy(buf[off:], []byte(s.RedirectURL)) | |
return buf, nil | |
} | |
// UnmarshalBinary implements the binary unmarshaler interface | |
func (s *State) UnmarshalBinary(data []byte) error { | |
copy(s.Nonce[:], data) | |
off := nonceSize | |
s.CreatedAt = time.Unix(int64(binary.LittleEndian.Uint32(data[off:])), 0) | |
off += uint32Size | |
s.RedirectURL = string(data[off:]) | |
return nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package state_test | |
import ( | |
"testing" | |
"github.com/appellation/pleb/auth/state" | |
) | |
func TestState(t *testing.T) { | |
s := state.New("") | |
t.Log(s) | |
b64 := s.MarshalBase64() | |
t.Log(b64) | |
compare := &state.State{} | |
compare.UnmarshalBase64(b64) | |
t.Log(compare) | |
if compare.MarshalBase64() != b64 { | |
t.Errorf("generated \"%s\" not equal to decoded \"%s\"", b64, compare.MarshalBase64()) | |
return | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package state | |
import ( | |
"sync" | |
"time" | |
) | |
// Store represents a store of states | |
type Store struct { | |
Timeout time.Duration | |
states sync.Map | |
close chan bool | |
} | |
// NewStore makes a new state store | |
func NewStore() *Store { | |
s := &Store{time.Hour, sync.Map{}, make(chan bool)} | |
go s.sweepStates() | |
return s | |
} | |
// Generate makes a new OAuth2 state and stores it in the store. Returns base64 version. | |
func (s *Store) Generate(redirect string) string { | |
state := New(redirect) | |
s.states.Store(state.Nonce, state) | |
return state.MarshalBase64() | |
} | |
// Verify validates a base64 string as an existing state in the store. Returns nil if invalid. | |
func (s *Store) Verify(b64 string) *State { | |
state := &State{} | |
state.UnmarshalBase64(b64) | |
existing, ok := s.states.Load(state.Nonce) | |
if ok && state.Equals(existing.(*State)) && !s.Expired(state) { | |
s.states.Delete(state.Nonce) | |
return state | |
} | |
return nil | |
} | |
// Expired determines whether a state is expired | |
func (s *Store) Expired(state *State) bool { | |
return state.CreatedAt.Unix() < time.Now().Add(-s.Timeout).Unix() | |
} | |
// Close closes this state store | |
func (s *Store) Close() { | |
s.close <- true | |
} | |
func (s *Store) sweepStates() { | |
ticker := time.NewTicker(s.Timeout) | |
defer ticker.Stop() | |
for { | |
select { | |
case <-ticker.C: | |
s.states.Range(func(k, v interface{}) bool { | |
if s.Expired(v.(*State)) { | |
s.states.Delete(k) | |
} | |
return true | |
}) | |
case <-s.close: | |
s.states = sync.Map{} | |
return | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment