Created
January 21, 2011 05:54
-
-
Save ebisawa/789298 to your computer and use it in GitHub Desktop.
mysql_ruby
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'mysql' | |
DB_DEFAULT_USER = 'root' | |
DB_DEFAULT_PASS = 'hogehoge' | |
DB_REFRESH = 60 * 60 # 1hour | |
MYSQL_INT_TYPES = [ | |
Mysql::Field::TYPE_TINY, Mysql::Field::TYPE_SHORT, | |
Mysql::Field::TYPE_LONG, Mysql::Field::TYPE_INT24, | |
Mysql::Field::TYPE_LONGLONG, Mysql::Field::TYPE_DECIMAL, | |
] | |
class SimpleBase | |
attr_accessor :sql_debug | |
def connect(host, dbname, user = DB_DEFAULT_USER, pass = DB_DEFAULT_PASS) | |
@db = Mysql::new(host, user, pass, dbname) | |
@sql_debug = false | |
end | |
def query(sql, *args) | |
STDERR.puts "DEBUG: sql=¥"#{sql}¥", args=[#{args.join(', ')}]" if @sql_debug | |
begin | |
s = @db.prepare(sql) | |
s.execute(*args) | |
rescue Mysql::Error => e | |
STDERR.puts e.error if @sql_debug | |
return nil | |
end | |
meta = s.result_metadata | |
if meta != nil | |
fields = meta.fetch_fields | |
return result_hash(s, fields) | |
end | |
(sql =‾ /^INSERT/) ? s.insert_id : nil | |
end | |
def table(table_name) | |
cols = {} | |
table_name = table_name.to_s | |
fields = @db.list_fields(table_name) | |
fields.fetch_fields.map {|f| cols[f.name] = f.hash } | |
SimpleTable.new(self, table_name, cols) | |
end | |
private | |
def result_hash(stmt, fields) | |
res = [] | |
stmt.each do |row| | |
rhash = {} | |
fields.each_with_index {|f, i| rhash[f.name.intern] = row[i] } | |
res.push(rhash) | |
end | |
res | |
end | |
end | |
class SimpleDB < SimpleBase | |
def initialize(host, dbname, user = DB_DEFAULT_USER, pass = DB_DEFAULT_PASS) | |
@host, @dbname, @user, @pass = host, dbname, user, pass | |
reconnect | |
end | |
def query(sql, *args) | |
refresh | |
super(sql, *args) | |
end | |
private | |
def refresh | |
if (Time.now - @tconn) > DB_REFRESH | |
reconnect | |
end | |
end | |
def reconnect | |
connect(@host, @dbname, @user, @pass) | |
@tconn = Time.now | |
end | |
end | |
class SimpleTable | |
def initialize(db, table, columns = nil) | |
@db = db | |
@table = table | |
@columns = columns | |
end | |
def all | |
select() | |
end | |
# primitives | |
def select(conditions = {}, options = {}) | |
where, params = make_where(conditions) | |
order_by = make_order_by(options[:order_by]) | |
group_by = make_group_by(options[:group_by]) | |
limit = "LIMIT #{options[:limit]}" if options[:limit] != nil | |
@db.query("SELECT * FROM #{@table} #{where} #{order_by} #{group_by} #{limit}", *params) | |
end | |
def select_distinct(column, conditions = {}) | |
where, params = make_where(conditions) | |
res = @db.query("SELECT DISTINCT #{column} FROM #{@table} #{where}", *params) | |
res.map {|r| r[column] } | |
end | |
def delete(conditions) | |
where, params = make_where(conditions) | |
@db.query("DELETE FROM #{@table} #{where}", *params) | |
end | |
def insert(insdata) | |
cols = []; vals = [] | |
insdata.each do |k, v| | |
c = k.to_s | |
validate_data(c, v) | |
cols << c; vals << v | |
end | |
cols = cols.join(',') | |
hats = ([ '?' ] * vals.size).join(',') | |
@db.query("INSERT INTO #{@table} (#{cols}) VALUES (#{hats})", *vals) | |
end | |
def update(updates, conditions) | |
ups, upa = expand_conditions(updates, ', ') | |
where, params = make_where(conditions) | |
@db.query("UPDATE #{@table} SET #{ups} #{where}", *(upa + params)) | |
end | |
def upsert(updates, key) | |
cols = []; vals = [] | |
insdata = key.merge(updates) | |
insdata.each do |k, v| | |
c = k.to_s | |
validate_data(c, v) | |
cols << c; vals << v | |
end | |
cols = cols.join(',') | |
hats = ([ '?' ] * vals.size).join(',') | |
ups, upa = expand_conditions(updates, ', ') | |
@db.query("INSERT INTO #{@table} (#{cols}) VALUES (#{hats}) ON DUPLICATE KEY UPDATE #{ups}", *(vals + upa)) | |
end | |
private | |
def expand_conditions(conditions, join_str = ' AND ') | |
where = []; params = [] | |
if conditions.size == 0 | |
return "", [] | |
else | |
conditions.each do |k, v| | |
col = k.to_s; val = v.to_s | |
if v == nil || val.upcase == 'NULL' | |
where << "#{col} IS NULL" | |
else | |
validate_data(col, v) | |
where << "#{col}=?"; params << val | |
end | |
end | |
return "#{where.join(join_str)}", params | |
end | |
end | |
def make_where(conditions) | |
where, params = expand_conditions(conditions) | |
where = "WHERE #{where}" if where != "" | |
return where, params | |
end | |
def make_order_by(colname) | |
return "" if colname == nil | |
colname = colname.to_s | |
validate_column(colname) | |
"ORDER BY #{colname}" | |
end | |
def make_group_by(colname) | |
return "" if colname == nil | |
colname = colname.to_s | |
validate_column(colname) | |
"GROUP BY #{colname}" | |
end | |
def validate_data(colname, value) | |
validate_column(colname) | |
validate_value(colname, value) | |
end | |
def validate_column(colname) | |
if @columns != nil && @columns[colname] == nil | |
raise "invalid column: #{colname}" | |
end | |
end | |
def validate_value(colname, value) | |
type = @columns[colname]['type'] | |
Integer(value) if MYSQL_INT_TYPES.include?(type) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment