Skip to content

Instantly share code, notes, and snippets.

@azazel75
Created May 29, 2021 21:44
Show Gist options
  • Save azazel75/30261a59f136e604b1c9c10bf026d50c to your computer and use it in GitHub Desktop.
Save azazel75/30261a59f136e604b1c9c10bf026d50c to your computer and use it in GitHub Desktop.
import pglast
from pglast.printer import IndentedStream
import pglast.printers
import pytest
from walker import walker
CREATE_SQL = """
CREATE TABLE public."order" (
id serial NOT NULL,
customer_id int4 NULL,
payment_type bpchar(255) NULL,
paid bool NULL,
payment_date timestamp NULL DEFAULT now(),
creation_date timestamp NULL DEFAULT now(),
CONSTRAINT order_pkey PRIMARY KEY (id)
);
"""
TEST_RM_NULL_SQL = """\
CREATE TABLE public."order" (
id serial NOT NULL
, customer_id int4
, payment_type bpchar(255)
, paid bool
, payment_date timestamp DEFAULT now()
, creation_date timestamp DEFAULT now()
, CONSTRAINT order_pkey PRIMARY KEY (id)
)\
"""
def format_stmt(stmt, **opts) -> str:
is_ = IndentedStream(**opts)
# here the statement is wrapped to avoid printing errors
return is_(pglast.Node(pglast.ast.RawStmt(stmt)))
@pytest.fixture
def parsed_create():
return pglast.parse_sql(CREATE_SQL)[0].stmt
def test_collect(parsed_create):
calls_count = 0
prev_fields_len = len(parsed_create.tableElts)
@walker
def filter_table_constraints(node, collect, stop):
nonlocal calls_count
calls_count += 1
if isinstance(node, pglast.ast.CreateStmt):
flds = []
for el in node.tableElts:
if isinstance(el, pglast.ast.Constraint):
collect(el)
else:
flds.append(el)
node.tableElts = tuple(flds)
stop()
_, constraints = filter_table_constraints(parsed_create)
assert len(constraints) == 1
assert calls_count == 1
assert len(parsed_create.tableElts) == prev_fields_len - len(constraints)
def test_rm_null_constr(parsed_create):
nulls_count = 0
@walker
def remove_null_constraints(node):
nonlocal nulls_count
if (isinstance(node, pglast.ast.Constraint) and
node.contype == pglast.enums.ConstrType.CONSTR_NULL):
nulls_count += 1
else:
return node
remove_null_constraints(parsed_create)
assert nulls_count == 5
# breakpoint()
assert format_stmt(parsed_create) == TEST_RM_NULL_SQL
# -*- coding: utf-8 -*-
import collections
import inspect
import typing
from typing import Callable, Generator, Optional
from pglast import ast
Stmt = Node = ast.Node
class walker:
"""
A functional transformator
"""
def __init__(self, transformer: Callable):
self.transformer: Callable = transformer
def iterate(self, root_node: Node) -> Generator[Node, bool, None]:
"""
Iterate over all the ast nodes
"""
todo = collections.deque()
if isinstance(root_node, ast.Node):
todo.append((((None, None, None),), root_node))
elif isinstance(root_node, tuple):
for index, item in enumerate(root_node):
if isinstance(item, ast.Node):
todo.append((((root_node, None, index),), item))
else:
raise ValueError('Bad argument, expected a ast.Node instance or '
'a tuple')
def replace_in_seq(seq, index, new_value):
if new_value is None:
if len(seq) == 1:
new_seq = None
else:
new_seq = seq[0: index] + seq[index + 1:]
elif isinstance(new_value, typing.Sequence):
new_seq = (seq[0: index] + type(seq)(new_value)
+ seq[index + 1:])
else:
new_seq = (seq[0: index] + (new_value,)
+ seq[index + 1:])
return new_seq
def replace_attr(parent, attr, index, new_value):
if parent is None:
raise ValueError("Needs a parent")
attr_is_seq = index is not None
new_is_node = isinstance(new_value, ast.Node)
if attr and index is None:
if new_is_node or new_value is None:
setattr(parent, attr, new_node)
else:
raise ValueError("Cannot set a sequence to a scalar"
f" value: {new_value!r}")
elif attr and attr_is_seq:
old_seq = getattr(parent, attr)
setattr(parent, attr,
replace_in_seq(old_seq, index, new_value))
elif attr is None and attr_is_seq:
raise ValueError("Cannot replace parent element")
else:
raise ValueError("'attr' and 'index' cannot be both None")
return parent
while todo:
path, node = todo.popleft()
stop, new_node = (yield path, node)
# replace the node in the tree
if new_node is not node:
*_, (parent, attr, index) = path
if parent is None:
pass
# current node is root
else:
replace_attr(parent, attr, index, new_node)
if new_node is None:
stop = True
else:
node = new_node
if not stop:
for attr in node:
value = getattr(node, attr)
if isinstance(value, ast.Node):
todo.append((path + ((node, attr, None),), value))
elif isinstance(value, tuple):
for index, item in enumerate(value):
if isinstance(item, ast.Node):
todo.append((path + ((node, attr, index),), item))
return root_node
def walk(self, root_node: Node, transformer: Optional[Callable] = None,
**ctx) -> list:
transformer = transformer or self.transformer
should_stop = False
def stop():
nonlocal should_stop
should_stop = True
collected = []
def collect(x):
collected.append(x)
fsig = inspect.signature(transformer)
wanted_arg_names = tuple(fsig.parameters.keys())[1:] # the first is
# always the node
all_args = {'stop': stop, 'collect': collect, 'ctx': ctx}
wanted_args = {k: v for k, v in all_args.items()
if k in wanted_arg_names}
try:
it = self.iterate(root_node)
path, node = it.send(None)
while True:
should_stop = False
# the output of the function will be retained for when the
# replacement of the current node will be implemented ;-)
if 'path' in wanted_arg_names:
node = transformer(node, path=path, **wanted_args)
else:
node = transformer(node, **wanted_args)
path, node = it.send((should_stop, node))
except StopIteration as e:
res = e.value
finally:
it.close()
return res, collected
__call__ = walk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment