Skip to content

Instantly share code, notes, and snippets.

@hadees
Created March 5, 2016 05:41
Show Gist options
  • Star 34 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save hadees/cff6af2b53d340b9b4b2 to your computer and use it in GitHub Desktop.
Save hadees/cff6af2b53d340b9b4b2 to your computer and use it in GitHub Desktop.
Arel Helpers
module ArelHelpers
extend self
def self.included(base)
base.extend self
end
def asterisk(arel_table_or_model)
arel_table, columns = case arel_table_or_model
when Arel::Table
[arel_table_or_model, arel_table_or_model.engine.columns]
when ->(possible_model) { UtilitiesHelper.is_model?(possible_model) }
[arel_table_or_model.arel_table, arel_table_or_model.columns]
else
raise ArgumentError, "Must pass in an arel table or model"
end
columns.map { |c| arel_table[c.name] }
end
def greatest(*args)
Arel::Nodes::NamedFunction.new "greatest", args
end
def least(*args)
Arel::Nodes::NamedFunction.new "least", args
end
def cast(pred, type)
Arel::Nodes::NamedFunction.new "cast", [pred.as(type)]
end
def null_if(column, value)
Arel::Nodes::NamedFunction.new "NULLIF", [column, value]
end
def predicate(pred, true_value, false_value)
Arel::Nodes::SqlLiteral.new("CASE WHEN #{sqlv(pred)} THEN #{sqlv(true_value)} ELSE #{sqlv(false_value)} END")
end
def tsrange(lower_or_range, upper = nil)
Arel::Nodes::NamedFunction.new "tsrange", range_params(lower_or_range, upper)
end
def tstzrange(lower_or_range, upper = nil)
Arel::Nodes::NamedFunction.new "tstzrange", range_params(lower_or_range, upper)
end
def overlap(a, b)
Arel::Nodes::InfixOperation.new "&&", a, b
end
def coalesce(*args)
Arel::Nodes::NamedFunction.new "coalesce", args
end
def hstore_key(hstore, key)
Arel::Nodes::InfixOperation.new "->", hstore, cloneable(key)
end
def concat(*args)
Arel::Nodes::NamedFunction.new "concat", args
end
def mod(a, b)
Arel::Nodes::InfixOperation.new "%", a, b
end
def to_char(input, format)
Arel::Nodes::NamedFunction.new "to_char", [input, format]
end
def string_agg(input, delimiter)
Arel::Nodes::NamedFunction.new "string_agg", [input, delimiter]
end
def between(pred, lower_or_range, upper = nil)
Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper)))
end
def unnest(array)
Arel::Nodes::NamedFunction.new "unnest", [array]
end
def array_agg(expression)
Arel::Nodes::NamedFunction.new "array_agg", [expression]
end
def lower(expression)
Arel::Nodes::NamedFunction.new "lower", [expression]
end
def accumulative_or(array)
array.inject do |expressions, expression|
if expressions === expression
expression
else
expressions.or(expression)
end
end
end
def array_intersect(a1, a2, opts = {})
select1 = unnest(sqlv(a1))
select2 = unnest(sqlv(a2))
if !opts[:case_sensitive]
select1 = lower cast(select1, "text")
select2 = lower cast(select2, "text")
end
Arel::Nodes::SqlLiteral.new <<-SQL
ARRAY(
SELECT #{sqlv(select1)} INTERSECT
SELECT #{sqlv(select2)}
)
SQL
end
def descendants_search(table, id, max_depth: 999)
tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL
WITH RECURSIVE descendants_search(id, path) AS (
SELECT id, ARRAY[id]
FROM #{table.name}
WHERE id = #{id}
UNION ALL
SELECT #{table.name}.id, (path || #{table.name}.id)
FROM descendants_search
JOIN #{table.name}
ON descendants_search.id = #{table.name}.reports_to_id
WHERE NOT #{table.name}.id = ANY(path)
AND NOT array_length(path,1) > #{max_depth}
)
SELECT id
FROM descendants_search
WHERE id != #{id}
ORDER BY array_length(path, 1), path
SQL
table[:id].in(tree_sql)
end
def ancestor_search(table, id)
tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL
WITH RECURSIVE ancestor_search(id, reports_to_id, path) AS (
SELECT id, reports_to_id, ARRAY[id]
FROM #{table.name}
WHERE id = #{id}
UNION ALL
SELECT #{table.name}.id, #{table.name}.reports_to_id, (path || #{table.name}.id)
FROM ancestor_search
JOIN #{table.name}
ON ancestor_search.reports_to_id = #{table.name}.id
WHERE NOT #{table.name}.id = ANY(path)
)
SELECT id
FROM ancestor_search
WHERE id != #{id}
ORDER BY array_length(path, 1), path
SQL
table[:id].in(tree_sql)
end
def sqlv(node)
case node
when ->(n) { n.respond_to?(:to_sql) }
node.to_sql
when Arel::Attributes::Attribute
Arel::Nodes::SqlLiteral.new "\"#{node.relation.name}\".\"#{node.name}\""
when Array, Range
value = node.map { |x| x.is_a?(String) ? "'#{x}'" : x }.join(",")
Arel::Nodes::SqlLiteral.new "ARRAY[#{value}]"
when Time, DateTime, Date
Arel::Nodes.build_quoted node
when String
Arel::Nodes.build_quoted node
else
Arel::Nodes::SqlLiteral.new node.to_s
end
end
def array_agg(expression)
Arel::Nodes::NamedFunction.new "array_agg", [expression]
end
def between(pred, lower_or_range, upper = nil)
Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper)))
end
# This is a special ordering SQL used inside methods like array_agg
# http://www.postgresql.org/docs/current/static/sql-expressions.html#SYNTAX-AGGREGATES
def order_by(a, b)
Arel::Nodes::SqlLiteral.new "#{sqlv(a)} ORDER BY #{sqlv(b)}"
end
def range_params(lower_or_range, upper = nil)
case lower_or_range
when Range
lower = lower_or_range.min
upper = lower_or_range.max
else
lower = lower_or_range
end
[sqlv(lower), sqlv(upper)]
end
def cloneable(obj)
case obj
when Symbol
Arel::Nodes.build_quoted obj.to_s
else
obj
end
end
def self.sort(node, order)
case order.try(:to_sym)
when :asc
Arel::Nodes::Ascending.new node
when :desc
Arel::Nodes::Descending.new node
else
raise ArgumentError, "Must pass in either :asc or :desc"
end
end
end
@kevinluo201
Copy link

amazing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment