Skip to content

Instantly share code, notes, and snippets.

@jruizvar
Created August 21, 2020 18:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jruizvar/d3f72e7a2ca8dce9522b1f84473143b0 to your computer and use it in GitHub Desktop.
Save jruizvar/d3f72e7a2ca8dce9522b1f84473143b0 to your computer and use it in GitHub Desktop.
""" SQL Code Generator.
Analisa as regras de uma árvore de decisão ajustada pelos modelos
Random Forest ou Gradient Boosted do Spark, para gerar o
correspondente código em SQL.
O conjunto de árvores pode ser extraido com o método `toDebugString`
e formatado como uma lista:
>> ensemble = model.trees
>> list_of_trees = [tree.toDebugString.split(\n) for tree in ensemble]
Finalmente, a lista de árvores pode ser iterada para gerar o código em SQL:
>> from sql_code_generator import *
>> sql_code = [sql_code_generator(tree) for tree in list_of_trees]
"""
import re
class Node:
def __init__(self, header, left=None, right=None):
""" Define os parâmetros de um nó da árvore
"""
self.header = header
self.left = left
self.right = right
def __repr__(self):
""" Acerta a formatação na linguagem SQL
"""
header = self.header.replace("{", "(").replace("}", ")")
return f"CASE WHEN {header} THEN {self.left} ELSE {self.right} END"
def parser(line, letter="I"):
""" Captura os parâmetros de uma linha.
As linhas que começam com `I` geram um novo nó.
As linhas que começam com `P` atualizam um nó existente.
"""
if letter == "I":
m = re.search(r"If \((.+)\)", line)
return Node(m.group(1))
m = re.search(r"Predict: (.+)", line)
return m.group(1)
def rule(line, root, depth):
""" Analisa cada linha da árvore.
Utiliza uma regex para obter a primeira letra da
linha e a profundidade de cada nó dentro da árvore.
"""
if not root:
return parser(line), ["left"]
m = re.search(r"(\s+)(\w)\w", line)
letter = m.group(2)
if letter == "E":
indentation = 2
s = m.group(1).count(" ") - indentation
d = depth[:s] + ["right"]
return root, d
exec("root." + ".".join(depth) + "= parser(line, letter)")
if letter == "I":
d = depth + ["left"]
else:
d = depth[:-1]
return root, d
def sql_code_generator(tree, root=None, depth=None):
""" Percorre as linhas do modelo Spark de forma recursiva.
No final, retorna uma string com o código em SQL.
"""
if not tree:
return str(root)
r, d = rule(tree[0], root, depth)
return sql_code_generator(tree[1:], r, d)
if __name__ == "__main__":
""" A modo de exemplo analisamos a seguinte árvore:
"""
tree = [
" If (a <= 0.5)", # root = Node("a <= 0.5")
" If (b in {1.0,2.0})", # root.left = Node(b in {1.0,2.0})
" Predict: 0.0", # root.left.left = "0.0"
" Else (b not in {1.0,2.0})",
" Predict: 1.0", # root.left.right = "1.0"
" Else (a > 0.5)",
" Predict: 0.0" # root.right = "0.0"
]
sql_code = sql_code_generator(tree)
print(sql_code)
""" Resultado:
CASE
WHEN
a <= 0.5
THEN
CASE
WHEN
b in
(
1.0, 2.0
)
THEN
0.0
ELSE
1.0
END
ELSE
0.0
END
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment