-
-
Save cee-dub/277630 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'strscan' | |
require 'set' | |
class SqlQuery < Struct.new(:table, :where, :order) | |
def to_json | |
[table, where, order].to_json | |
end | |
end | |
class SqlScanner | |
STARTERS = Set.new %w(SELECT UPDATE) | |
ANY_CLAUSE = /\s(WHERE|ORDER BY|LIMIT|OFFSET|GROUP BY)/i | |
def initialize(s) | |
s.strip! | |
s.gsub! /'[^']*'/, '?' | |
s.gsub! /\=\s*\d+/, '= ?' | |
s.gsub! /<>\s*\d+/, '<> ?' | |
s.gsub! /BINARY \?/, '?' | |
s.gsub! /IN \([^\)]*\)/, 'IN (?)' | |
@scanner = StringScanner.new(s) | |
end | |
def parse | |
q = SqlQuery.new | |
if scannable? | |
scan_table_name(q) | |
scan_until_clause(q) | |
end | |
q | |
end | |
def scannable? | |
@scanner.scan(/\w+/) | |
STARTERS.include?(@scanner.matched) | |
end | |
def scan_table_name(query) | |
if @scanner.scan_until(/`[^`]+`/) | |
query.table = @scanner.matched.gsub(/^`|`$/, '') | |
end | |
end | |
def scan_until_clause(query, field_to_set = nil) | |
str = @scanner.scan_until(ANY_CLAUSE) || @scanner.rest | |
if str && field_to_set | |
str.strip! | |
str.gsub! /\s(WHERE|ORDER BY|LIMIT|OFFSET|GROUP BY)$/i, '' | |
str.rstrip! | |
query.send("#{field_to_set}=", str) | |
end | |
case @scanner.matched | |
when /WHERE/i then scan_where_clause(query) | |
when /ORDER/i then scan_order_clause(query) | |
end | |
end | |
def scan_where_clause(query) | |
scan_until_clause query, :where | |
end | |
def scan_order_clause(query) | |
scan_until_clause query, :order | |
end | |
end | |
require 'test/unit' | |
class SqlScannerTest < Test::Unit::TestCase | |
def setup | |
@string = %(SELECT * FROM `faqs` WHERE (`faqs`.site_id = 1) AND ((faqs.published_at <= '2010-01-10 16:57:46') AND (`faqs`.site_id = 1)) ORDER BY faqs.updated_at desc LIMIT 5) | |
@scanner = SqlScanner.new @string | |
end | |
def test_parses_table_name | |
assert_equal 'faqs', @scanner.parse.table | |
end | |
def test_parses_where_clause_at_end | |
scanner = SqlScanner.new(%(SELECT * FROM `faqs` WHERE (`faqs`.site_id = 1) AND ((faqs.published_at <= '2010-01-10 16:57:46') AND (`faqs`.site_id = 1)))) | |
assert_equal %((`faqs`.site_id = ?) AND ((faqs.published_at <= ?) AND (`faqs`.site_id = ?))), scanner.parse.where | |
end | |
def test_parses_where_clause | |
assert_equal %((`faqs`.site_id = ?) AND ((faqs.published_at <= ?) AND (`faqs`.site_id = ?))), @scanner.parse.where | |
end | |
def test_parses_order_clause | |
assert_equal %(faqs.updated_at desc), @scanner.parse.order | |
end | |
def test_parses_without_quoted_table | |
scanner = SqlScanner.new %(SELECT * FROM faqs) | |
assert_nil scanner.parse.table | |
end | |
def test_parses_binary_where | |
scanner = SqlScanner.new %(SELECT * FROM `faqs` where a = BINARY '') | |
assert_equal "a = ?", scanner.parse.where | |
end | |
def test_parses_not_equal_to_number | |
scanner = SqlScanner.new %(SELECT * FROM `faqs` where a <> 1) | |
assert_equal "a <> ?", scanner.parse.where | |
end | |
def test_parses_not_equal_to_string | |
scanner = SqlScanner.new %(SELECT * FROM `faqs` where a <> 'abc') | |
assert_equal "a <> ?", scanner.parse.where | |
end | |
def test_parses_string_array | |
scanner = SqlScanner.new %(SELECT * FROM `faqs` where a IN ('abc', 'def')) | |
assert_equal "a IN (?)", scanner.parse.where | |
end | |
def test_parses_num_array | |
scanner = SqlScanner.new %(SELECT * FROM `faqs` where a IN (1, 2)) | |
assert_equal "a IN (?)", scanner.parse.where | |
end | |
def test_parses_null_array | |
scanner = SqlScanner.new %(SELECT * FROM `faqs` where a IN (NULL)) | |
assert_equal "a IN (?)", scanner.parse.where | |
end | |
def test_parses_order_clause_at_end | |
scanner = SqlScanner.new %(SELECT * FROM `faqs` ORDER BY faqs.updated_at desc) | |
assert_equal %(faqs.updated_at desc), scanner.parse.order | |
end | |
def test_where_clause_with_where_field | |
scanner = SqlScanner.new("SELECT `foo` WHERE foo.where = 1") | |
assert_equal 'foo.where = ?', scanner.parse.where | |
end | |
def test_order_clause_with_where_field | |
scanner = SqlScanner.new("SELECT `foo` ORDER BY foo.where") | |
assert_equal 'foo.where', scanner.parse.order | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment