Skip to content

Instantly share code, notes, and snippets.

@Rulexec
Last active July 14, 2020 07:15
Show Gist options
  • Save Rulexec/664747f12179d6bd94d0a112a138d679 to your computer and use it in GitHub Desktop.
Save Rulexec/664747f12179d6bd94d0a112a138d679 to your computer and use it in GitHub Desktop.
Port of punycode encode function in lua from https://github.com/bestiejs/punycode.js
local bit = require "bit"
local _M = {}
local maxInt = 2147483647
local base = 36
local tMin = 1
local tMax = 26
local skew = 38
local damp = 700
local initialBias = 72
local initialN = 128
local delimiter = 45
local baseMinusTMin = base - tMin
local function utf8_decode(str)
local i = 1
local length = string.len(str)
local output = {}
local codes_left = 0
local prev_multioctet = 0
while i <= length do
local c = string.byte(str, i)
i = i + 1
if codes_left > 0 then
prev_multioctet = prev_multioctet + (c - 0x80)
codes_left = codes_left - 1
if codes_left == 0 then
table.insert(output, prev_multioctet)
else
prev_multioctet = prev_multioctet * 0x40
end
elseif c <= 0x7f then
table.insert(output, c)
elseif c <= 0xdf then
codes_left = 1
prev_multioctet = (c - 0xc0) * 0x40
elseif c <= 0xef then
codes_left = 2
prev_multioctet = (c - 0xe0) * 0x40
else
codes_left = 3
prev_multioctet = (c - 0xf0) * 0x40
end
end
return output
end
local function ucs2decode(str)
local output = {}
local counter = 1
local length = table.getn(str)
while counter <= length do
local value = str[counter]
counter = counter + 1
if (value >= 0xD800 and value <= 0xDBFF and counter <= length) then
local extra = str[counter]
counter = counter + 1
if bit.band(extra, 0xfc00) == 0xdc00 then
table.insert(output, bit.lshift(bit.band(value, 0x3ff), 10) + bit.band(extra, 0x3ff) + 0x10000)
else
table.insert(output, value)
counter = counter - 1
end
else
table.insert(output, value)
end
end
return output
end
local function digit_to_basic(digit, flag)
local result = digit + 22
if flag ~= 0 then
result = result - 32
end
if digit < 26 then
result = result + 75
end
return result
end
local function adapt(delta, numPoints, firstTime)
local k = 0
if firstTime then
delta = math.floor(delta / damp)
else
delta = bit.rshift(delta, 1)
end
delta = delta + math.floor(delta / numPoints)
while delta > baseMinusTMin * bit.rshift(tMax, 1) do
delta = math.floor(delta / baseMinusTMin)
k = k + base
end
return math.floor(k + (baseMinusTMin + 1) * delta / (delta + skew))
end
local function from_char_codes(arr)
local result = ""
for _, code in ipairs(arr) do
result = result .. string.char(code)
end
return result
end
local function raw_encode(input)
local output = {}
input = ucs2decode(input)
local inputLength = table.getn(input)
local n = initialN
local delta = 0
local bias = initialBias
for _, val in ipairs(input) do
if val < 0x80 then
table.insert(output, val)
end
end
local basicLength = table.getn(output)
local handledCPCount = basicLength
if basicLength > 0 then
table.insert(output, delimiter)
end
while handledCPCount < inputLength do
local m = maxInt
for _, val in ipairs(input) do
if val >= n and val < m then
m = val
end
end
local handledCPCountPlusOne = handledCPCount + 1
if m - n > math.floor((maxInt - delta) / handledCPCountPlusOne) then
error('overflow')
end
delta = delta + (m - n) * handledCPCountPlusOne
n = m
for _, val in ipairs(input) do
if val < n then
delta = delta + 1
if delta > maxInt then
error('overflow')
end
end
if val == n then
local q = delta
local k = base
while true do
local t
if k <= bias then
t = tMin
elseif k >= bias + tMax then
t = tMax
else
t = k - bias
end
if q < t then
break
end
local qMinusT = q - t
local baseMinusT = base - t
table.insert(output, digit_to_basic(t + qMinusT % baseMinusT, 0))
q = math.floor(qMinusT / baseMinusT)
k = k + base
end
table.insert(output, digit_to_basic(q, 0))
bias = adapt(delta, handledCPCountPlusOne, handledCPCount == basicLength)
delta = 0
handledCPCount = handledCPCount + 1
end
end
delta = delta + 1
n = n + 1
end
return output
end
function _M.encode_domain(domain)
local has_non_ascii = false
local length = string.len(domain)
for i = 1,length do
if string.byte(domain, i) > 0x7e then
has_non_ascii = true
break
end
end
if not has_non_ascii then
return domain
end
-- TODO: implement split by \u3002\uFF0E\uFF61 too
domain = utf8_decode(domain)
local result = ""
local part = {}
local part_has_non_ascii = false
for _, c in ipairs(domain) do
if c ~= 0x2e then
if c > 0x7e then
part_has_non_ascii = true
end
table.insert(part, c)
else
if part_has_non_ascii then
result = result .. "xn--" .. from_char_codes(raw_encode(part)) .. "."
else
result = result .. from_char_codes(part) .. "."
end
part = {}
part_has_non_ascii = false
end
end
if part_has_non_ascii then
result = result .. "xn--" .. from_char_codes(raw_encode(part))
else
result = result .. from_char_codes(part)
end
return result
end
return _M
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment