Skip to content

Instantly share code, notes, and snippets.

@shaunlee
Last active September 29, 2017 08:11
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save shaunlee/8895120 to your computer and use it in GitHub Desktop.
Save shaunlee/8895120 to your computer and use it in GitHub Desktop.
Simple RESTful web dispatcher
package main
import (
"fmt"
"log"
"net/http"
"regexp"
"strings"
)
const DEFAULT_MAX_MEMORY = 32 << 20 // equals to http.defaultMaxMemory
var (
RE_URL_PATTERNS = regexp.MustCompile(`<(((int|float|path):)?(\w+))>`)
URL_DEFAULT_PATTERN = `(?P<%s>[^/]+)`
URL_INT_PATTERN = `(?P<%s>\d+)`
URL_FLOAT_PATTERN = `(?P<%s>\d+\.\d+)`
URL_PATH_PATTERN = `(?P<%s>.+)`
)
type HandlerFunc func(http.ResponseWriter, *http.Request) bool
type PatternMethod map[*regexp.Regexp]http.HandlerFunc
type WebHandler struct {
patterns map[string]PatternMethod
errorHandlers map[int]http.HandlerFunc
}
func NewWebHandler() *WebHandler {
p := &WebHandler{
patterns: make(map[string]PatternMethod),
errorHandlers: make(map[int]http.HandlerFunc),
}
return p
}
// /<field>/<int:field>/<float:field>/<path:field>
func (p *WebHandler) register(method, pattern string, fn ...HandlerFunc) {
if _, ok := p.patterns[method]; !ok {
p.patterns[method] = make(PatternMethod)
}
var replaceTo string
for _, matches := range RE_URL_PATTERNS.FindAllStringSubmatch(pattern, -1) {
switch matches[3] {
case "int":
replaceTo = fmt.Sprintf(URL_INT_PATTERN, matches[4])
case "float":
replaceTo = fmt.Sprintf(URL_FLOAT_PATTERN, matches[4])
case "path":
replaceTo = fmt.Sprintf(URL_PATH_PATTERN, matches[4])
default:
replaceTo = fmt.Sprintf(URL_DEFAULT_PATTERN, matches[4])
}
pattern = strings.Replace(pattern, matches[0], replaceTo, 1)
}
p.patterns[method][regexp.MustCompile(fmt.Sprintf(`^%s$`, pattern))] = func(w http.ResponseWriter, r *http.Request) {
for _, f := range fn {
if ok := f(w, r); !ok {
break
}
}
}
}
func (p *WebHandler) Get(pattern string, fn ...HandlerFunc) { p.register("GET", pattern, fn...) }
func (p *WebHandler) Post(pattern string, fn ...HandlerFunc) { p.register("POST", pattern, fn...) }
func (p *WebHandler) Put(pattern string, fn ...HandlerFunc) { p.register("PUT", pattern, fn...) }
func (p *WebHandler) Delete(pattern string, fn ...HandlerFunc) { p.register("DELETE", pattern, fn...) }
func (p *WebHandler) All(pattern string, methods []string, fn ...HandlerFunc) {
if methods == nil {
methods = []string{"GET", "POST", "PUT", "DELETE"}
}
for _, method := range methods {
p.register(method, pattern, fn...)
}
}
func (p *WebHandler) Handle(code int, fn http.HandlerFunc) {
p.errorHandlers[code] = fn
}
func (p *WebHandler) Raise(w http.ResponseWriter, r *http.Request, err string, code int) {
if fn, ok := p.errorHandlers[code]; ok {
fn(w, r)
} else {
http.Error(w, err, code)
}
}
func (p *WebHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
// TODO: let it go while debugging?
if err := recover(); err != nil {
log.Println(err)
p.Raise(w, r, "500 internal server error", http.StatusInternalServerError)
}
}()
if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") {
r.ParseMultipartForm(DEFAULT_MAX_MEMORY)
} else {
r.ParseForm()
}
var (
found = false
url = r.URL.RequestURI()
)
if n := strings.Index(url, "?"); n > -1 {
url = url[:n]
}
if patterns, ok := p.patterns[r.Method]; ok {
for pattern, fn := range patterns {
if values := pattern.FindStringSubmatch(url); len(values) > 0 {
for i, field := range pattern.SubexpNames() {
if field != "" {
r.Form.Set(field, values[i])
}
}
fn(w, r)
found = true
break
}
}
}
if !found {
p.Raise(w, r, "404 page not found", http.StatusNotFound)
}
}
func (p *WebHandler) Run(addr string) error {
return http.ListenAndServe(addr, p)
}
func main() {
app := NewWebHandler()
app.Handle(http.StatusInternalServerError, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal server error!!!", http.StatusInternalServerError)
})
app.Handle(http.StatusNotFound, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Page not found!!!", http.StatusNotFound)
})
app.Get(`/`, func(w http.ResponseWriter, r *http.Request) bool {
fmt.Fprintf(w, "Middleware")
if r.FormValue("next") == "1" {
return true
}
return false
}, func(w http.ResponseWriter, r *http.Request) bool {
fmt.Fprintf(w, "Home")
return true
})
app.All(`/blog/<int:id>/`, nil, func(w http.ResponseWriter, r *http.Request) bool {
fmt.Fprintf(w, "Blog %s", r.FormValue("id"))
return true
})
app.Post(`/test/<name>/<int:id>/<float:lat>/<path:filename>`, func(w http.ResponseWriter, r *http.Request) bool {
fmt.Fprintf(w, "Test %s %s %s %s", r.FormValue("id"), r.FormValue("name"), r.FormValue("lat"), r.FormValue("filename"))
return true
})
log.Println("Listening on 127.0.0.1:8080 ...")
log.Fatal(app.Run("127.0.0.1:8080"))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment