Skip to content

Instantly share code, notes, and snippets.

@wolfiestyle
Created February 1, 2015 23:50
Show Gist options
  • Save wolfiestyle/605aadeadc1db32ef07c to your computer and use it in GitHub Desktop.
Save wolfiestyle/605aadeadc1db32ef07c to your computer and use it in GitHub Desktop.
-- Micro web framework for writing REST applications.
-- Made by @wolfiestyle
-- Last modified: 2015-02-01
-- License: MIT/X11
local error, ipairs, math_random, next, pairs, require, select, setmetatable, string_char, table_concat, tonumber, tostring, type, unpack =
error, ipairs, math.random, next, pairs, require, select, setmetatable, string.char, table.concat, tonumber, tostring, type, unpack
local stderr = io.stderr
local Request = require "wsapi.request"
local Response = require "wsapi.response"
local json = require "dkjson"
local base64 = require "base64"
local crypto = require "crypto"
local pretty = require "pl.pretty"
local template = require "pl.template"
local _M = {}
local app = {}
app.__index = app
_M.app = app
--( Session )--
-- Configures and enables session support
function app:setup_session(key, name)
if not key then
error "setup_session: missing key"
end
local cfg = {
cookie = name or "session",
key = key,
}
self.config.session = cfg
end
local function hmac(data, key)
return crypto.hmac.digest("SHA1", data, key)
end
-- Encode and sign the session data
local function session_encode(obj, key)
local data = json.encode(obj)
local sig = hmac(data, key)
return base64.encode(data) .. ":" .. sig
end
-- Decode and verify the session data
local function session_decode(str, key)
local data, sig = str:match "^([^:]+):([^:]+)$"
if data == nil then
return nil, "error: can't decode session data"
end
data = base64.decode(data)
if data == nil then
return nil, "error: can't decode base64"
end
if hmac(data, key) == sig then
local res, _, err = json.decode(data)
if res == nil then
return nil, "error: " .. err
end
return res
else
return nil, "error: invalid session signature"
end
end
-- Initializes the session (called by `process_request`)
local function session_start(self, req)
if not self.config.session then
return
end
local cookie = req.cookies[self.config.session.cookie]
local session, err
if cookie ~= nil then
session, err = session_decode(cookie, self.config.session.key)
if err then
stderr:write(err)
end
end
if session == nil then
session = {}
end
req.session = session
end
-- Writes the session data to a cookie (called by `process_request`)
local function session_finish(self, resp, req)
if req.session then
local data = session_encode(req.session, self.config.session.key)
resp:set_cookie(self.config.session.cookie, { path = req.script_name, value = data })
elseif req.session == false then
resp:delete_cookie(self.config.session.cookie, req.script_name)
end
end
-- Deletes the current session
local function session_end(req)
req.session = false
end
--( Router )--
local function tonumber_hex(str)
return tonumber(str, 16)
end
-- Data types allowed in path variables
local arg_types = {
str = { re = "([^/]+)", conv = false },
int = { re = "([%+%-]?%d+)", conv = tonumber },
uint = { re = "(%d+)", conv = tonumber },
num = { re = "([%+%-]?%d*%.?%d+)", conv = tonumber },
hex = { re = "(%x+)", conv = tonumber_hex },
alnum = { re = "(%w+)", conv = false },
date = { re = "(%d%d%d%d%-%d%d%-%d%d)", conv = false },
}
-- Generates the regex for a path (like "/users/:int/profile")
local function parse_path(path)
local convs = {}
local path_re = path:gsub(":(%w+)", function(tname)
local timpl = arg_types[tname]
if not timpl then
error("parse_path: unknown argument type: " .. tname)
end
convs[#convs + 1] = timpl.conv
return timpl.re
end)
return "^" .. path_re .. "$", convs
end
-- The `route_item` object represents a path in the application (created by `app:path`)
-- This is the main interface of the routing system
local route_item = {}
route_item.__index = route_item
function route_item.new(path)
local path_re, convs = parse_path(path)
local self = {
path = path,
re = path_re,
convs = convs,
fn = {},
}
return setmetatable(self, route_item)
end
-- Adds a callback for the specificed HTTP method
function route_item:add_callback(method, callback)
local list = self.fn[method]
if list == nil then
list = {}
self.fn[method] = list
end
list[#list + 1] = callback
return self
end
local http_methods = {
"POST", -- Create
"GET", -- Read
"PUT", -- Update
"DELETE", -- Delete
}
-- create one function for each HTTP method (app:GET, ...)
for _, method in pairs(http_methods) do
route_item[method] = function(self, callback)
return self:add_callback(method, callback)
end
end
local function pack_args(convs, ...)
local n = select("#", ...)
if n == 1 and select(1, ...) == nil then
return nil
end
local args = {}
for i = 1, #convs do
local val = select(i, ...)
local fn = convs[i]
args[i] = fn and fn(val) or val
end
return n, args
end
-- Tests this route against the specified request url
function route_item:match(method, req_path)
local fns = self.fn[method]
if fns and next(fns) then
local n, args = pack_args(self.convs, req_path:match(self.re))
if n then
return fns, args
end
end
end
local function normalize_path(str)
if str:sub(1, 1) ~= "/" then
str = "/" .. str
end
if str ~= "/" and str:sub(-1) == "/" then
return str:sub(1, str:len()-1)
end
return str
end
-- Returns the route_item for the specified path.
-- This is the entry point for declaring app routes.
-- Example:
--
-- app:path "/users/:int/profile"
-- :GET(function(resp, req, user_id)
-- resp:write("user " .. user_id)
-- end)
--
function app:path(path)
path = normalize_path(path)
local item = self.routes.index[path]
if item == nil then
item = route_item.new(path)
self.routes[#self.routes + 1] = item
self.routes.index[path] = item
end
return item
end
-- Picks the first matching route for the specified request url
local function select_route(routes, method, req_path)
for _, item in ipairs(routes) do
local fns, args = item:match(method, req_path)
if fns then
return fns, args
end
end
end
-- Runs all the callbacks until one of them returns a true value
local function run_callbacks(fns, ...)
for _, fn in ipairs(fns) do
local stop = fn(...)
if stop then
break
end
end
end
-- Setups the environment and processes the request
local function process_request(self, resp, req)
local callbacks, args = select_route(self.routes, req.method, normalize_path(req.path_info))
if callbacks == nil then
self.routes.default(resp, req)
return resp:finish()
end
self:open_db()
session_start(self, req)
self.req = req
run_callbacks(callbacks, resp, req, unpack(args))
self.req = nil
session_finish(self, resp, req)
self:close_db()
return resp:finish()
end
--( App )--
-- Called when no route matches
local function default_route(resp)
resp.status = 404
resp:write "error: route not found"
end
-- Dumps a Lua object using `pl.pretty.write`
local function dump(resp, obj, only_ret)
local str = pretty.write(obj)
if resp.headers["Content-Type"] == "text/html" then
str = "<pre>" .. _M.sanitize(str) .. "</pre>"
end
if not only_ret then
resp:write(str)
end
return str
end
-- Processes and outputs a template using `pl.template`
-- It uses `_brackets = "{}"` to prevent clash with jquery's $(...) syntax
local function write_template(resp, tmpl)
resp:write(template.substitute(tmpl, resp.stash))
end
-- Causes a HTTP redirect
local function redirect(resp, url)
resp.status = 303
resp.headers.Location = url
resp:write(url)
end
-- Entry point for the WSAPI application
local function wsapi_main(app, env)
local req = Request.new(env)
req.session_end = session_end
local resp = Response.new()
resp.headers["X-Powered-By"] = _VERSION
resp.dump = dump
resp.write_template = write_template
resp.redirect = redirect
resp.redirect_app = function(_resp, path)
return _resp:redirect(req:link(path))
end
-- used for templates and general storage during the request
resp.stash = { req = req, _brackets = "{}" }
return process_request(app, resp, req)
end
function app.new()
local self = {
routes = {
index = {},
default = default_route,
},
config = {},
}
function self.run(env)
return wsapi_main(self, env)
end
return setmetatable(self, app)
end
--( Database )--
-- Configure and enable database support
function app:setup_db(args)
if type(args) ~= "table" or args.driver == nil or args.dbname == nil then
error("setup_db: invalid argument")
end
self.config.db = args
end
-- Connects to the database using `DBI`
function app:open_db()
local cfg = self.config.db
if cfg == nil then
return
end
local db, err = require("DBI").Connect(cfg.driver, cfg.dbname, cfg.user, cfg.pass, cfg.host, cfg.port)
if db == nil then
error("open_db: " .. err)
end
db:autocommit(true)
self.db = db
return db
end
function app:close_db()
if self.db then
self.db:close()
self.db = nil
end
end
--( Utils )---
-- Creates functions to wrap content around HTML tags ( tag.h1"title"; tag.td{1, 2, 3} )
_M.tag = setmetatable({}, {
__index = function(self, name)
local open, close = "<" .. name .. ">", "</" .. name .. ">"
local middle = close .. open
local impl = function(content)
if type(content) == "table" then
content = table_concat(content, middle)
else
content = tostring(content)
end
return open .. content .. close
end
self[name] = impl
return impl
end
})
-- Initializes the RNG
function _M.randomize()
math.randomseed(os.time() + tonumber_hex(tostring(_G):match("(%x+)$")))
end
-- Generates a random string of printable characters
function _M.random_string(len)
local buf = {}
for i = 1, len do
buf[i] = math_random(32, 126)
end
return string_char(unpack(buf))
end
-- Replaces special HTML characters with entities
function _M.sanitize(str)
return str:gsub("[<>&]", {
["<"] = "&lt;",
[">"] = "&gt;",
["&"] = "&amp;",
})
end
-- Parses and validates POST data fields
function _M.parse_fields(fields_decl, data)
local result = {}
for name, conv in pairs(fields_decl) do
local val = data[name]
if val == nil then
return nil, "error: incomplete request"
end
local parsed, err
local t_conv = type(conv)
if conv == true then
parsed = val
elseif t_conv == "string" then
parsed = val:match(conv)
elseif t_conv == "function" then
parsed, err = conv(val)
else
error "parse_fields: invalid conversion in table"
end
if parsed == nil then
return nil, err or "error: invalid request"
end
result[name] = parsed
end
return result
end
_M.randomize()
return _M
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment