Skip to content

Instantly share code, notes, and snippets.

@ebisawa
Created January 21, 2011 05:54
Show Gist options
  • Save ebisawa/789298 to your computer and use it in GitHub Desktop.
Save ebisawa/789298 to your computer and use it in GitHub Desktop.
mysql_ruby
require 'mysql'
DB_DEFAULT_USER = 'root'
DB_DEFAULT_PASS = 'hogehoge'
DB_REFRESH = 60 * 60 # 1hour
MYSQL_INT_TYPES = [
Mysql::Field::TYPE_TINY, Mysql::Field::TYPE_SHORT,
Mysql::Field::TYPE_LONG, Mysql::Field::TYPE_INT24,
Mysql::Field::TYPE_LONGLONG, Mysql::Field::TYPE_DECIMAL,
]
class SimpleBase
attr_accessor :sql_debug
def connect(host, dbname, user = DB_DEFAULT_USER, pass = DB_DEFAULT_PASS)
@db = Mysql::new(host, user, pass, dbname)
@sql_debug = false
end
def query(sql, *args)
STDERR.puts "DEBUG: sql=¥"#{sql}¥", args=[#{args.join(', ')}]" if @sql_debug
begin
s = @db.prepare(sql)
s.execute(*args)
rescue Mysql::Error => e
STDERR.puts e.error if @sql_debug
return nil
end
meta = s.result_metadata
if meta != nil
fields = meta.fetch_fields
return result_hash(s, fields)
end
(sql =‾ /^INSERT/) ? s.insert_id : nil
end
def table(table_name)
cols = {}
table_name = table_name.to_s
fields = @db.list_fields(table_name)
fields.fetch_fields.map {|f| cols[f.name] = f.hash }
SimpleTable.new(self, table_name, cols)
end
private
def result_hash(stmt, fields)
res = []
stmt.each do |row|
rhash = {}
fields.each_with_index {|f, i| rhash[f.name.intern] = row[i] }
res.push(rhash)
end
res
end
end
class SimpleDB < SimpleBase
def initialize(host, dbname, user = DB_DEFAULT_USER, pass = DB_DEFAULT_PASS)
@host, @dbname, @user, @pass = host, dbname, user, pass
reconnect
end
def query(sql, *args)
refresh
super(sql, *args)
end
private
def refresh
if (Time.now - @tconn) > DB_REFRESH
reconnect
end
end
def reconnect
connect(@host, @dbname, @user, @pass)
@tconn = Time.now
end
end
class SimpleTable
def initialize(db, table, columns = nil)
@db = db
@table = table
@columns = columns
end
def all
select()
end
# primitives
def select(conditions = {}, options = {})
where, params = make_where(conditions)
order_by = make_order_by(options[:order_by])
group_by = make_group_by(options[:group_by])
limit = "LIMIT #{options[:limit]}" if options[:limit] != nil
@db.query("SELECT * FROM #{@table} #{where} #{order_by} #{group_by} #{limit}", *params)
end
def select_distinct(column, conditions = {})
where, params = make_where(conditions)
res = @db.query("SELECT DISTINCT #{column} FROM #{@table} #{where}", *params)
res.map {|r| r[column] }
end
def delete(conditions)
where, params = make_where(conditions)
@db.query("DELETE FROM #{@table} #{where}", *params)
end
def insert(insdata)
cols = []; vals = []
insdata.each do |k, v|
c = k.to_s
validate_data(c, v)
cols << c; vals << v
end
cols = cols.join(',')
hats = ([ '?' ] * vals.size).join(',')
@db.query("INSERT INTO #{@table} (#{cols}) VALUES (#{hats})", *vals)
end
def update(updates, conditions)
ups, upa = expand_conditions(updates, ', ')
where, params = make_where(conditions)
@db.query("UPDATE #{@table} SET #{ups} #{where}", *(upa + params))
end
def upsert(updates, key)
cols = []; vals = []
insdata = key.merge(updates)
insdata.each do |k, v|
c = k.to_s
validate_data(c, v)
cols << c; vals << v
end
cols = cols.join(',')
hats = ([ '?' ] * vals.size).join(',')
ups, upa = expand_conditions(updates, ', ')
@db.query("INSERT INTO #{@table} (#{cols}) VALUES (#{hats}) ON DUPLICATE KEY UPDATE #{ups}", *(vals + upa))
end
private
def expand_conditions(conditions, join_str = ' AND ')
where = []; params = []
if conditions.size == 0
return "", []
else
conditions.each do |k, v|
col = k.to_s; val = v.to_s
if v == nil || val.upcase == 'NULL'
where << "#{col} IS NULL"
else
validate_data(col, v)
where << "#{col}=?"; params << val
end
end
return "#{where.join(join_str)}", params
end
end
def make_where(conditions)
where, params = expand_conditions(conditions)
where = "WHERE #{where}" if where != ""
return where, params
end
def make_order_by(colname)
return "" if colname == nil
colname = colname.to_s
validate_column(colname)
"ORDER BY #{colname}"
end
def make_group_by(colname)
return "" if colname == nil
colname = colname.to_s
validate_column(colname)
"GROUP BY #{colname}"
end
def validate_data(colname, value)
validate_column(colname)
validate_value(colname, value)
end
def validate_column(colname)
if @columns != nil && @columns[colname] == nil
raise "invalid column: #{colname}"
end
end
def validate_value(colname, value)
type = @columns[colname]['type']
Integer(value) if MYSQL_INT_TYPES.include?(type)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment