Created
May 29, 2021 21:44
-
-
Save azazel75/30261a59f136e604b1c9c10bf026d50c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pglast | |
from pglast.printer import IndentedStream | |
import pglast.printers | |
import pytest | |
from walker import walker | |
CREATE_SQL = """ | |
CREATE TABLE public."order" ( | |
id serial NOT NULL, | |
customer_id int4 NULL, | |
payment_type bpchar(255) NULL, | |
paid bool NULL, | |
payment_date timestamp NULL DEFAULT now(), | |
creation_date timestamp NULL DEFAULT now(), | |
CONSTRAINT order_pkey PRIMARY KEY (id) | |
); | |
""" | |
TEST_RM_NULL_SQL = """\ | |
CREATE TABLE public."order" ( | |
id serial NOT NULL | |
, customer_id int4 | |
, payment_type bpchar(255) | |
, paid bool | |
, payment_date timestamp DEFAULT now() | |
, creation_date timestamp DEFAULT now() | |
, CONSTRAINT order_pkey PRIMARY KEY (id) | |
)\ | |
""" | |
def format_stmt(stmt, **opts) -> str: | |
is_ = IndentedStream(**opts) | |
# here the statement is wrapped to avoid printing errors | |
return is_(pglast.Node(pglast.ast.RawStmt(stmt))) | |
@pytest.fixture | |
def parsed_create(): | |
return pglast.parse_sql(CREATE_SQL)[0].stmt | |
def test_collect(parsed_create): | |
calls_count = 0 | |
prev_fields_len = len(parsed_create.tableElts) | |
@walker | |
def filter_table_constraints(node, collect, stop): | |
nonlocal calls_count | |
calls_count += 1 | |
if isinstance(node, pglast.ast.CreateStmt): | |
flds = [] | |
for el in node.tableElts: | |
if isinstance(el, pglast.ast.Constraint): | |
collect(el) | |
else: | |
flds.append(el) | |
node.tableElts = tuple(flds) | |
stop() | |
_, constraints = filter_table_constraints(parsed_create) | |
assert len(constraints) == 1 | |
assert calls_count == 1 | |
assert len(parsed_create.tableElts) == prev_fields_len - len(constraints) | |
def test_rm_null_constr(parsed_create): | |
nulls_count = 0 | |
@walker | |
def remove_null_constraints(node): | |
nonlocal nulls_count | |
if (isinstance(node, pglast.ast.Constraint) and | |
node.contype == pglast.enums.ConstrType.CONSTR_NULL): | |
nulls_count += 1 | |
else: | |
return node | |
remove_null_constraints(parsed_create) | |
assert nulls_count == 5 | |
# breakpoint() | |
assert format_stmt(parsed_create) == TEST_RM_NULL_SQL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import collections | |
import inspect | |
import typing | |
from typing import Callable, Generator, Optional | |
from pglast import ast | |
Stmt = Node = ast.Node | |
class walker: | |
""" | |
A functional transformator | |
""" | |
def __init__(self, transformer: Callable): | |
self.transformer: Callable = transformer | |
def iterate(self, root_node: Node) -> Generator[Node, bool, None]: | |
""" | |
Iterate over all the ast nodes | |
""" | |
todo = collections.deque() | |
if isinstance(root_node, ast.Node): | |
todo.append((((None, None, None),), root_node)) | |
elif isinstance(root_node, tuple): | |
for index, item in enumerate(root_node): | |
if isinstance(item, ast.Node): | |
todo.append((((root_node, None, index),), item)) | |
else: | |
raise ValueError('Bad argument, expected a ast.Node instance or ' | |
'a tuple') | |
def replace_in_seq(seq, index, new_value): | |
if new_value is None: | |
if len(seq) == 1: | |
new_seq = None | |
else: | |
new_seq = seq[0: index] + seq[index + 1:] | |
elif isinstance(new_value, typing.Sequence): | |
new_seq = (seq[0: index] + type(seq)(new_value) | |
+ seq[index + 1:]) | |
else: | |
new_seq = (seq[0: index] + (new_value,) | |
+ seq[index + 1:]) | |
return new_seq | |
def replace_attr(parent, attr, index, new_value): | |
if parent is None: | |
raise ValueError("Needs a parent") | |
attr_is_seq = index is not None | |
new_is_node = isinstance(new_value, ast.Node) | |
if attr and index is None: | |
if new_is_node or new_value is None: | |
setattr(parent, attr, new_node) | |
else: | |
raise ValueError("Cannot set a sequence to a scalar" | |
f" value: {new_value!r}") | |
elif attr and attr_is_seq: | |
old_seq = getattr(parent, attr) | |
setattr(parent, attr, | |
replace_in_seq(old_seq, index, new_value)) | |
elif attr is None and attr_is_seq: | |
raise ValueError("Cannot replace parent element") | |
else: | |
raise ValueError("'attr' and 'index' cannot be both None") | |
return parent | |
while todo: | |
path, node = todo.popleft() | |
stop, new_node = (yield path, node) | |
# replace the node in the tree | |
if new_node is not node: | |
*_, (parent, attr, index) = path | |
if parent is None: | |
pass | |
# current node is root | |
else: | |
replace_attr(parent, attr, index, new_node) | |
if new_node is None: | |
stop = True | |
else: | |
node = new_node | |
if not stop: | |
for attr in node: | |
value = getattr(node, attr) | |
if isinstance(value, ast.Node): | |
todo.append((path + ((node, attr, None),), value)) | |
elif isinstance(value, tuple): | |
for index, item in enumerate(value): | |
if isinstance(item, ast.Node): | |
todo.append((path + ((node, attr, index),), item)) | |
return root_node | |
def walk(self, root_node: Node, transformer: Optional[Callable] = None, | |
**ctx) -> list: | |
transformer = transformer or self.transformer | |
should_stop = False | |
def stop(): | |
nonlocal should_stop | |
should_stop = True | |
collected = [] | |
def collect(x): | |
collected.append(x) | |
fsig = inspect.signature(transformer) | |
wanted_arg_names = tuple(fsig.parameters.keys())[1:] # the first is | |
# always the node | |
all_args = {'stop': stop, 'collect': collect, 'ctx': ctx} | |
wanted_args = {k: v for k, v in all_args.items() | |
if k in wanted_arg_names} | |
try: | |
it = self.iterate(root_node) | |
path, node = it.send(None) | |
while True: | |
should_stop = False | |
# the output of the function will be retained for when the | |
# replacement of the current node will be implemented ;-) | |
if 'path' in wanted_arg_names: | |
node = transformer(node, path=path, **wanted_args) | |
else: | |
node = transformer(node, **wanted_args) | |
path, node = it.send((should_stop, node)) | |
except StopIteration as e: | |
res = e.value | |
finally: | |
it.close() | |
return res, collected | |
__call__ = walk |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment