Skip to content

Instantly share code, notes, and snippets.

@kylege
Created April 3, 2026 07:16
Show Gist options
  • Select an option

  • Save kylege/88c87d1697dda64662280107e7f853e9 to your computer and use it in GitHub Desktop.

Select an option

Save kylege/88c87d1697dda64662280107e7f853e9 to your computer and use it in GitHub Desktop.
Split SQL string into multiple sub-statements, based on Golang.
package sqlsplit
import (
"strings"
"unicode/utf8"
)
type SplitConfig struct {
Dialect string
StripSemicolon bool
}
type splitOption func(*SplitConfig)
func WithDialect(dialect string) splitOption {
return func(c *SplitConfig) {
c.Dialect = dialect
}
}
func WithStripSemicolon() splitOption {
return func(c *SplitConfig) {
c.StripSemicolon = true
}
}
// Split 将 sql 分割成多个独立的 sql
//
type Split struct {
// src 是源 sql 字符串
src string
// the current position of the cursor
cursor int
// the start position of the current statement
start int
// 保存最近看到的标识标
idents *seenIdentifiers
// 是否进入了 statement set 多语句中
stmtSet bool
config *SplitConfig
}
// New 生成一个 Split 实例
func New(input string, opts ...splitOption) *Split {
spliter := &Split{
src: input,
idents: newSeenIdentifiers(2),
config: &SplitConfig{},
}
for _, opt := range opts {
opt(spliter.config)
}
return spliter
}
func (s *Split) Split() []string {
var statements []string
for {
statement := s.scan()
if statement == "" {
break
}
statements = append(statements, statement)
}
return statements
}
func (s *Split) stmt() string {
_ = s.next()
// 空白符和单行注释
// 换行符不被当作空白符处理
for {
ch := s.peek()
if isWhitespace(ch) {
s.scanWhitespace()
continue
}
if isSingleLineComment(ch, s.lookAhead(1)) {
s.scanSingleLineComment()
continue
}
break
}
stmt := s.src[s.start:s.cursor]
s.start = s.cursor
stmt = strings.TrimSpace(stmt)
if s.config.StripSemicolon {
stmt = strings.TrimSuffix(stmt, ";")
}
stmt = strings.TrimSpace(stmt)
return stmt
}
func (s *Split) scanIdentifier() {
curPos := s.cursor
for {
ch := s.peek()
if isLetter(ch) || isNumber(ch) || ch == '_' {
s.next()
} else {
break
}
}
identifier := s.src[curPos:s.cursor]
// 判断 flink 的 statement set 语句
if strings.ToUpper(identifier) == "SET" && strings.ToUpper(s.idents.current()) == "STATEMENT" {
s.stmtSet = true
}
s.idents.add(identifier)
}
// 遇到分号时调用,如果当前在 statement set 中,则继续扫描,直到遇到 END 关键字
func (s *Split) tryStmt() (string, bool) {
if s.stmtSet {
if s.idents.current() == "END" {
s.stmtSet = false
return s.stmt(), true
}
return "", false
}
return s.stmt(), true
}
// scan scans the next statement and returns it.
func (s *Split) scan() string {
for {
ch := s.peek()
switch {
case isEOF(ch):
return s.stmt()
case ch == ';':
stmt, ok := s.tryStmt()
if ok {
return stmt
}
_ = s.next()
case isSingleQuote(ch):
s.scanString('\'')
case isDoubleQuote(ch):
s.scanString('"')
case isLetter(ch):
s.scanIdentifier()
case isSingleLineComment(ch, s.lookAhead(1)):
s.scanSingleLineComment()
case isMultiLineComment(ch, s.lookAhead(1)):
s.scanMultiLineComment()
default:
_ = s.next()
}
}
}
func (s *Split) scanString(quote rune) {
ch := s.next() // consume the opening quote
escaped := false
for {
if escaped {
// encountered an escape character
// reset the escaped flag and continue
escaped = false
ch = s.next()
continue
}
if ch == '\\' {
escaped = true
ch = s.next()
continue
}
if ch == quote {
s.next() // consume the closing quote
return
}
if isEOF(ch) {
// encountered EOF before closing quote
// this usually happens when the string is truncated
return
}
ch = s.next()
}
}
func (s *Split) scanWhitespace() {
// scan whitespace, tab, carriage return
ch := s.next()
for isWhitespace(ch) {
ch = s.next()
}
}
func (s *Split) scanSingleLineComment() {
ch := s.nextBy(2) // consume the opening dashes
for ch != '\n' && !isEOF(ch) {
ch = s.next()
}
// 把换行符也放到 comment 中
// 为了与 sqlparse 库逻辑保持一致
if ch == '\n' {
_ = s.next()
}
}
func (s *Split) scanMultiLineComment() {
ch := s.nextBy(2) // consume the opening slash and asterisk
for {
if ch == '*' && s.lookAhead(1) == '/' {
s.nextBy(2) // consume the closing asterisk and slash
break
}
if isEOF(ch) {
// encountered EOF before closing comment
// this usually happens when the comment is truncated
return
}
ch = s.next()
}
}
// lookAhead returns the rune n positions ahead of the cursor.
func (s *Split) lookAhead(n int) rune {
if s.cursor+n >= len(s.src) || s.cursor+n < 0 {
return 0
}
r, _ := utf8.DecodeRuneInString(s.src[s.cursor+n:])
return r
}
// peek returns the rune at the cursor position.
func (s *Split) peek() rune {
return s.lookAhead(0)
}
// nextBy advances the cursor by n positions and returns the rune at the cursor position.
func (s *Split) nextBy(n int) rune {
// advance the cursor by n and return the rune at the cursor position
if s.cursor+n > len(s.src) {
return 0
}
s.cursor += n
if s.cursor >= len(s.src) {
return 0
}
r, _ := utf8.DecodeRuneInString(s.src[s.cursor:])
return r
}
// next advances the cursor by 1 position and returns the rune at the cursor position.
func (s *Split) next() rune {
return s.nextBy(1)
}
func isEOF(ch rune) bool {
return ch == 0
}
func isSingleLineComment(ch rune, nextCh rune) bool {
return ch == '-' && nextCh == '-'
}
func isMultiLineComment(ch rune, nextCh rune) bool {
return ch == '/' && nextCh == '*'
}
func isSingleQuote(ch rune) bool {
return ch == '\''
}
func isDoubleQuote(ch rune) bool {
return ch == '"'
}
func isLetter(ch rune) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
func isNumber(ch rune) bool {
return ch >= '0' && ch <= '9'
}
// 非换行符的其它空白符
func isWhitespace(ch rune) bool {
return ch == ' ' || ch == '\t' || ch == '\r'
}
type seenIdentifiers struct {
size int
identifiers []string
pos int
}
func newSeenIdentifiers(size int) *seenIdentifiers {
return &seenIdentifiers{
size: size,
identifiers: make([]string, size),
pos: 0,
}
}
func (s *seenIdentifiers) add(identifier string) {
if s.size == 0 {
return
}
s.identifiers[s.pos] = identifier
s.pos = (s.pos + 1) % s.size
}
func (s *seenIdentifiers) current() string {
if s.size == 0 {
return ""
}
return s.identifiers[(s.pos-1+s.size)%s.size]
}
func (s *seenIdentifiers) String() string {
if s.size == 0 || s.identifiers[0] == "" {
return ""
}
ordered := make([]string, 0, s.size)
if s.identifiers[s.pos] == "" {
ordered = append(ordered, s.identifiers[:s.pos]...)
} else {
for i := 0; i < s.size; i++ {
idx := (s.pos + i) % s.size
ordered = append(ordered, s.identifiers[idx])
}
}
return strings.Join(ordered, " ")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment