Skip to content

Instantly share code, notes, and snippets.

@noprompt
Last active February 27, 2023 22:12
Show Gist options
  • Save noprompt/176883f76e7051909a16438c357438dd to your computer and use it in GitHub Desktop.
Save noprompt/176883f76e7051909a16438c357438dd to your computer and use it in GitHub Desktop.
-- Delimiters by language
-- named_node -> left, right, separator
local delimiters = {
clojure = {
list_lit = { '(', ')' },
map_lit = { '{', '}' },
str_lit = { '"', '"' },
set_lit = { '#{', '}' },
vec_lit = { '[', ']' },
},
lua = {
table_constructor = { '{', '}', ',' },
parenthesized_expression = { '(', ')' }
}
}
local slurp_rules = {
clojure = {
comment = false
}
}
local function buf_can_slurp(node)
local language = vim.api.nvim_buf_get_option(0, 'filetype')
local x = slurp_rules[language][node:type()]
return x == nil or x == true
end
local function buf_get_delimiters()
local language = vim.api.nvim_buf_get_option(0, 'filetype')
return delimiters[language]
end
-- Treesitter Helpers
local function is_root(node)
return node:type() == node:root():type()
end
local function get_node_at_cursor(winnr)
winnr = winnr or 0
local vi_row, vi_col = unpack(vim.api.nvim_win_get_cursor(winnr))
local ts_row = vi_row - 1
local node = vim.treesitter.get_node_at_pos(0, ts_row, vi_col)
return node
end
local function get_delimited_node_at_cursor()
local delimiters = buf_get_delimiters()
if delimiters then
local node = get_node_at_cursor(0)
local delimited_node = nil
while not is_root(node) do
if delimiters[node:type()] then
delimited_node = node
break
end
node = node:parent()
end
return delimited_node
end
end
local function get_node_text(node)
local row_a, col_a, row_b, col_b = node:range()
return vim.api.nvim_buf_get_text(0, row_a, col_a, row_b, col_b, {})
end
local function start_equal(node_a, node_b)
if node_a and node_b then
local row_a_a, col_a_a, _, _ = node_a:range()
local row_b_a, col_b_a, _, _ = node_b:range()
return (
row_a_a == row_b_a and
col_a_a == col_b_a
)
else
return false
end
end
local function starts_at(node, ts_line, ts_col)
local ts_line_a, ts_col_a = node:start()
return (
ts_line_a == ts_line and
ts_col_a == ts_col
)
end
local function starts_after(node, ts_line, ts_col)
local ts_line_a, ts_col_a = node:start()
return (
ts_line == ts_line_a and ts_col < ts_col_a
or
ts_line < ts_line_a
)
end
local function swap(node_a, node_b)
local text_a = get_node_text(node_a)
local text_b = get_node_text(node_b)
local row_a_a, col_a_a, row_a_b, col_a_b = node_a:range()
local row_b_a, col_b_a, row_b_b, col_b_b = node_b:range()
vim.api.nvim_buf_set_text(0, row_b_a, col_b_a, row_b_b, col_b_b, text_a)
vim.api.nvim_buf_set_text(0, row_a_a, col_a_a, row_a_b, col_a_b, text_b)
end
local function set_cursor_to_node_start(node)
local ts_row, vi_col, _, _ = node:range()
local vi_row = ts_row + 1
vim.api.nvim_win_set_cursor(0, { vi_row , vi_col })
end
local function set_cursor_to_node_end(node, append)
local _, _, ts_row, vi_col = node:range()
local vi_row = ts_row + 1
if not append then
vim.api.nvim_win_set_cursor(0, { vi_row , vi_col })
else
vim.api.nvim_win_set_cursor(0, { vi_row , vi_col - 1})
end
end
local function down(node)
if node:named_child_count() == 0 then
return nil
else
return node:named_child(0)
end
end
local function rightmost(node)
if is_root(node) then
return nil
else
local parent = node:parent()
return parent:child(parent:child_count() - 1)
end
end
local function rightmost_named(node)
if is_root(node) then
return nil
else
local parent = node:parent()
return parent:named_child(parent:named_child_count() - 1)
end
end
local function next_node(node)
if is_root(node) then
return nil
else
local next_node = down(node)
if next_node then
return next_node
else
local next_node = node:next_named_sibling()
if next_node then
return next_node
else
local parent = node:parent()
while not is_root(parent) do
next_node = parent:next_named_sibling()
if next_node then
break
else
parent = parent:parent()
end
end
return next_node
end
end
end
end
local function prev_node(node)
if is_root(node) then
return nil
else
local node_prev = node:prev_named_sibling()
if node_prev then
if node_prev:named_child_count() == 0 then
return node_prev
else
return rightmost_named(down(node_prev))
end
else
node_prev = node:parent()
if start_equal(node_prev, node) then
return prev_node(node_prev)
else
return node_prev
end
end
end
end
local function largest(node)
if is_root(node) or is_root(node:parent()) then
return node
else
local parent = node:parent()
-- NOTE: I think this should use check equality of the whole range.
if start_equal(node, parent) then
return largest(parent)
else
return node
end
end
end
local function node_closest_to_pos_forward(node, ts_line, ts_col)
if starts_at(node, ts_line, ts_col) or node:named_child_count() == 0 then
return node
else
local closest_node
local i
for i = 0, node:named_child_count() - 1 do
local child = node:named_child(i)
if starts_after(child, ts_line, ts_col) then
closest_node = child
break
end
end
return closest_node
end
end
local function node_closest_to_cursor_forward()
local vi_line, vi_col = unpack(vim.api.nvim_win_get_cursor(0))
local ts_line = vi_line - 1
local node = vim.treesitter.get_node_at_pos(0, ts_line, vi_col)
return node_closest_to_pos_forward(node, ts_line, vi_col)
end
local function save_excursion(f)
local pos = vim.api.nvim_win_get_cursor(0)
f()
vim.api.nvim_win_set_cursor(0, pos)
end
local function I(x)
print(vim.inspect(x))
end
-- API
local P = {}
function P.forward_slurp()
local node = get_delimited_node_at_cursor()
if node then
local slurp_node = node:next_named_sibling()
if slurp_node and buf_can_slurp(slurp_node) then
local delimiters = buf_get_delimiters()[node:type()]
local right_delimiter = delimiters[2]
local separator = node:named_child_count() > 0 and delimiters[3] or ''
save_excursion(function ()
-- Append the right delimiter
set_cursor_to_node_end(slurp_node, true)
vim.api.nvim_put({ right_delimiter }, 'c', true, false)
-- Delete original right delimiter
set_cursor_to_node_start(node:child(node:child_count() - 1))
vim.cmd(string.format([[normal! "_%dx]], #right_delimiter))
-- Insert the separator
vim.api.nvim_put({ separator }, 'c', true, false)
-- Format
local format_start_row, _, _, _ = node:range()
local format_end_row, _, _, _ = slurp_node:range()
vim.cmd(string.format([[normal! %d==]], format_end_row - format_start_row + 1))
end)
end
end
end
function P.forward_barf()
local node = get_delimited_node_at_cursor()
if node then
local child_count = node:named_child_count()
if child_count > 0 then
local new_rightmost_child
if child_count == 1 then
new_rightmost_child = node:child(0)
else
new_rightmost_child = node:named_child(child_count - 2)
end
local delimiters = buf_get_delimiters()[node:type()]
local right_delimiter = delimiters[2]
local separator = delimiters[3]
save_excursion(function ()
-- Delete original right delimiter
set_cursor_to_node_start(node:child(node:child_count() - 1))
vim.cmd(string.format([[normal! "_%dx]], #right_delimiter))
-- TODO: Delete original separator
-- Append the new right delimiter
local should_append = true
set_cursor_to_node_end(new_rightmost_child, should_append)
vim.api.nvim_put({ right_delimiter }, 'c', should_append, true)
-- Format
local _, _, format_start_row, _ = new_rightmost_child:range()
local _, _, format_end_row, _ = node:range()
vim.cmd(string.format([[normal! %d==]], format_end_row - format_start_row + 1))
end)
end
end
end
function P.backward_slurp()
local node = get_delimited_node_at_cursor()
if node then
local slurp_node = node:prev_named_sibling()
if slurp_node and buf_can_slurp(slurp_node) then
local delimiters = buf_get_delimiters()[node:type()]
local left_delimiter = delimiters[1]
local separator = node:named_child_count() > 0 and delimiters[3] or ''
save_excursion(function ()
-- Delete the left delimiter.
set_cursor_to_node_start(node:child(0))
vim.cmd(string.format([[normal! "_%dx"]], #left_delimiter))
set_cursor_to_node_end(slurp_node)
-- Append separator after slurp node
set_cursor_to_node_start(slurp_node, true)
vim.api.nvim_put({ separator }, 'c', true, false)
-- Insert left delimiter before slurp node
set_cursor_to_node_start(slurp_node)
vim.api.nvim_put({ left_delimiter }, 'c', false, false)
-- Format
local format_start_row, _, _, _ = slurp_node:range()
local format_end_row, _, _, _ = node:range()
vim.cmd(string.format([[normal! %d==]], format_end_row - format_start_row + 1))
end)
end
end
end
function P.backward_barf()
local node = get_delimited_node_at_cursor()
if node then
local child_count = node:named_child_count()
if child_count > 0 then
local barf_node = node:named_child(0)
local delimiters = buf_get_delimiters()[node:type()]
local left_delimiter = delimiters[1]
local barf_separator = barf_node:next_named_sibling() and delimiters[3] or ''
save_excursion(function ()
local right_of_barf_node = barf_node:next_named_sibling() or rightmost(barf_node)
-- Insert the left delimiter to the right of the barf node.
if right_of_barf_node then
set_cursor_to_node_start(right_of_barf_node)
vim.api.nvim_put({ left_delimiter }, 'c', false, false)
end
-- Delete the original left delimiter.
set_cursor_to_node_start(node:child(0))
vim.cmd(string.format([[normal! "_%dx]], #left_delimiter))
-- Delete the separator to the right of the barf node if there
-- was one.
if #barf_separator > 0 then
set_cursor_to_node_end(barf_node)
vim.cmd(string.format([[normal! "_%dx]], #barf_separator))
end
-- TODO: Insert correct separator before barf node
-- Format
set_cursor_to_node_start(node)
local format_start_row, _, format_end_row, _ = node:range()
vim.cmd(string.format([[normal! %d==]], format_end_row - format_start_row + 1))
end)
end
end
end
function P.split()
local node = get_delimited_node_at_cursor()
if node then
local cursor = vim.api.nvim_win_get_cursor(0)
local line = vim.api.nvim_buf_get_lines(0, cursor[1] - 1, cursor[1], false)[1]
local part_1 = line:sub(1, cursor[2]):match('(.*%S)%s*') or ''
local part_2 = line:sub(cursor[2] + 1):match('^%s*(.*)') or ''
local pair = buf_get_delimiters()[node:type()]
if line:sub(cursor[2] + 1, cursor[2] + 1) == pair[1] then
-- At the opening brace of a first form at the top level.
local edit = pair[1] .. pair[2] .. ' ' .. part_1 .. part_2
vim.api.nvim_buf_set_lines(0, cursor[1] - 1, cursor[1], false, { edit })
else
local edit = part_1 .. pair[2] .. ' ' .. pair[1] .. part_2
vim.api.nvim_buf_set_lines(0, cursor[1] - 1, cursor[1], false, { edit })
end
end
end
function P.splice()
local node = get_delimited_node_at_cursor()
if node then
local delimiters = buf_get_delimiters()[node:type()]
local left_delimiter = delimiters[1]
local right_delimiter = delimiters[2]
save_excursion(function ()
-- Assume first and last child are the delimiters
set_cursor_to_node_start(node:child(node:child_count() - 1))
vim.cmd(string.format([[normal! "_%dx]], #right_delimiter))
set_cursor_to_node_start(node:child(0))
vim.cmd(string.format([[normal! "_%dx]], #left_delimiter))
end)
end
end
function P.set_cursor_to_next()
local vi_line, vi_col = unpack(vim.api.nvim_win_get_cursor(0))
local ts_line = vi_line - 1
local node = vim.treesitter.get_node_at_pos(0, ts_line, vi_col)
local node = next_node(node)
if node then
set_cursor_to_node_start(node)
end
end
function P.set_cursor_to_prev()
local vi_line, vi_col = unpack(vim.api.nvim_win_get_cursor(0))
local ts_line = vi_line - 1
local node = vim.treesitter.get_node_at_pos(0, ts_line, vi_col)
local node = prev_node(node)
if node then
set_cursor_to_node_start(node)
end
end
function P.set_cursor_to_left()
local node = get_node_at_cursor(0):prev_named_sibling()
if node then
set_cursor_to_node_start(node)
end
end
function P.set_cursor_to_right()
local node = get_node_at_cursor(0):next_named_sibling()
if node then
set_cursor_to_node_start(node)
end
end
function P.backward_swap()
local node_b = get_node_at_cursor(0)
local node_b = largest(node_b)
local node_a = node_b:prev_named_sibling()
if node_a then
set_cursor_to_node_start(node_a)
swap(node_a, node_b)
end
end
function P.forward_select()
local node = node_closest_to_cursor_forward()
if node then
local row_a, col_a, row_b, col_b = node:range()
vim.api.nvim_buf_set_mark(0, 'a', row_a + 1, col_a, {})
vim.api.nvim_buf_set_mark(0, 'b', row_b + 1, col_b - 1, {})
vim.cmd([[normal! `av`b]])
vim.api.nvim_buf_del_mark(0, 'a')
vim.api.nvim_buf_del_mark(0, 'b')
end
end
function P.forward_swap()
local node_a = get_node_at_cursor(0)
local node_a = largest(node_a)
local node_b = node_a:next_named_sibling()
if node_b then
set_cursor_to_node_start(node_b)
swap(node_a, node_b)
end
end
function P.wrap (c1, c2)
local node = node_closest_to_cursor_forward()
local row_a, col_a, row_b, col_b = node:range()
local text = vim.api.nvim_buf_get_text(0, row_a, col_a, row_b, col_b, {})
local idx_1 = 1
local idx_2 = row_b - row_a + 1
text[idx_1] = c1 .. text[idx_1]
text[idx_2] = text[idx_2] .. c2
vim.api.nvim_buf_set_text(0, row_a, col_a, row_b, col_b, text)
vim.cmd(string.format([[normal! %d==]], idx_2))
end
function P.wrap_double_quote()
P.wrap('"', '"')
end
function P.wrap_single_quote()
P.wrap("'", "'")
end
function P.wrap_round()
P.wrap('(', ')')
end
function P.wrap_square()
P.wrap('[', ']')
end
function P.wrap_curly()
P.wrap('{', '}')
end
return P
function! VSCodeTSSync()
if exists('g:vscode')
TSToggle highlight
TSToggle highlight
endif
endfunction
" Esper Start
function! EsperDev()
call VSCodeTSSync()
lua package.loaded['esper'] = nil
endfunction
function! EsperBackwardBarf()
call EsperDev()
lua require('esper').backward_barf()
endfunction
function! EsperBackwardSlurp()
call EsperDev()
lua require('esper').backward_slurp()
endfunction
function! EsperBackwardSwap()
call EsperDev()
lua require('esper').backward_swap()
endfunction
function! EsperSplice()
call EsperDev()
lua require('esper').splice()
endfunction
function! EsperSplit()
call EsperDev()
lua require('esper').split()
endfunction
function! EsperForwardBarf()
call EsperDev()
lua require('esper').forward_barf()
endfunction
function! EsperForwardSelect()
call EsperDev()
lua require('esper').forward_select()
endfunction
function! EsperForwardYank()
call EsperForwardSelect()
normal! y
endfunction
function! EsperForwardKill()
call EsperForwardSelect()
normal! d
endfunction
function! EsperForwardSlurp()
call EsperDev()
lua require('esper').forward_slurp()
endfunction
function! EsperForwardSwap()
call EsperDev()
lua require('esper').forward_swap()
endfunction
function! EsperNext()
call EsperDev()
lua require('esper').set_cursor_to_next()
endfunction
function! EsperPrev()
call EsperDev()
lua require('esper').set_cursor_to_prev()
endfunction
function! EsperWrapCurly()
call EsperDev()
lua require('esper').wrap_curly()
endfunction
function! EsperWrapRound()
call EsperDev()
lua require('esper').wrap_round()
endfunction
function! EsperWrapSquare()
call EsperDev()
lua require('esper').wrap_square()
endfunction
function! EsperWrapDoubleQuote()
call EsperDev()
lua require('esper').wrap_double_quote()
endfunction
function! EsperWrapSingleQuote()
call EsperDev()
lua require('esper').wrap_double_quote()
endfunction
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment