Last active
December 17, 2019 06:26
-
-
Save codingconcepts/59559a10794935df2cc9fabf210dfc9c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"crypto/rand" | |
"database/sql" | |
"encoding/json" | |
"fmt" | |
"log" | |
"net/http" | |
"time" | |
"github.com/gorilla/mux" | |
"github.com/lib/pq" | |
"golang.org/x/crypto/bcrypt" | |
) | |
func main() { | |
// Open a connection to the database and test it. Stopping the | |
// application if either fails. | |
db, err := sql.Open("postgres", "postgres://root@localhost:26257/defaultdb?sslmode=disable") | |
if err != nil { | |
log.Fatalf("error opening database connection: %v", err) | |
} | |
if err = db.Ping(); err != nil { | |
log.Fatalf("error testing database connection: %v", err) | |
} | |
defer db.Close() | |
// Configure some sensible config defaults. | |
s := &server{ | |
db: db, | |
hashCost: 14, | |
tokenSize: 80, | |
tokenExpiry: time.Hour, | |
resetURLFormat: "http://localhost:1234/reset/%s", | |
sweepLimit: 1000, | |
sweepInterval: time.Minute * 30, | |
sweepOverflowInterval: time.Second * 1, | |
} | |
// Spin up an HTTP router and register the handlers we'll need for | |
// this example. /register to create new users, /login to demonstrate | |
// that passwords are being changed, /forgot to request a new password, | |
// and /reset to configure a new password. | |
handler := mux.NewRouter() | |
handler.HandleFunc("/register", register(s)).Methods(http.MethodPost) | |
handler.HandleFunc("/login", login(s)).Methods(http.MethodPost) | |
handler.HandleFunc("/forgot", forgot(s)).Methods(http.MethodPost) | |
handler.HandleFunc("/reset/{token}", reset(s)).Methods(http.MethodPost) | |
// Configure an HTTP server with some sensible timeouts. | |
hs := &http.Server{ | |
Addr: ":1234", | |
Handler: handler, | |
ReadHeaderTimeout: time.Second * 2, | |
ReadTimeout: time.Second * 2, | |
WriteTimeout: time.Second * 3, | |
IdleTimeout: time.Second * 10, | |
} | |
// Start the sweep worker. | |
go sweep(s) | |
// Start the HTTP server and log the error it returns (if any). | |
log.Fatal(hs.ListenAndServe()) | |
} | |
// server holds the configuration variables required by the various | |
// endpoints and helpful functions of the application. | |
type server struct { | |
// The database connection pool. | |
db *sql.DB | |
// The complexity of of the bcrypt hash. | |
hashCost int | |
// The size of the reset token that will be sent to the user. | |
// Bigger tokens offer better security. | |
tokenSize int | |
// The time after which reset tokens will no longer be valid. | |
tokenExpiry time.Duration | |
// The format of the reset URL that will be sent to the user. | |
resetURLFormat string | |
// The amount of time between token table clear downs. | |
sweepInterval time.Duration | |
// The amount of time we'll give the database between batch | |
// deletions of reset tokens if there are more than sweepLimit | |
// tokens to delete. | |
sweepOverflowInterval time.Duration | |
// The number of tokens to delete in any one go. Larger numbers | |
// will delete more tokens and lock the reset token for longer. | |
sweepLimit int64 | |
} | |
// simpleResponse is shared by a number of endpoints to return a simple | |
// text-based message to the user. | |
type simpleResponse struct { | |
Message string `json:"message"` | |
} | |
func register(s *server) http.HandlerFunc { | |
type request struct { | |
Email string `json:"email"` | |
Password string `json:"password"` | |
} | |
type response struct { | |
ID string `json:"id"` | |
} | |
return func(w http.ResponseWriter, r *http.Request) { | |
// Bind to the user's request object. | |
var req request | |
if err := bindClose(r, &req); err != nil { | |
http.Error(w, err.Error(), http.StatusUnprocessableEntity) | |
return | |
} | |
// Hash the user's plaintext password. | |
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), s.hashCost) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
// Insert to the user into the database and fetch their ID. | |
row := s.db.QueryRow( | |
`insert into "user" ("email", "password") values ($1, $2) returning id`, | |
req.Email, | |
hash) | |
var id string | |
if err := row.Scan(&id); err != nil { | |
if pqerr, ok := err.(*pq.Error); ok && pqerr.Code == "23505" { | |
respond(w, http.StatusConflict, simpleResponse{Message: "email address already registered"}) | |
return | |
} | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
respond(w, http.StatusOK, response{ID: id}) | |
} | |
} | |
func login(s *server) http.HandlerFunc { | |
type request struct { | |
Email string `json:"email"` | |
Password string `json:"password"` | |
} | |
return func(w http.ResponseWriter, r *http.Request) { | |
var req request | |
if err := bindClose(r, &req); err != nil { | |
http.Error(w, err.Error(), http.StatusUnprocessableEntity) | |
return | |
} | |
// Lookup the user's hashed password from the database. | |
row := s.db.QueryRow( | |
`select "password" from "user" where "email" = $1`, | |
req.Email) | |
var password []byte | |
if err := row.Scan(&password); err != nil { | |
respond(w, http.StatusUnauthorized, simpleResponse{Message: "invalid email or password"}) | |
return | |
} | |
// Compare the user's stored password against the one they've | |
// just provided at login. | |
ok, err := check(password, []byte(req.Password)) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusUnprocessableEntity) | |
return | |
} | |
if !ok { | |
respond(w, http.StatusUnauthorized, simpleResponse{Message: "invalid email or password"}) | |
return | |
} | |
respond(w, http.StatusOK, simpleResponse{Message: "implement JWTs"}) | |
} | |
} | |
func forgot(s *server) http.HandlerFunc { | |
type request struct { | |
Email string `json:"email"` | |
} | |
return func(w http.ResponseWriter, r *http.Request) { | |
var req request | |
if err := bindClose(r, &req); err != nil { | |
http.Error(w, err.Error(), http.StatusUnprocessableEntity) | |
return | |
} | |
// We'll want to respond with a succcess message in the event that | |
// a user is found with this email address or not, otherwise we're | |
// providing malicious users with a tool to check for the existence | |
// of accounts. | |
successMessage := fmt.Sprintf( | |
"a reset url has been emailed to %s, it will expire in %s", | |
req.Email, | |
s.tokenExpiry) | |
// Get user's ID to store alongside token. Checking this ensures | |
// that we're only storing reset tokens for valid email addresses. | |
row := s.db.QueryRow( | |
`select "id" from "user" where "email" = $1`, | |
req.Email) | |
var userID string | |
if err := row.Scan(&userID); err != nil { | |
respond(w, http.StatusOK, simpleResponse{Message: successMessage}) | |
return | |
} | |
// Generate and insert the token against the user. | |
token, err := token(s.tokenSize) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusUnprocessableEntity) | |
return | |
} | |
_, err = s.db.Exec( | |
`insert into "reset" ("user_id", "token", "expiry") values ($1, $2, $3)`, | |
userID, | |
token, | |
time.Now().UTC().Add(s.tokenExpiry)) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusUnprocessableEntity) | |
return | |
} | |
// Simulate reset URL being emailed. | |
log.Printf(s.resetURLFormat, token) | |
respond(w, http.StatusOK, simpleResponse{Message: successMessage}) | |
} | |
} | |
func reset(s *server) http.HandlerFunc { | |
type request struct { | |
Password string `json:"password"` | |
} | |
return func(w http.ResponseWriter, r *http.Request) { | |
// Get the reset token from the URL path. | |
token, ok := mux.Vars(r)["token"] | |
if !ok { | |
http.Error(w, "missing token", http.StatusUnprocessableEntity) | |
return | |
} | |
var req request | |
if err := bindClose(r, &req); err != nil { | |
http.Error(w, err.Error(), http.StatusUnprocessableEntity) | |
return | |
} | |
// Lookup the token and user's id. | |
row := s.db.QueryRow( | |
`select "id", "user_id" from "reset" where "token" = $1 and "expiry" < $2`, | |
token, | |
time.Now().UTC().Add(s.tokenExpiry)) | |
var tokenID string | |
var userID string | |
if err := row.Scan(&tokenID, &userID); err != nil { | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
// Hash the new password in the same way we did when the user | |
// registered. | |
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), s.hashCost) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
// Update the user's password. | |
_, err = s.db.Exec( | |
`update "user" set "password" = $1 where "id" = $2`, | |
hash, | |
userID) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
// Delete the reset token. | |
_, err = s.db.Exec( | |
`delete from "reset" where "id" = $1`, | |
tokenID) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
respond(w, http.StatusOK, simpleResponse{Message: "password reset"}) | |
} | |
} | |
/********* | |
* Crypto * | |
*********/ | |
// check takes a hashed password and a plaintext password and | |
// checks them for equality. If the password's don't match, false | |
// will be returned without an error, as this is an expected flow. | |
func check(hash, pw []byte) (bool, error) { | |
if err := bcrypt.CompareHashAndPassword(hash, pw); err != nil { | |
if err == bcrypt.ErrMismatchedHashAndPassword { | |
return false, nil | |
} | |
return false, err | |
} | |
return true, nil | |
} | |
var ( | |
tokenRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") | |
) | |
// token takes a token size and returns a pseudorandom token string. | |
// This function will crash the application if random data can't be | |
// generated for the token, or if enough random data isn't generated | |
// for the token. | |
func token(size int) (string, error) { | |
bytes := make([]byte, size) | |
if n, err := rand.Read(bytes); err != nil || n < size { | |
log.Fatalf("cryptographic prng not available: %v", err) | |
} | |
output := make([]rune, size) | |
for i := 0; i < size; i++ { | |
output[i] = tokenRunes[bytes[i]%byte(len(tokenRunes))] | |
} | |
return string(output), nil | |
} | |
/********** | |
* Helpers * | |
**********/ | |
// bindClose binds a given value to an HTTP request body, closing the | |
// request body on successul binding. | |
func bindClose(r *http.Request, val interface{}) error { | |
if err := json.NewDecoder(r.Body).Decode(val); err != nil { | |
return err | |
} | |
return r.Body.Close() | |
} | |
// respond marshals a given object to JSON and sends it to the | |
// ResponseWriter, falling back to writing an error message if this | |
// fails. Nothing should write to ResponseWriter after a call to | |
// this function is made. | |
func respond(w http.ResponseWriter, code int, val interface{}) { | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(code) | |
if err := json.NewEncoder(w).Encode(val); err != nil { | |
w.WriteHeader(http.StatusInternalServerError) | |
w.Header().Set("Content-Type", "text/plain") | |
fmt.Fprintln(w, "error writing response") | |
} | |
} | |
// sweep clears down the reset token table. At every sweepInterval, | |
// it will begin a process of deleting sweepLimit items until none | |
// are left, continually sweeping for sweepLimit items at every | |
// sweepOverflowInterval, until the table is clear. | |
// Note: I could have used Redis or similar, which provides TTLs for | |
// rows but to keep things simple, I'm doing everything in the | |
// database. Note that this method will only exit in the event of an | |
// error, so call in it a goroutine. | |
func sweep(s *server) { | |
const stmt = `delete from "reset" where "expiry" < $1 limit $2` | |
for range time.Tick(s.sweepInterval) { | |
for range time.NewTicker(s.sweepOverflowInterval).C { | |
result, err := s.db.Exec(stmt, time.Now().UTC(), s.sweepLimit) | |
if err != nil { | |
log.Printf("error deleting tokens: %v", err) | |
} | |
affected, err := result.RowsAffected() | |
if err != nil { | |
log.Printf("error deleting tokens: %v", err) | |
} | |
if affected < s.sweepLimit { | |
break | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
create table "user" ( | |
"id" uuid primary key default uuid_v4()::uuid, | |
"email" string not null, | |
"password" bytes not null, | |
unique ("email") | |
); | |
create table "reset" ( | |
"id" uuid primary key default uuid_v4()::uuid, | |
"user_id" uuid not null references "user" ("id"), | |
"token" string not null, | |
"expiry" timestamp not null | |
); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment