Skip to content

Instantly share code, notes, and snippets.

@SSDany
Created July 22, 2024 15:53
Show Gist options
  • Save SSDany/62817fa3b550eeed79175376d0429dcb to your computer and use it in GitHub Desktop.
Save SSDany/62817fa3b550eeed79175376d0429dcb to your computer and use it in GitHub Desktop.
class PartitionKeyReporter
class MissingPartitionKey < StandardError; end
class MissingPartitionKeyOnUpdate < MissingPartitionKey; end
class MissingPartitionKeyOnDelete < MissingPartitionKey; end
def table_name
@table_name ||= 'resources_partitioned'
end
def resource_name
@resource_name ||= 'Resource'
end
def partition_key
@partition_key ||= 'user_id'
end
def sampling_percentage
ENV.fetch('PARTITIONING_QUERY_SAMPLING_PERCENTAGE', 0).to_i
end
def check_query?
# Always check in non-production environments
return true unless Rails.env.production?
# Report only a given percentage of queries
rand(100) < sampling_percentage
end
def call(*_args, payload)
return unless check_query?
return if payload[:sql].blank?
if payload[:name] == "#{resource_name} Create"
report(payload[:sql]) unless payload[:sql].include?(partition_key)
else
parsed = PgQuery.parse(payload[:sql])
return unless parsed.tables.include?(table_name) # load_objects! is alsp called here
statements = parsed.tree.stmts.dup.to_a.map(&:stmt)
fragments = []
err_klass = MissingPartitionKey
loop do
statement = statements.shift
conditions = []
from_data = []
case statement.node
when :raw_stmt
statements << statement.raw_stmt.stmt
when :select_stmt
case statement.select_stmt.op
when :SETOP_NONE
# Extract FROM subselects and conditions from JOIN statements
statement.select_stmt.from_clause&.each do |item|
statements << item.range_subselect.subquery if item.range_subselect
from_data << item
next unless item.node == :join_expr
join_exprs = [item.join_expr]
loop do
ex = join_exprs.shift
break unless ex
conditions << ex.quals if ex.quals
join_exprs << ex.larg.join_expr if ex.larg.node == :join_expr
join_exprs << ex.rarg.join_expr if ex.rarg.node == :join_expr
end
end
# Extract conditions from WHERE clause
conditions << statement.select_stmt.where_clause if statement.select_stmt.where_clause
# CTEs
statement.select_stmt.with_clause&.ctes&.each do |item|
statements << item.common_table_expr.ctequery if item.node == :common_table_expr
end
when :SETOP_UNION, :SETOP_EXCEPT, :SETOP_INTERSECT
statements << PgQuery::Node.new(select_stmt: statement.select_stmt.larg) if statement.select_stmt.larg
statements << PgQuery::Node.new(select_stmt: statement.select_stmt.rarg) if statement.select_stmt.rarg
end
when :update_stmt
err_klass = MissingPartitionKeyOnUpdate
conditions << statement.update_stmt.where_clause if statement.update_stmt.where_clause
from_data = [PgQuery::Node.new(range_var: statement.update_stmt.relation)]
when :delete_stmt
err_klass = MissingPartitionKeyOnDelete
conditions << statement.delete_stmt.where_clause if statement.delete_stmt.where_clause
from_data = [PgQuery::Node.new(range_var: statement.delete_stmt.relation)]
end
conditions.each do |c|
if c&.node == :sub_link
statements << c.sub_link.subselect
end
end
# Aggregate (tables, conditions) pairs to ensure that we use
# correct conditions for our target table.
if conditions.present?
tables = []
loop do
d = from_data.shift
case d.node
when :join_expr
from_data << d.join_expr.larg
from_data << d.join_expr.rarg
when :row_expr
from_data.concat d.row_expr.args
when :range_var
tables << d.range_var.relname
end
break if from_data.empty?
end
tables.uniq!
fragments << [tables, conditions]
end
break if statements.empty?
end
# Now either each fragment should be correct (i.e. there should be a condition with partition key)
# or the whole query should be reported.
success = fragments.all? do |tables, conditions|
nodes = Array.wrap(conditions)
columns = []
loop do
next_item = nodes.shift
if next_item
case next_item.node
when :a_expr
l = next_item.a_expr.lexpr
r = next_item.a_expr.rexpr
qualified = case r.node
when :list
# partition key BETWEEN ? AND ?
r.list.items.all? { |i| i&.a_const&.ival.present? }
when :a_const
# partition key = ?
r.a_const&.ival.present?
end
nodes << l if qualified
when :bool_expr
nodes += next_item.bool_expr.args
when :coalesce_expr
nodes += next_item.coalesce_expr.args
when :row_expr
nodes += next_item.row_expr.args
when :null_test
nodes << next_item.null_test.arg
when :boolean_test
nodes << next_item.boolean_test.arg
when :sub_link
nodes << next_item.sub_link.testexpr
when :column_ref
column, table = next_item.column_ref.fields.map { |f| f.string.sval }.reverse
columns << [parsed.aliases[table] || table, column]
end
end
break if nodes.blank?
end
tables.exclude?(table_name) || columns.include?([table_name, partition_key])
end
return if success
# Here we can report the invalid query, for example:
#
# obfuscated = NewRelic::Agent::Database::Obfuscator.instance.obfuscate(payload[:sql], :postgres)
# ActiveSupport::Deprecation.warn(
# "SQL query is missing the partition key:\n#{obfuscated}"
# )
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment