Skip to content

Instantly share code, notes, and snippets.

@codingconcepts codingconcepts/main.go
Last active Dec 17, 2019

Embed
What would you like to do?
package main
import (
"crypto/rand"
"database/sql"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
)
func main() {
// Open a connection to the database and test it. Stopping the
// application if either fails.
db, err := sql.Open("postgres", "postgres://root@localhost:26257/defaultdb?sslmode=disable")
if err != nil {
log.Fatalf("error opening database connection: %v", err)
}
if err = db.Ping(); err != nil {
log.Fatalf("error testing database connection: %v", err)
}
defer db.Close()
// Configure some sensible config defaults.
s := &server{
db: db,
hashCost: 14,
tokenSize: 80,
tokenExpiry: time.Hour,
resetURLFormat: "http://localhost:1234/reset/%s",
sweepLimit: 1000,
sweepInterval: time.Minute * 30,
sweepOverflowInterval: time.Second * 1,
}
// Spin up an HTTP router and register the handlers we'll need for
// this example. /register to create new users, /login to demonstrate
// that passwords are being changed, /forgot to request a new password,
// and /reset to configure a new password.
handler := mux.NewRouter()
handler.HandleFunc("/register", register(s)).Methods(http.MethodPost)
handler.HandleFunc("/login", login(s)).Methods(http.MethodPost)
handler.HandleFunc("/forgot", forgot(s)).Methods(http.MethodPost)
handler.HandleFunc("/reset/{token}", reset(s)).Methods(http.MethodPost)
// Configure an HTTP server with some sensible timeouts.
hs := &http.Server{
Addr: ":1234",
Handler: handler,
ReadHeaderTimeout: time.Second * 2,
ReadTimeout: time.Second * 2,
WriteTimeout: time.Second * 3,
IdleTimeout: time.Second * 10,
}
// Start the sweep worker.
go sweep(s)
// Start the HTTP server and log the error it returns (if any).
log.Fatal(hs.ListenAndServe())
}
// server holds the configuration variables required by the various
// endpoints and helpful functions of the application.
type server struct {
// The database connection pool.
db *sql.DB
// The complexity of of the bcrypt hash.
hashCost int
// The size of the reset token that will be sent to the user.
// Bigger tokens offer better security.
tokenSize int
// The time after which reset tokens will no longer be valid.
tokenExpiry time.Duration
// The format of the reset URL that will be sent to the user.
resetURLFormat string
// The amount of time between token table clear downs.
sweepInterval time.Duration
// The amount of time we'll give the database between batch
// deletions of reset tokens if there are more than sweepLimit
// tokens to delete.
sweepOverflowInterval time.Duration
// The number of tokens to delete in any one go. Larger numbers
// will delete more tokens and lock the reset token for longer.
sweepLimit int64
}
// simpleResponse is shared by a number of endpoints to return a simple
// text-based message to the user.
type simpleResponse struct {
Message string `json:"message"`
}
func register(s *server) http.HandlerFunc {
type request struct {
Email string `json:"email"`
Password string `json:"password"`
}
type response struct {
ID string `json:"id"`
}
return func(w http.ResponseWriter, r *http.Request) {
// Bind to the user's request object.
var req request
if err := bindClose(r, &req); err != nil {
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
// Hash the user's plaintext password.
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), s.hashCost)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Insert to the user into the database and fetch their ID.
row := s.db.QueryRow(
`insert into "user" ("email", "password") values ($1, $2) returning id`,
req.Email,
hash)
var id string
if err := row.Scan(&id); err != nil {
if pqerr, ok := err.(*pq.Error); ok && pqerr.Code == "23505" {
respond(w, http.StatusConflict, simpleResponse{Message: "email address already registered"})
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
respond(w, http.StatusOK, response{ID: id})
}
}
func login(s *server) http.HandlerFunc {
type request struct {
Email string `json:"email"`
Password string `json:"password"`
}
return func(w http.ResponseWriter, r *http.Request) {
var req request
if err := bindClose(r, &req); err != nil {
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
// Lookup the user's hashed password from the database.
row := s.db.QueryRow(
`select "password" from "user" where "email" = $1`,
req.Email)
var password []byte
if err := row.Scan(&password); err != nil {
respond(w, http.StatusUnauthorized, simpleResponse{Message: "invalid email or password"})
return
}
// Compare the user's stored password against the one they've
// just provided at login.
ok, err := check(password, []byte(req.Password))
if err != nil {
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
if !ok {
respond(w, http.StatusUnauthorized, simpleResponse{Message: "invalid email or password"})
return
}
respond(w, http.StatusOK, simpleResponse{Message: "implement JWTs"})
}
}
func forgot(s *server) http.HandlerFunc {
type request struct {
Email string `json:"email"`
}
return func(w http.ResponseWriter, r *http.Request) {
var req request
if err := bindClose(r, &req); err != nil {
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
// We'll want to respond with a succcess message in the event that
// a user is found with this email address or not, otherwise we're
// providing malicious users with a tool to check for the existence
// of accounts.
successMessage := fmt.Sprintf(
"a reset url has been emailed to %s, it will expire in %s",
req.Email,
s.tokenExpiry)
// Get user's ID to store alongside token. Checking this ensures
// that we're only storing reset tokens for valid email addresses.
row := s.db.QueryRow(
`select "id" from "user" where "email" = $1`,
req.Email)
var userID string
if err := row.Scan(&userID); err != nil {
respond(w, http.StatusOK, simpleResponse{Message: successMessage})
return
}
// Generate and insert the token against the user.
token, err := token(s.tokenSize)
if err != nil {
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
_, err = s.db.Exec(
`insert into "reset" ("user_id", "token", "expiry") values ($1, $2, $3)`,
userID,
token,
time.Now().UTC().Add(s.tokenExpiry))
if err != nil {
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
// Simulate reset URL being emailed.
log.Printf(s.resetURLFormat, token)
respond(w, http.StatusOK, simpleResponse{Message: successMessage})
}
}
func reset(s *server) http.HandlerFunc {
type request struct {
Password string `json:"password"`
}
return func(w http.ResponseWriter, r *http.Request) {
// Get the reset token from the URL path.
token, ok := mux.Vars(r)["token"]
if !ok {
http.Error(w, "missing token", http.StatusUnprocessableEntity)
return
}
var req request
if err := bindClose(r, &req); err != nil {
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
// Lookup the token and user's id.
row := s.db.QueryRow(
`select "id", "user_id" from "reset" where "token" = $1 and "expiry" < $2`,
token,
time.Now().UTC().Add(s.tokenExpiry))
var tokenID string
var userID string
if err := row.Scan(&tokenID, &userID); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Hash the new password in the same way we did when the user
// registered.
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), s.hashCost)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Update the user's password.
_, err = s.db.Exec(
`update "user" set "password" = $1 where "id" = $2`,
hash,
userID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Delete the reset token.
_, err = s.db.Exec(
`delete from "reset" where "id" = $1`,
tokenID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
respond(w, http.StatusOK, simpleResponse{Message: "password reset"})
}
}
/*********
* Crypto *
*********/
// check takes a hashed password and a plaintext password and
// checks them for equality. If the password's don't match, false
// will be returned without an error, as this is an expected flow.
func check(hash, pw []byte) (bool, error) {
if err := bcrypt.CompareHashAndPassword(hash, pw); err != nil {
if err == bcrypt.ErrMismatchedHashAndPassword {
return false, nil
}
return false, err
}
return true, nil
}
var (
tokenRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
)
// token takes a token size and returns a pseudorandom token string.
// This function will crash the application if random data can't be
// generated for the token, or if enough random data isn't generated
// for the token.
func token(size int) (string, error) {
bytes := make([]byte, size)
if n, err := rand.Read(bytes); err != nil || n < size {
log.Fatalf("cryptographic prng not available: %v", err)
}
output := make([]rune, size)
for i := 0; i < size; i++ {
output[i] = tokenRunes[bytes[i]%byte(len(tokenRunes))]
}
return string(output), nil
}
/**********
* Helpers *
**********/
// bindClose binds a given value to an HTTP request body, closing the
// request body on successul binding.
func bindClose(r *http.Request, val interface{}) error {
if err := json.NewDecoder(r.Body).Decode(val); err != nil {
return err
}
return r.Body.Close()
}
// respond marshals a given object to JSON and sends it to the
// ResponseWriter, falling back to writing an error message if this
// fails. Nothing should write to ResponseWriter after a call to
// this function is made.
func respond(w http.ResponseWriter, code int, val interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(val); err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, "error writing response")
}
}
// sweep clears down the reset token table. At every sweepInterval,
// it will begin a process of deleting sweepLimit items until none
// are left, continually sweeping for sweepLimit items at every
// sweepOverflowInterval, until the table is clear.
// Note: I could have used Redis or similar, which provides TTLs for
// rows but to keep things simple, I'm doing everything in the
// database. Note that this method will only exit in the event of an
// error, so call in it a goroutine.
func sweep(s *server) {
const stmt = `delete from "reset" where "expiry" < $1 limit $2`
for range time.Tick(s.sweepInterval) {
for range time.NewTicker(s.sweepOverflowInterval).C {
result, err := s.db.Exec(stmt, time.Now().UTC(), s.sweepLimit)
if err != nil {
log.Printf("error deleting tokens: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
log.Printf("error deleting tokens: %v", err)
}
if affected < s.sweepLimit {
break
}
}
}
}
create table "user" (
"id" uuid primary key default uuid_v4()::uuid,
"email" string not null,
"password" bytes not null,
unique ("email")
);
create table "reset" (
"id" uuid primary key default uuid_v4()::uuid,
"user_id" uuid not null references "user" ("id"),
"token" string not null,
"expiry" timestamp not null
);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.