Skip to content

Instantly share code, notes, and snippets.

@Hunsin
Last active August 26, 2023 16:23
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save Hunsin/26b2021757e831554d4f59a52a5c9152 to your computer and use it in GitHub Desktop.
Save Hunsin/26b2021757e831554d4f59a52a5c9152 to your computer and use it in GitHub Desktop.
The package wraps julienschmidt's httprouter, making it support functions such as middlewares, sub/group routing with same prefix. Written in Go (Golang).
package router
import (
"context"
"net/http"
gpath "path"
"github.com/julienschmidt/httprouter"
)
// Param returns the named URL parameter from a request context.
func Param(ctx context.Context, name string) string {
if p := httprouter.ParamsFromContext(ctx); p != nil {
return p.ByName(name)
}
return ""
}
// A Middleware chains http.Handlers.
type Middleware func(http.Handler) http.Handler
// A Router is a http.Handler which supports routing and middlewares.
type Router struct {
middlewares []Middleware
path string
root *httprouter.Router
}
// New creates a new Router.
func New() *Router {
return &Router{root: httprouter.New(), path: "/"}
}
// Group returns a new Router with given path and middlewares.
// It should be used for handlers which have same path prefix or
// common middlewares.
func (r *Router) Group(path string, m ...Middleware) *Router {
return &Router{
middlewares: append(m, r.middlewares...),
path: gpath.Join(r.path, path),
root: r.root,
}
}
// Use appends new middlewares to current Router.
func (r *Router) Use(m ...Middleware) *Router {
r.middlewares = append(m, r.middlewares...)
return r
}
// Handle registers a new request handler combined with middlewares.
func (r *Router) Handle(method, path string, handler http.Handler) {
for _, v := range r.middlewares {
handler = v(handler)
}
r.root.Handler(method, gpath.Join(r.path, path), handler)
}
// GET is a shortcut for r.Handle("GET", path, handler)
func (r *Router) GET(path string, handler http.HandlerFunc) {
r.Handle(http.MethodGet, path, handler)
}
// HEAD is a shortcut for r.Handle("HEAD", path, handler)
func (r *Router) HEAD(path string, handler http.HandlerFunc) {
r.Handle(http.MethodHead, path, handler)
}
// OPTIONS is a shortcut for r.Handle("OPTIONS", path, handler)
func (r *Router) OPTIONS(path string, handler http.HandlerFunc) {
r.Handle(http.MethodOptions, path, handler)
}
// POST is a shortcut for r.Handle("POST", path, handler)
func (r *Router) POST(path string, handler http.HandlerFunc) {
r.Handle(http.MethodPost, path, handler)
}
// PUT is a shortcut for r.Handle("PUT", path, handler)
func (r *Router) PUT(path string, handler http.HandlerFunc) {
r.Handle(http.MethodPut, path, handler)
}
// PATCH is a shortcut for r.Handle("PATCH", path, handler)
func (r *Router) PATCH(path string, handler http.HandlerFunc) {
r.Handle(http.MethodPatch, path, handler)
}
// DELETE is a shortcut for r.Handle("DELETE", path, handler)
func (r *Router) DELETE(path string, handler http.HandlerFunc) {
r.Handle(http.MethodDelete, path, handler)
}
// HandleFunc is an adapter for http.HandlerFunc.
func (r *Router) HandleFunc(method, path string, handler http.HandlerFunc) {
r.Handle(method, path, handler)
}
// NotFound sets the handler which is called if the request path doesn't match
// any routes. It overwrites the previous setting.
func (r *Router) NotFound(handler http.Handler) {
r.root.NotFound = handler
}
// Static serves files from given root directory.
func (r *Router) Static(path, root string) {
if len(path) < 10 || path[len(path)-10:] != "/*filepath" {
panic("path should end with '/*filepath' in path '" + path + "'.")
}
base := gpath.Join(r.path, path[:len(path)-9])
fileServer := http.StripPrefix(base, http.FileServer(http.Dir(root)))
r.Handle(http.MethodGet, path, fileServer)
}
// File serves the named file.
func (r *Router) File(path, name string) {
r.HandleFunc(http.MethodGet, path, func(w http.ResponseWriter, req *http.Request) {
http.ServeFile(w, req, name)
})
}
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.root.ServeHTTP(w, req)
}
package router
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestHandle(t *testing.T) {
router := New()
h := func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
}
router.Handle("GET", "/", http.HandlerFunc(h))
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Error("Test Handle failed")
}
}
func TestHandleFunc(t *testing.T) {
router := New()
h := func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
}
router.HandleFunc("GET", "/", h)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Error("Test HandlerFunc failed")
}
}
func TestMethod(t *testing.T) {
router := New()
router.DELETE("/delete", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
router.GET("/get", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
router.HEAD("/head", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
router.OPTIONS("/options", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
router.PATCH("/patch", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
router.POST("/post", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
router.PUT("/put", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
samples := map[string]string{
"DELETE": "/delete",
"GET": "/get",
"HEAD": "/head",
"OPTIONS": "/options",
"PATCH": "/patch",
"POST": "/post",
"PUT": "/put",
}
for method, path := range samples {
r := httptest.NewRequest(method, path, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Errorf("Path %s not registered", path)
}
}
}
func TestGroup(t *testing.T) {
router := New()
foo := router.Group("/foo")
bar := router.Group("/bar")
baz := foo.Group("/baz")
foo.HandleFunc("GET", "", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
foo.HandleFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
bar.HandleFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
baz.HandleFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
samples := []string{"/foo", "/foo/group", "/foo/baz/group", "/bar/group"}
for _, path := range samples {
r := httptest.NewRequest("GET", path, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Errorf("Grouped path %s not registered", path)
}
}
}
func TestMiddleware(t *testing.T) {
var use, group bool
router := New().Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
use = true
next.ServeHTTP(w, r)
})
})
foo := router.Group("/foo", func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
group = true
next.ServeHTTP(w, r)
})
})
foo.HandleFunc("GET", "/bar", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
r := httptest.NewRequest("GET", "/foo/bar", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if !use {
t.Error("Middleware registered by Use() under \"/\" not touched")
}
if !group {
t.Error("Middleware registered by Group() under \"/foo\" not touched")
}
}
func TestStatic(t *testing.T) {
files := []string{"temp_1", "temp_2"}
strs := []string{"test content", "static contents"}
for i := range files {
f, _ := os.Create(files[i])
defer os.Remove(files[i])
f.WriteString(strs[i])
f.Sync()
f.Close()
}
pwd, _ := os.Getwd()
router := New()
router.Static("/*filepath", pwd)
for i := range files {
r := httptest.NewRequest("GET", "/"+files[i], nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
body := w.Result().Body
defer body.Close()
file, _ := ioutil.ReadAll(body)
if string(file) != strs[i] {
t.Error("Test Static failed")
}
}
}
func TestFile(t *testing.T) {
str := "test_content"
f, _ := os.Create("temp_file")
defer os.Remove("temp_file")
f.WriteString(str)
f.Sync()
f.Close()
router := New()
router.File("/file", "temp_file")
r := httptest.NewRequest("GET", "/file", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
body := w.Result().Body
defer body.Close()
file, _ := ioutil.ReadAll(body)
if string(file) != str {
t.Error("Test File failed")
}
}
@saeid-ir
Copy link

Hello @Hunsin
sorry for late
here is my code

package main

import (
	"github.com/julienschmidt/httprouter"
	"./router"
	"net/http"
	"log"
)

func main() {
	r := httprouter.New()

	mid1 := func(handle httprouter.Handle) httprouter.Handle {
		return func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
			log.Print("Hello 1")
			r.ServeHTTP(writer, request)
		}
	}
	mid2 := func(handle httprouter.Handle) httprouter.Handle {
		return func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
			log.Print("Hello 2")
			r.ServeHTTP(writer, request)
		}
	}

	v1 := router.NewGroup(r, "/v1")
	v2 := router.NewGroup(r, "/v2")
	v21 := v2.Group("/v1", mid1, mid2)

	r.GET("/", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
		writer.Write([]byte("Hello from Index"))
	})
	v1.GET("/test", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
		writer.Write([]byte("Hello from v1"))
	})
	v2.GET("/test", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
		writer.Write([]byte("Hello from v2"))
	})
	v21.GET("/test", func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
		writer.Write([]byte("Hello from v21"))
	})

	log.Fatal(http.ListenAndServe(":8000", r))
}

and the router is exactly as you write above

@Hunsin
Copy link
Author

Hunsin commented Jan 5, 2018

Hello @saeidakbari

I found Github doesn't send notification here. So, feel free to contact me at: asky[dot]hunsin[at]gmail[dot]com

First

I modified the code again. I change func NewRouter() to New(). But it doesn't matter in your case, just for your convenience.

Second

It is recommend to put your source code in this structure:

gopath <-- the folder where environment variable GOPATH is set to
 └─ src
     ├─ github.com
     │   └─ saeidakbari
     │       ├─ router
     │       │   └─ router.go <-- package router
     │       └─ project
     │           └─ main.go   <-- your main package
     └─ others

Then you can import router package in main.go

import (
    // ...others
    "github.com/saeidakbari/router"
)

For more information please check here

Third

Declare r variable with router.New(). httprouter.Router doesn't support middleware and groups but router.Router can.

r := router.New()

The reason your code stuck in loop is because you call r.ServeHTTP in the middlewares.
Instead, you need to call handle(writer, request, params).

mid1 := func(handle httprouter.Handle) httprouter.Handle {
    return func(writer http.ResponseWriter, request *http.Request, params httprouter.Params) {
        log.Print("Hello 1")
        handle(writer, request, params)
    }
}

And you can see the final code here!
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment