Skip to content

Instantly share code, notes, and snippets.

@littletsu
Last active June 10, 2024 00:17
Show Gist options
  • Save littletsu/09b9d3bf581b759e1dfbf40d2275df29 to your computer and use it in GitHub Desktop.
Save littletsu/09b9d3bf581b759e1dfbf40d2275df29 to your computer and use it in GitHub Desktop.
Websocket and HTTP server for obs lua scripting
-- Based on https://github.com/stonetoad/obs-lua-httpd
local obs = obslua
-- From https://github.com/stonetoad/obs-lua-httpd/blob/e1c167f6c5231e605cf8531750153e728765f587/ljsocket.lua
local socket = require("ljsocket")
-- From https://gist.githubusercontent.com/PedroAlvesV/872a108f187f57c2a5b7b5bc34398496/raw/4ee8e36c9ee4b55a3d6bef768258ec8f9c6c3bc2/sha1.lua
local sha1 = require("sha1")
local bit = require("bit")
-- From https://devforum.roblox.com/t/base64-encoding-and-decoding-in-lua/1719860
function to_base64(data)
local b = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/'
return ((data:gsub('.', function(x)
local r,b='',x:byte()
for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end
return r;
end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x)
if (#x < 6) then return '' end
local c=0
for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end
return b:sub(c+1,c+1)
end)..({ '', '==', '=' })[#data%3+1])
end
function b64_sha1(str)
return to_base64(sha1_binary(str))
end
function debug_print(str)
print(str)
end
local server = nil
function script_load()
server = open_server(
"127.0.0.1", 12464, "/ws",
function(client, text, wss)
if text == "hi" then
wss.send_text(client, "Hi bro :3!")
end
end,
{
["/"] = function()
return "<h1>Index</h1>"
end
}
).socket
end
function script_unload()
if server then
server:close()
end
end
function split(inputstr, sep)
if sep == nil then sep = "%s" end
local t = {}
for str in string.gmatch(inputstr, "([^" .. sep .. "]+)") do
table.insert(t, str)
end
return t
end
function read_header_list_includes(value, str)
local elements = split(value, ",")
debug_print(value, str)
for _,element in pairs(elements) do
if element:sub(1,1) == " " then
element = element:sub(2)
end
if element:sub(-1) == " " then
element = element:sub(1,-2)
end
debug_print(element)
if element == str then
return true
end
end
return false
end
function open_server(bind_host, port, wspath, on_text, routes)
local debug = false
local poll_interval = 500 -- millisecond poll interval
local poll_interval_fast = 30
local max_fast_idle = 20
local sock = assert(socket.create("inet", "stream", "tcp"))
print(bind_host)
assert(sock:set_blocking(false)) -- critical! don't hang obs UI!
assert(sock:set_option("reuseaddr", true))
assert(sock:bind(bind_host, port))
assert(sock:listen())
local server = {}
local connected_ws = {}
function server.close()
if sock then
assert(sock:close())
sock = nil
end
end
function do_slow_poll()
do_poll()
end
local fast_poll = false
local idle_count = 0
function do_fast_poll()
idle_count = idle_count + 1
if idle_count > max_fast_idle then
obs.remove_current_callback()
fast_poll = false
else
do_poll()
end
end
function do_poll()
if sock == nil then
obs.remove_current_callback()
return
end
local client, err, errno = sock:accept()
if client and client:is_connected() then
idle_count = 0
if not fast_poll then
fast_poll = true
obs.timer_add(do_fast_poll, poll_interval_fast)
end
debug_print("Got client " .. tostring(client))
debug_print("\tName is " .. tostring(client:get_name()))
debug_print("\tPeername is " .. tostring(client:get_peer_name()))
assert(client:set_blocking(false)) -- critical! don't hang obs UI!
do_request(client)
elseif err ~= "timeout" then
error(err)
end
for k,client in pairs(connected_ws) do
if not client:is_connected() then
connected_ws[k] = nil
else
local request_raw, err = client:receive()
if not request_raw then
if err ~= "timeout" then
-- ws probably closed
connected_ws[k] = nil
end
else
do_ws_request(client, request_raw, k)
end
end
end
end
function send_ws(client, op, payload)
local fin = 1
local rsv1 = 0
local rsv2 = 0
local rsv3 = 0
local first = bit.bor(
bit.lshift(fin, 7),
bit.lshift(rsv1, 6),
bit.lshift(rsv2, 5),
bit.lshift(rsv3, 4),
op
)
if #payload > 125 then
print("Unimplemented payload length " .. tostring(#payload))
return
end
client:send(string.char(first, #payload) .. payload)
end
function send_text_ws(client, text)
send_ws(client, 1, text)
end
function broadcast(op, payload)
for k,client in pairs(connected_ws) do
send_ws(client, op, payload)
end
end
local wss = {
socket = sock,
broadcast = broadcast,
broadcast_text = function(text)
return broadcast(1, text)
end,
send_text = send_text_ws,
send = send_ws
}
function do_ws_request(client, data, id)
local first = data:byte()
local fin = bit.band(first, 128)
local rsv1 = bit.band(first, 64)
local rsv2 = bit.band(first, 32)
local rsv3 = bit.band(first, 16)
local opcode = bit.band(first, 15)
local second = data:byte(2)
local masked = bit.band(second, 128)
local len = bit.band(second, 127)
if masked == 0 then
client:close()
return
end
if len > 125 then
print("Unimplemented client payload length " .. tonumber(len))
return
end
local mask = data:sub(3, 6)
local payload = data:sub(7)
local decoded = {}
-- debug_print("mask")
-- for i=1,#mask do
-- -- debug_print(mask:byte(i))
-- end
for i=0,len-1 do
-- debug_print(tostring(i) .. "+1 demasking " .. tostring(payload:byte(i+1)) .. " with " .. tostring((i % #mask)+1) .. " mask byte")
decoded[i+1] = bit.bxor(payload:byte(i+1), mask:byte((i % #mask)+1))
end
debug_print(("fin: %i, rsv: %i,%i,%i, op: %i, masked: %i, len: %i"):format(fin, rsv1, rsv2, rsv3, opcode, masked, len))
if opcode == 1 then
local text = string.char(unpack(decoded))
debug_print("decoded: " .. text)
on_text(client, text, wss)
elseif opcode == 8 then
connected_ws[id] = nil
client:close()
elseif opcode == 9 then
send_ws(client, 10, "")
end
end
function hex_to_char(x)
return string.char(tonumber(x, 16))
end
function url_decode(s)
return string.gsub(s, "%%(%x%x)", hex_to_char)
end
local response_forbidden = [[
HTTP/1.1 403 Forbidden
Connection: Close
Content-Type: text/html; charset=utf-8
Access-Control-Allow-Origin: *
Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp
Cache-Control: max-age=15
<div style="background-color: darkgrey; foreground-color: white">
<h1>You cannot access this location, please check the script log and config.</h1>
</div>
]]
local response_404 = [[
HTTP/1.1 404 Not Found
Connection: Close
Content-Type: text/html; charset=utf-8
Access-Control-Allow-Origin: *
Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp
Cache-Control: max-age=15
<div style="background-color: darkgrey; foreground-color: white">
<h1>Route was not found.</h1>
</div>
]]
local response_bad = function(client, msg)
debug_print(msg)
client:send([[
HTTP/1.1 400 Bad Request
Connection: Close
Content-Type: text/html; charset=utf-8
Access-Control-Allow-Origin: *
Sec-Websocket-Version: 13
Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp
Cache-Control: max-age=15]] .. "\r\n\r\n" .. msg)
client:close()
end
local client_i = 1
function do_request(client)
local request_raw, err = client:receive()
if not request_raw then
if err ~= "timeout" then
error("Client read error: " .. err)
else
-- client probably closed the connection before we got to it
print("Client socket timeout before processing")
end
return
end
debug_print(request_raw)
local line = string.gmatch(request_raw, "[^\r\n]+")
local method, url, ver = string.match(line(), "(%g+) (%g*) HTTP/(%g+)")
url = split(url_decode(url), "?")[1]
if method ~= "GET" then
error("Error: client requested unsupported http method")
return
end
if ver ~= "1.1" then
error("Error: client requested unsupported http version")
return
end
print("\trequest for " .. url)
if url == wspath then
local headers = {}
for s in line do
local header = split(s, ":")
header[2] = header[2]:sub(2)
debug_print(header[1] .. "=" .. header[2] .. " (" .. s .. ")")
headers[header[1]] = header[2]
end
debug_print(headers["Upgrade"])
if headers["Upgrade"] == nil or headers["Upgrade"] ~= "websocket" then
response_bad(client, "bad upgrade")
return
end
if headers["Connection"] == nil or not read_header_list_includes(headers["Connection"], "Upgrade") then
response_bad(client, "bad connection")
return
end
if headers["Sec-WebSocket-Version"] == nil or headers["Sec-WebSocket-Version"] ~= "13" then
response_bad(client, "bad version")
return
end
if headers["Sec-WebSocket-Key"] == nil then
response_bad(client, "no key")
return
end
local accept = b64_sha1(headers["Sec-WebSocket-Key"].."258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
client:send(
[[HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: ]] .. accept .. "\r\n\r\n"
)
connected_ws[accept .. client_i] = client
client_i = client_i + 1
return
end
local route = routes[url]
if route == nil then
client:send(response_404)
client:close()
return
end
local content_type = "text/html; charset=utf8"
local content = route()
local content_length = string.len(content)
local headers = string.format(
[[HTTP/1.1 200 OK
Connection: Close
Content-Type: %s
Access-Control-Allow-Origin: *
Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp]] .. "\r\n\r\n", content_type)
print(string.format("Serving request for %s from %s (%s, %dkB)", url, filename, content_type, content_length / 1024))
client:send(headers)
client:send(content)
client:close()
end
obs.timer_add(do_slow_poll, poll_interval)
do_poll() -- no delay for testing
return wss
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment