Skip to content

Instantly share code, notes, and snippets.

@philippta
Last active August 13, 2020 18:32
Show Gist options
  • Save philippta/41b34e2ab7a7e5f4b0143151501b62e0 to your computer and use it in GitHub Desktop.
Save philippta/41b34e2ab7a7e5f4b0143151501b62e0 to your computer and use it in GitHub Desktop.
Cleaner Go API example with access policies
package main
import (
"encoding/json"
"errors"
"net/http"
"github.com/go-chi/chi"
)
type User struct {
ID string
}
type Todo struct {
UserID string
Task string
Done bool
}
var users = map[string]User{
"1": {
ID: "1",
},
"2": {
ID: "2",
},
}
var todos = map[string]Todo{
"1": {
UserID: "1",
Task: "hello",
Done: false,
},
"2": {
UserID: "2",
Task: "hello",
Done: false,
},
}
type userTodoHandlerFunc func(http.ResponseWriter, *http.Request, *User, *Todo)
type userTodosHandlerFunc func(http.ResponseWriter, *http.Request, *User, map[string]Todo)
func parseUser(r *http.Request) (*User, error) {
id := r.URL.Query().Get("user")
user, ok := users[id]
if !ok {
return nil, errors.New("user not found")
}
return &user, nil
}
func parseTodo(r *http.Request) (*Todo, error) {
id := chi.URLParam(r, "todoID")
todo, ok := todos[id]
if !ok {
return nil, errors.New("todo not found")
}
return &todo, nil
}
func withUserAndTodo(h userTodoHandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, _ := parseUser(r)
todo, err := parseTodo(r)
if err != nil {
http.NotFound(w, r)
return
}
h(w, r, user, todo)
}
}
func withUserAndTodos(h userTodosHandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, _ := parseUser(r)
h(w, r, user, todos)
}
}
func canViewTodos(h userTodosHandlerFunc) userTodosHandlerFunc {
return func(w http.ResponseWriter, r *http.Request, user *User, todos map[string]Todo) {
if user == nil {
http.Error(w, "403 forbidden", http.StatusForbidden)
return
}
h(w, r, user, todos)
}
}
func canViewTodo(h userTodoHandlerFunc) userTodoHandlerFunc {
return func(w http.ResponseWriter, r *http.Request, user *User, todo *Todo) {
if user == nil {
http.Error(w, "403 forbidden", http.StatusForbidden)
return
}
if todo == nil || todo.UserID != user.ID {
http.Error(w, "403 forbidden", http.StatusForbidden)
return
}
h(w, r, user, todo)
}
}
func index(w http.ResponseWriter, r *http.Request, user *User, todos map[string]Todo) {
json.NewEncoder(w).Encode(todos)
}
func view(w http.ResponseWriter, r *http.Request, user *User, todo *Todo) {
json.NewEncoder(w).Encode(todo)
}
func main() {
r := chi.NewRouter()
// http://localhost:8081/todos -> Forbidden
// http://localhost:8081/todos?user=1 -> OK
r.Get("/todos", withUserAndTodos(canViewTodos(index)))
// http://localhost:8081/todos/1?user=1 -> OK
// http://localhost:8081/todos/2?user=2 -> OK
// http://localhost:8081/todos/1?user=2 -> Forbidden
// http://localhost:8081/todos/2?user=1 -> Forbidden
// http://localhost:8081/todos/3?user=1 -> Not found
r.Get("/todos/{todoID}", withUserAndTodo(canViewTodo(view)))
http.ListenAndServe(":8081", r)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment