-
-
Save SSDany/62817fa3b550eeed79175376d0429dcb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PartitionKeyReporter | |
class MissingPartitionKey < StandardError; end | |
class MissingPartitionKeyOnUpdate < MissingPartitionKey; end | |
class MissingPartitionKeyOnDelete < MissingPartitionKey; end | |
def table_name | |
@table_name ||= 'resources_partitioned' | |
end | |
def resource_name | |
@resource_name ||= 'Resource' | |
end | |
def partition_key | |
@partition_key ||= 'user_id' | |
end | |
def sampling_percentage | |
ENV.fetch('PARTITIONING_QUERY_SAMPLING_PERCENTAGE', 0).to_i | |
end | |
def check_query? | |
# Always check in non-production environments | |
return true unless Rails.env.production? | |
# Report only a given percentage of queries | |
rand(100) < sampling_percentage | |
end | |
def call(*_args, payload) | |
return unless check_query? | |
return if payload[:sql].blank? | |
if payload[:name] == "#{resource_name} Create" | |
report(payload[:sql]) unless payload[:sql].include?(partition_key) | |
else | |
parsed = PgQuery.parse(payload[:sql]) | |
return unless parsed.tables.include?(table_name) # load_objects! is alsp called here | |
statements = parsed.tree.stmts.dup.to_a.map(&:stmt) | |
fragments = [] | |
err_klass = MissingPartitionKey | |
loop do | |
statement = statements.shift | |
conditions = [] | |
from_data = [] | |
case statement.node | |
when :raw_stmt | |
statements << statement.raw_stmt.stmt | |
when :select_stmt | |
case statement.select_stmt.op | |
when :SETOP_NONE | |
# Extract FROM subselects and conditions from JOIN statements | |
statement.select_stmt.from_clause&.each do |item| | |
statements << item.range_subselect.subquery if item.range_subselect | |
from_data << item | |
next unless item.node == :join_expr | |
join_exprs = [item.join_expr] | |
loop do | |
ex = join_exprs.shift | |
break unless ex | |
conditions << ex.quals if ex.quals | |
join_exprs << ex.larg.join_expr if ex.larg.node == :join_expr | |
join_exprs << ex.rarg.join_expr if ex.rarg.node == :join_expr | |
end | |
end | |
# Extract conditions from WHERE clause | |
conditions << statement.select_stmt.where_clause if statement.select_stmt.where_clause | |
# CTEs | |
statement.select_stmt.with_clause&.ctes&.each do |item| | |
statements << item.common_table_expr.ctequery if item.node == :common_table_expr | |
end | |
when :SETOP_UNION, :SETOP_EXCEPT, :SETOP_INTERSECT | |
statements << PgQuery::Node.new(select_stmt: statement.select_stmt.larg) if statement.select_stmt.larg | |
statements << PgQuery::Node.new(select_stmt: statement.select_stmt.rarg) if statement.select_stmt.rarg | |
end | |
when :update_stmt | |
err_klass = MissingPartitionKeyOnUpdate | |
conditions << statement.update_stmt.where_clause if statement.update_stmt.where_clause | |
from_data = [PgQuery::Node.new(range_var: statement.update_stmt.relation)] | |
when :delete_stmt | |
err_klass = MissingPartitionKeyOnDelete | |
conditions << statement.delete_stmt.where_clause if statement.delete_stmt.where_clause | |
from_data = [PgQuery::Node.new(range_var: statement.delete_stmt.relation)] | |
end | |
conditions.each do |c| | |
if c&.node == :sub_link | |
statements << c.sub_link.subselect | |
end | |
end | |
# Aggregate (tables, conditions) pairs to ensure that we use | |
# correct conditions for our target table. | |
if conditions.present? | |
tables = [] | |
loop do | |
d = from_data.shift | |
case d.node | |
when :join_expr | |
from_data << d.join_expr.larg | |
from_data << d.join_expr.rarg | |
when :row_expr | |
from_data.concat d.row_expr.args | |
when :range_var | |
tables << d.range_var.relname | |
end | |
break if from_data.empty? | |
end | |
tables.uniq! | |
fragments << [tables, conditions] | |
end | |
break if statements.empty? | |
end | |
# Now either each fragment should be correct (i.e. there should be a condition with partition key) | |
# or the whole query should be reported. | |
success = fragments.all? do |tables, conditions| | |
nodes = Array.wrap(conditions) | |
columns = [] | |
loop do | |
next_item = nodes.shift | |
if next_item | |
case next_item.node | |
when :a_expr | |
l = next_item.a_expr.lexpr | |
r = next_item.a_expr.rexpr | |
qualified = case r.node | |
when :list | |
# partition key BETWEEN ? AND ? | |
r.list.items.all? { |i| i&.a_const&.ival.present? } | |
when :a_const | |
# partition key = ? | |
r.a_const&.ival.present? | |
end | |
nodes << l if qualified | |
when :bool_expr | |
nodes += next_item.bool_expr.args | |
when :coalesce_expr | |
nodes += next_item.coalesce_expr.args | |
when :row_expr | |
nodes += next_item.row_expr.args | |
when :null_test | |
nodes << next_item.null_test.arg | |
when :boolean_test | |
nodes << next_item.boolean_test.arg | |
when :sub_link | |
nodes << next_item.sub_link.testexpr | |
when :column_ref | |
column, table = next_item.column_ref.fields.map { |f| f.string.sval }.reverse | |
columns << [parsed.aliases[table] || table, column] | |
end | |
end | |
break if nodes.blank? | |
end | |
tables.exclude?(table_name) || columns.include?([table_name, partition_key]) | |
end | |
return if success | |
# Here we can report the invalid query, for example: | |
# | |
# obfuscated = NewRelic::Agent::Database::Obfuscator.instance.obfuscate(payload[:sql], :postgres) | |
# ActiveSupport::Deprecation.warn( | |
# "SQL query is missing the partition key:\n#{obfuscated}" | |
# ) | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment