Skip to content

Instantly share code, notes, and snippets.

@rbellamy
Created November 6, 2020 19:01
Show Gist options
  • Save rbellamy/10b6f44f657f4ea0bd9e624d9912e6ae to your computer and use it in GitHub Desktop.
Save rbellamy/10b6f44f657f4ea0bd9e624d9912e6ae to your computer and use it in GitHub Desktop.
postway and postsplit - a FlyWay-like tool for raw psql and a DAG-based splitter for pg_dump
#!/usr/bin/env python
import argparse
import logging
import os
import re
from enum import Enum
from itertools import groupby, chain
# import matplotlib.pyplot as plt
import networkx as nx
class ParentName(Enum):
none = 1
before = 2
after = 3
def __str__(self):
return self.name
@staticmethod
def from_string(s):
try:
return ParentName[s]
except KeyError:
raise ValueError()
class PostSplit:
"""A class object for the PG dump split utility script.
Split pg_dump files to schema/schema_type file hierarchy
Use with files produced by pg_dump -s
Original author: kmatt - from https://gist.github.com/kmatt/2572360
https://gist.github.com/rbellamy/e61bbe89abe97afb0fd6150190856f4f
"""
def __init__(self):
self.parser = argparse.ArgumentParser(
description='A script for parsing a text pg_dump file and creating either a) per-object files, or b) '
'per-object files in a form consumable by postway/postway')
self.logger = logging.getLogger('pg_dump_split')
self.logger.setLevel(logging.DEBUG)
self.console_handler = logging.StreamHandler() # sys.stderr
self.console_handler.setLevel(
logging.CRITICAL) # set later by set_log_level_from_verbose() in interactive sessions
self.console_handler.setFormatter(
logging.Formatter('[%(levelname)s] %(message)s'))
self.logger.addHandler(self.console_handler)
self.create_public_schema_sql = '''
CREATE SCHEMA public;
GRANT ALL ON SCHEMA public TO postgres;
GRANT ALL ON SCHEMA public TO PUBLIC;
COMMENT ON SCHEMA public IS 'standard public schema';
'''
self.args = None
def _get_sections(self):
infile = self.args.infile
with open(infile) as f:
groups = groupby(f, lambda x: x.startswith('-- TOC'))
for key, value in groups:
if key:
yield chain([next(value)], (next(groups)[1])) # all lines up to -- TOC
def _add_edges_by_parent_name(self, graph):
"""assign edges based on parent_name inspection"""
for parent_n, parent_d in graph.nodes(data=True):
if parent_d['schema_type'] == 'table':
parent_schema = parent_d['schema']
parent_object_name = parent_d['object_name']
for child_n, child_d in graph.nodes(data=True):
child_schema_type = child_d['schema_type']
child_schema = child_d['schema']
child_object_name = child_d['parent_name']
if 'parent_name' in child_d and \
parent_schema == child_schema and \
parent_object_name == child_d['parent_name']:
self.logger.debug(
'[toc table <- {}] edge: {}.{}({}) <- {}.{}({})'.format(child_schema_type,
parent_schema,
parent_object_name,
parent_n,
child_schema,
child_object_name,
child_n))
graph.add_edge(child_n, parent_n)
return graph
def _add_edges_by_sql(self, parent_schema_types, child_schema_types, graph):
"""
- if parent 'object_name' is in 'sql' of child, then parent is dependency of child.
- if 'myent_pkg.r' is in 'check_missing_goals_setup' function, then 'myent_pkg'(TOC == 1765) is a dependency of
'check_missing_goals_setup'(TOC == 4054)
>>> graph.add_edge(1765, 4054)
:param parent_schema_types:
:type parent_schema_types:
:param child_schema_types:
:type child_schema_types:
:param graph:
:type graph:
:return:
:rtype:
"""
for outer_n, outer_d in graph.nodes(data=True):
if outer_d['schema_type'] in parent_schema_types:
parent_schema = outer_d['schema']
parent_object_name = outer_d['object_name']
for inner_n, inner_d in graph.nodes(data=True):
if inner_d['schema_type'] in child_schema_types:
child_schema = inner_d['schema']
child_object_name = inner_d['object_name']
child_sql = inner_d['sql']
if self._is_dependent_sql(parent_schema,
parent_object_name,
child_schema,
child_object_name,
child_sql):
self.logger.debug('[{} <- {}] edge: {}.{}({}) <- {}.{}({})'.format(
'/'.join(parent_schema_types),
'/'.join(child_schema_types),
parent_schema,
parent_object_name,
outer_n,
child_schema,
child_object_name,
inner_n))
graph.add_edge(inner_n, outer_n)
return graph
def _merge_acl(self, graph):
"""
apply schema_type == 'acl' to all types - this effectively removes the acl type
*NOTE: There is a bug - currently does not correctly match with foreign tables, sequences and views*
"""
acl_delete_list = []
for outer_n, outer_d in graph.nodes(data=True):
parent_schema = outer_d['schema']
parent_schema_type = outer_d['schema_type']
parent_object_name = outer_d['object_name']
if parent_schema_type != 'acl':
for inner_n, inner_d in graph.nodes(data=True):
if inner_d['schema_type'] == 'acl':
acl_schema = inner_d['schema']
acl_parent_schema_type = self._get_acl_parent_type(inner_d['sql'])
acl_object_name = inner_d['object_name']
if parent_schema_type == acl_parent_schema_type and \
parent_object_name == acl_object_name and \
parent_schema == acl_schema:
self.logger.debug(
'[{} <- acl] merge: {}.{}({}) <- {}.{}({})'.format(parent_schema_type,
parent_schema,
parent_object_name,
outer_n,
acl_schema,
acl_object_name,
inner_n))
graph.nodes[outer_n]['sql'] = '{}\n{}'.format(outer_d['sql'], inner_d['sql'])
acl_delete_list.append(inner_n)
graph.remove_nodes_from(acl_delete_list)
return graph
@staticmethod
def _is_dependent_sql(parent_schema, parent_object_name, child_schema, child_object_name, sql):
"""
1. If the fully-qualified parent and child names match, then no match.
2. If the schema names are different, then the parent must be matched with a fully-qualified name in the child.
3. If the schema names are the same, then the parent must be matched with a relative name, if a fully-qualified
parent with the same relative name, but different schema, hasn't already matched the child.
:param parent_schema:
:type parent_schema:
:param parent_object_name:
:type parent_object_name:
:param child_schema:
:type child_schema:
:param child_object_name:
:type child_object_name:
:param sql:
:type sql:
:return:
:rtype:
"""
parent = '{}\.{}'.format(parent_schema, parent_object_name)
child = '{}\.{}'.format(child_schema, child_object_name)
relative_regex = re.compile(r'.*[\s,;\(\)]{}[\s,;\(\)].*'.format(parent_object_name))
fully_qualified_regex = re.compile(r'.*[\s,;\(\)]{}[\s,;\(\)].*'.format(parent))
is_dependent_sql = False
for line in sql.splitlines(keepends=True):
if is_dependent_sql or parent == child:
break
if not line.startswith('-- Name'):
if parent_schema != child_schema:
is_dependent_sql = fully_qualified_regex.match(line)
else:
is_dependent_sql = relative_regex.match(line)
return is_dependent_sql
@staticmethod
def _header(owner, schema, set_role):
s0 = 'BEGIN;'
if set_role and owner:
s1 = 'SET LOCAL ROLE {};'.format(owner)
else:
s1 = ''
s2 = 'SET LOCAL check_function_bodies = false;'
s3 = 'SET SEARCH_PATH TO {}, pg_catalog, sys, dbo;'.format(schema)
return '{}\n{}\n{}\n{}\n'.format(s0, s1, s2, s3)
@staticmethod
def _footer():
return 'COMMIT;'
def _drop_trigger(self, schema, table_name, trigger_name):
self.logger.debug('table_name: {}, trigger_name: {}'.format(table_name, trigger_name))
s1 = 'DROP TRIGGER IF EXISTS {} ON {}.{};'.format(trigger_name.lower(), schema, table_name)
s2 = 'DROP TRIGGER IF EXISTS "{}" ON {}.{};'.format(trigger_name.upper(), schema, table_name)
return '{}\n{}\n'.format(s1, s2)
@staticmethod
def _drop_foreign_table(object_name):
s1 = 'DROP FOREIGN TABLE IF EXISTS {};'.format(object_name)
return s1
@staticmethod
def _inject_create_or_replace(schema_type, sql):
if schema_type != 'type':
sql = re.sub(r'CREATE\s{}'.format(schema_type.upper()),
'CREATE OR REPLACE {}'.format(schema_type.upper()),
sql)
return sql
@staticmethod
def _fixup_quoted_create(schema, schema_type, object_name, sql):
sql = re.sub(r'CREATE(.*{}.*)(\"{}\")'.format(schema_type.upper(), object_name),
'CREATE\\1{}'.format(object_name.lower()),
sql)
return sql
@staticmethod
def _inject_postway_separator(schema_type, object_name, sql):
indices = [s.start() for s in re.finditer('CREATE OR REPLACE', sql)]
if len(indices) == 2:
sql = '{}/\n{}'.format(sql[:indices[1]], sql[indices[1]:])
if schema_type == 'package':
indices = [s.start() for s in re.finditer('ALTER PACKAGE'.format(object_name), sql)]
if len(indices) == 1:
sql = '{}/\n{}'.format(sql[:indices[0]], sql[indices[0]:])
return sql
@staticmethod
def _get_acl_parent_type(sql):
parent_types = ['aggregate', 'foreign_table', 'function', 'materialized_view',
'package', 'procedure', 'schema', 'server', 'table', 'trigger', 'type', 'view']
acl_sql = ''.join([line for line in sql.splitlines(keepends=True) if not line.startswith('-- Name')])
return next((pt for pt in parent_types if pt.upper() in acl_sql), None)
def _write_file(self, outdir, schema, schema_type, fname, sql):
self.logger.info('Schema: {}; Type: {}; Name: {}'.format(schema, schema_type, fname))
sqlpath = os.path.join(outdir, schema, schema_type)
if not os.path.exists(sqlpath):
print('*** mkdir {}'.format(sqlpath))
os.makedirs(sqlpath)
sqlf = os.path.join(sqlpath, fname)
self.logger.debug('sqlf: {}'.format(sqlf))
sql = re.sub(r'\n{3,}', r'\n\n', sql.strip())
open(sqlf, 'w').write(sql)
def _get_file_name(self, postway, version, versioned, name):
fname = '{}.sql'.format(name)
if versioned:
version += 1
fname = 'V{}__{}.sql'.format(version, name)
elif postway:
fname = 'R__{}.sql'.format(name)
return fname, version
def parse_arguments(self):
"""
Parse command line arguments.
Sets self.args parameter for use throughout class/script.
"""
self.parser.add_argument('-i', '--infile', required=True,
help='The dump file created using pg_dump in text mode.')
self.parser.add_argument('-o', '--outdir',
help="The directory to use as the parent for the type'd directories. Created if it "
"doesn't exist. Defaults to the current directory.")
self.parser.add_argument('--include-parent-name', type=ParentName.from_string, choices=list(ParentName),
default=ParentName.none,
help='When naming files for objects that have parents (triggers, fk, pk, etc), use '
'the parent name in the file name.')
self.parser.add_argument('--postway', action='store_true',
help='Create postway-compliant migration scripts.')
self.parser.add_argument('--postway-versioned-only', action='store_true',
help='Create postway-compliant VERSIONED migration scripts ONLY. Normally, postsplit '
'creates both versioned and repeatable scripts, with '
'functions/procedures/packages etc. being repeatable and '
'tables/indexs/constraints etc. being versioned. However, EDB EPAS '
'"check_function_bodies=false" DOES NOT WORK for packages. This means that if '
'there are packages that have dependencies, they will fail to build unless they '
'just happen to sort correctly.')
self.parser.add_argument('--postway-version', default=1,
help='Start with postway-version for migration scripts.')
self.parser.add_argument('--plot', action='store_true',
help='Show plot.')
self.parser.add_argument('-V', '--version', action='version', version='%(prog)s 1.0.0',
help='Print the version number of postsplit.')
self.parser.add_argument('-v', '--verbose', action='count', help='verbose level... repeat up to three times.')
self.args = self.parser.parse_args()
return self
def set_log_level_from_verbose(self):
if not self.args.verbose:
self.console_handler.setLevel('ERROR')
elif self.args.verbose == 1:
self.console_handler.setLevel('WARNING')
elif self.args.verbose == 2:
self.console_handler.setLevel('INFO')
elif self.args.verbose >= 3:
self.console_handler.setLevel('DEBUG')
else:
self.logger.critical('UNEXPLAINED NEGATIVE COUNT!')
return self
def build_graph(self):
outdir = self.args.outdir
if outdir == '':
outdir = os.path.dirname(self.args.infile)
graph = nx.DiGraph()
type_regex = re.compile(
r'-- Name: ([-\w\s\.\$\^]+)(?:\([-\w\s\[\],.\*\"]*\))?; Type: ([-\w\s]+); Schema: ([-\w]+); Owner: ([-\w]*)(?:; Tablespace: )?([-\w]*)\n',
flags=re.IGNORECASE)
toc_regex = re.compile(r'-- TOC entry (\d*) \(class (\d+) OID (\d+)\)\n', flags=re.IGNORECASE)
dep_regex = re.compile(r'-- Dependencies: (.*)')
name_regex = re.compile(r'(\w+)\s+(\w+)')
user_mapping_name_regex = re.compile(r'USER\sMAPPING\s([\w\s]+)')
name, schema_type, schema, owner, tablespace = [''] * 5
toc_id = 0
sql = ''
type_line = ''
# build the graph based on TOC
for sec in self._get_sections():
# remove all lines with just `--`
section = [x for x in list(sec) if x != '--\n']
toc_id = toc_regex.search(section[0]).group(1)
graph.add_node(toc_id, schema='', schema_type='', name='', sql='')
if section[1].startswith('-- Dependencies'):
dep_ids = dep_regex.search(section[1]).group(1).split()
# edges defined by TOC dependencies
for dep_id in set(dep_ids):
# dep_id is a dependency of toc_id
self.logger.debug('TOC edge: {}({}) <- {}({})'.format('unknown',
dep_id,
'unknown',
toc_id))
graph.add_edge(toc_id, dep_id)
type_line = section[2]
else:
type_line = section[1]
name, schema_type, schema, owner, tablespace = type_regex.search(type_line).groups()
# ignore the schema_version table if building for postway
if self.args.postway and 'schema_version' in name:
continue
schema_type = schema_type.replace(' ', '_').lower()
if schema_type == 'user_mapping':
parent_name = ''
name = user_mapping_name_regex.search(name).group(1).replace(' ', '_').lower()
object_name = name
else:
name_match = name_regex.match(name)
if name_match:
parent_name = name_regex.search(name).group(1)
object_name = name_regex.search(name).group(2)
if self.args.include_parent_name == ParentName.none:
name = name_regex.search(name).group(2)
else:
if self.args.include_parent_name == ParentName.before:
name = '{}_{}'.format(name_regex.search(name).group(1), name_regex.search(name).group(2))
else:
name = '{}_{}'.format(name_regex.search(name).group(2), name_regex.search(name).group(1))
else:
parent_name = ''
object_name = name
name = name.lower()
if self.args.postway:
name = '{}_{}'.format(name, schema_type)
if schema == '-':
schema = 'public'
# section = anything that doesn't start with:
# 1. -- TOC
# 2. -- Dependencies
# 3. SET search_path
sql = [y for y in section if not (y.startswith('-- TOC') or
y.startswith('-- Dependencies') or
y.startswith('SET search_path'))]
self.logger.debug(
'Owner: {}; Schema: {}; Type: {}; Name: {}; Parent Name: {}; Object Name: {}'.format(owner, schema,
schema_type, name,
parent_name,
object_name))
graph.nodes[toc_id]['schema'] = schema
graph.nodes[toc_id]['schema_type'] = schema_type
graph.nodes[toc_id]['name'] = name
graph.nodes[toc_id]['parent_name'] = parent_name
graph.nodes[toc_id]['object_name'] = object_name
graph.nodes[toc_id]['owner'] = owner
graph.nodes[toc_id]['sql'] = ''.join(sql)
# prune the graph - remove all nodes that have no name attribute
prune_list = [pr for pr, d in graph.nodes(data=True) if 'name' not in d or d['name'] is None or d['name'] == '']
graph.remove_nodes_from(prune_list)
# find edges where type is a dependency of type/function/procedure/package
graph = self._add_edges_by_sql(['type'], ['type', 'function', 'procedure', 'package'], graph)
# find edges where table/view is a dependency of view
graph = self._add_edges_by_sql(['table', 'view'], ['view'], graph)
# find edges where table is a dependency of index
graph = self._add_edges_by_sql(['table'], ['index'], graph)
# find edges where table is a dependency of constraints/fk_constraints
graph = self._add_edges_by_parent_name(graph)
# merge acl into parent script
graph = self._merge_acl(graph)
# add public schema
# graph.add_node(0, schema='public', schema_type='schema', name='public_schema',
# sql=self.create_public_schema_sql)
return graph
def prepare_and_write(self, toc_id, postway, version, versioned_only, outdir, node):
self.logger.debug('{}: {}'.format(toc_id, node))
owner, schema, schema_type, name, parent_name, object_name, sql = map(node.get,
('owner', 'schema', 'schema_type', 'name',
'parent_name', 'object_name', 'sql'))
if name is not None and name > '':
self.logger.info(
'Owner: {}; Schema: {}; Type: {}; Name: {}; Parent Name: {}; Object Name: {}'.format(owner, schema,
schema_type, name,
parent_name,
object_name))
if postway and schema_type != 'schema':
# sql = self._inject_postway_separator(schema_type, object_name, sql)
set_role = True
sql = self._fixup_quoted_create(schema, schema_type, object_name, sql)
if schema_type in ['function', 'procedure', 'package', 'view', 'trigger', 'foreign_table']:
fname, version = self._get_file_name(postway, version, versioned_only, name)
if schema_type == 'trigger':
sql = '{}\n{}'.format(self._drop_trigger(schema, parent_name, object_name), sql)
elif schema_type == 'foreign_table':
sql = '{}\n{}'.format(self._drop_foreign_table(object_name), sql)
else:
sql = self._inject_create_or_replace(schema_type, sql)
else:
fname, version = self._get_file_name(postway, version, True, name)
sql = '{}\n{}\n{}'.format(self._header(owner, schema, set_role), sql, self._footer())
else:
fname, version = self._get_file_name(False, version, False, name)
self._write_file(outdir, schema, schema_type, fname, sql)
return version
if __name__ == '__main__':
postsplit = PostSplit().parse_arguments().set_log_level_from_verbose()
outdir = postsplit.args.outdir
if outdir is None or outdir == '':
outdir = os.path.dirname(postsplit.args.infile)
print('Start parsing {}.'.format(postsplit.args.infile))
postsplit.logger.debug('args: {}'.format(postsplit.args))
g = postsplit.build_graph()
# if postsplit.args.plot:
# print('Plot the graph.')
# nx.draw(g, with_labels=True, font_weight='bold')
# plt.show()
print('Sort the DAG.')
try:
s = list(reversed(list(nx.topological_sort(g))))
postsplit.logger.debug(s)
print('Write the files.')
p = postsplit.args.postway
p_v = postsplit.args.postway_version # postway version - used in version migration scripts
p_v_o = postsplit.args.postway_versioned_only
for n in s:
p_v = postsplit.prepare_and_write(n, p, p_v, p_v_o, outdir, g.nodes[n])
print('Done processing {}.'.format(postsplit.args.infile))
except Exception as err:
c = nx.find_cycle(g)
postsplit.logger.debug(c)
postsplit.logger.error(err)
exit(1)
#!/usr/bin/env python
import argparse
import collections.abc
import logging
import ntpath
import os
import re
import subprocess
import zlib
from collections import namedtuple
from enum import Enum
from glob import glob
from timeit import default_timer as timer
import psycopg2
from packaging import version
from psycopg2._psycopg import AsIs
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from psycopg2.extensions import make_dsn
from psycopg2.extras import LoggingConnection
class OrderedSet(collections.abc.MutableSet):
def __init__(self, iterable=None):
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __len__(self):
return len(self.map)
def __contains__(self, key):
return key in self.map
def add(self, key):
if key not in self.map:
end = self.end
curr = end[1]
curr[2] = end[1] = self.map[key] = [key, curr, end]
def discard(self, key):
if key in self.map:
key, prev, next = self.map.pop(key)
prev[2] = next
next[1] = prev
def __iter__(self):
end = self.end
curr = end[2]
while curr is not end:
yield curr[0]
curr = curr[2]
def __reversed__(self):
end = self.end
curr = end[1]
while curr is not end:
yield curr[0]
curr = curr[1]
def pop(self, last=True):
if not self:
raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
class PostWayCommand(Enum):
baseline = 1
version = 2
clean = 3
migrate = 4
validate = 5
info = 6
def __str__(self):
return self.name
@staticmethod
def from_string(s):
try:
return PostWayCommand[s]
except KeyError:
raise ValueError()
class PostWay:
"""
PostWay is a wholly new derivative of Flyway from BoxFuse. It's a derivative in that it derives it's workflow and
the schema_version table directly from Flyway. It's wholly new in that the code and focus are specific to
EDB/Postgres, and the method for executing SQL scripts uses native EDB/Postgres tools - specifically psql.
Management routines are managed via psycopg2. Migration scripts are run through psql.
"""
SchemaRecord = namedtuple('SchemaRecord',
'name, script, baseline_version, max_current_version, max_schema_version, '
'repeatable_schema_migrations, versioned_schema_migrations, '
'diff_repeatable_migrations, diff_versioned_migrations, changed_repeatable_migrations')
MigrationRecord = namedtuple('MigrationRecord', 'script, version, description, checksum')
ChangedMigrationRecord = namedtuple('ChangedMigrationRecord',
'script, version, description, old_checksum, new_checksum')
ExecutingMigrationRecord = namedtuple('ExecutingMigrationRecord', 'schema, script, version, description, checksum')
def __init__(self):
self.parser = argparse.ArgumentParser(
add_help=False,
description='''
PostWay is a wholly new derivative of Flyway from BoxFuse. It's a derivative in that it derives it's workflow and
the schema_version table directly from Flyway. It's wholly new in that the code and focus are specific to
EDB/Postgres, and the method for executing SQL scripts uses native EDB/Postgres tools - specifically psql.
Management routines are managed via psycopg2. Migration scripts are run through psql.
''')
self.logger = logging.getLogger('postway')
self.logger.setLevel(logging.DEBUG)
self.console_handler = logging.StreamHandler() # sys.stderr
self.console_handler.setLevel(
logging.CRITICAL) # set later by set_log_level_from_verbose() in interactive sessions
self.console_handler.setFormatter(logging.Formatter('[%(levelname)s] %(message)s'))
self.logger.addHandler(self.console_handler)
self.args = None
self.conn = None
self.is_db = False
self.schema_records = None
self.is_max_schema_version = False
self.is_baselined = False
self.is_validated = False
self.versioned_prefix = 'V'
self.repeatable_prefix = 'R'
self.migration_separator = '__'
self.migration_suffix = '.sql'
self.baseline_version = 1
self.baseline_description = '<< Flyway Baseline >>'
self.max_schema_version_sql = '''
SELECT version
FROM %(nspname)s.schema_version
WHERE installed_rank = (SELECT MAX(installed_rank) FROM %(nspname)s.schema_version WHERE version IS NOT NULL);
'''
self.get_repeatable_migrations_sql = '''
SELECT script, version, description, checksum
FROM %(nspname)s.schema_version
WHERE version IS NULL
ORDER BY installed_rank ASC
'''
self.get_versioned_migrations_sql = '''
SELECT script, version, description, checksum
FROM %(nspname)s.schema_version
WHERE version IS NOT NULL AND script != %(description)s
ORDER BY installed_rank ASC
'''
self.schema_exists_sql = '''
SELECT EXISTS(
SELECT 1 FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname = %(schema)s
AND c.relkind = 'r'
AND c.oid NOT IN (SELECT inhrelid FROM pg_inherits)
);
'''
self.schema_version_exists_sql = '''
SELECT EXISTS(
SELECT 1 FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname = %(schema)s
AND c.relkind = 'r'
AND c.oid NOT IN (SELECT inhrelid FROM pg_inherits)
AND c.relname = 'schema_version'
);
'''
self.create_schema_version_sql = '''
--
-- Copyright 2010-2017 Boxfuse GmbH
--
-- Licensed under the Apache License, Version 2.0 (the 'License');
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an 'AS IS' BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
CREATE TABLE %(nspname)s.schema_version (
installed_rank SERIAL PRIMARY KEY,
version VARCHAR(50),
description VARCHAR(200) NOT NULL,
type VARCHAR(20) NOT NULL,
script VARCHAR(1000) NOT NULL,
checksum NUMBER,
installed_by VARCHAR(100) NOT NULL,
installed_on TIMESTAMP NOT NULL DEFAULT now(),
execution_time INTEGER NOT NULL,
success BOOLEAN NOT NULL
) WITH (
OIDS=FALSE
);
CREATE INDEX schema_version_s_idx ON %(nspname)s.schema_version (version, success);
'''
self.is_baselined_sql = 'SELECT EXISTS(SELECT 1 FROM %(nspname)s.schema_version WHERE script = %(script)s);'
self.baselined_version_sql = 'SELECT version FROM %(nspname)s.schema_version WHERE script = %(script)s;'
self.insert_schema_version_sql = '''
INSERT INTO %(nspname)s.schema_version (version, description, "type", script, checksum, installed_by, execution_time, success)
VALUES (%(version)s, %(description)s, %(type)s, %(script)s, %(checksum)s, %(installed_by)s, %(execution_time)s, %(success)s);
'''
self.drop_schema_sql = 'DROP SCHEMA IF EXISTS %(nspname)s CASCADE;'
self.db_no_more_connections_sql = "UPDATE pg_database SET datallowconn = 'false' WHERE datname = %(dbname)s;"
self.db_disconnect_sql = '''
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = %(dbname)s;
'''
self.db_drop_sql = 'DROP DATABASE IF EXISTS %(dbname)s'
self.db_create_sql = '''
CREATE DATABASE %(dbname)s WITH OWNER=%(owner)s LC_COLLATE='C' LC_CTYPE='C' TEMPLATE='template0';
'''
@property
def _dsn(self):
return make_dsn(host=self.args.host,
port=self.args.port,
dbname=self.args.dbname,
user=self.args.user,
password=self.args.password)
def _get_search_path(self):
with self.conn:
with self.conn.cursor() as curs:
curs.execute('SHOW search_path;')
search_path = curs.fetchone()
return search_path
def _set_search_path(self, schema):
search_path = self._get_search_path()
self.logger.debug('old search_path: {}'.format(search_path[0]))
with self.conn:
with self.conn.cursor() as curs:
curs.execute('SET search_path=%(schema)s,%(search_path)s',
{'schema': schema, 'search_path': AsIs(', '.join(search_path))})
self.logger.debug('new search_path: {}'.format(self._get_search_path()[0]))
def _is_schema(self, schema):
with self.conn:
with self.conn.cursor() as curs:
curs.execute(self.schema_exists_sql, {'schema': schema})
res = curs.fetchone()
is_schema_version = res[0]
self.logger.warning('schema exists: {}'.format(is_schema_version))
return is_schema_version
def _is_schema_version(self, schema):
with self.conn:
with self.conn.cursor() as curs:
curs.execute(self.schema_version_exists_sql, {'schema': schema})
res = curs.fetchone()
is_schema_version = res[0]
self.logger.warning('schema_version exists: {}'.format(is_schema_version))
return is_schema_version
def _schema_baseline(self, schema, script):
with self.conn:
with self.conn.cursor() as curs:
curs.execute(self.baselined_version_sql, {'nspname': AsIs(schema), 'script': script})
res = curs.fetchone()
if res is not None:
baselined_version = res[0]
self.logger.warning('Schema is baselined')
return baselined_version
else:
self.logger.warning('Schema is not baselined')
return 0
def _get_repeatable_schema_migrations(self, schema):
with self.conn:
with self.conn.cursor() as curs:
curs.execute(self.get_repeatable_migrations_sql, {'nspname': AsIs(schema)})
res = curs.fetchall()
ret = list(map(self.MigrationRecord._make, res))
return ret
def _get_versioned_schema_migrations(self, schema, baseline_description):
with self.conn:
with self.conn.cursor() as curs:
curs.execute(self.get_versioned_migrations_sql,
{'nspname': AsIs(schema), 'description': baseline_description})
res = curs.fetchall()
# noinspection PyTypeChecker
ret = list(map(self.MigrationRecord._make, [(r[0], version.parse(r[1]), r[2], r[3]) for r in res]))
return ret
def _get_current_migrations(self, migration_base_directory, schema, repeatable_prefix, versioned_prefix,
migration_separator, migration_suffix):
max_current_version = 0
repeatable_migrations = []
versioned_migrations = []
repeatable_migrations = postway._get_repeatable_migrations(migration_base_directory, schema, repeatable_prefix,
migration_separator, migration_suffix)
versioned_migrations = postway._get_versioned_migrations(migration_base_directory, schema, versioned_prefix,
migration_separator, migration_suffix)
if versioned_migrations:
max_current_version = max(versioned_migrations, key=lambda mi: mi.version).version
return max_current_version, repeatable_migrations, versioned_migrations
def _get_schema_migrations(self, schema, baseline_description, baseline_version, user):
repeatable_schema_migrations = self._get_repeatable_schema_migrations(schema)
versioned_schema_migrations = self._get_versioned_schema_migrations(schema, baseline_description)
return repeatable_schema_migrations, versioned_schema_migrations
def _get_repeatable_migrations(self, migration_base_directory, schema, prefix, separator, suffix):
files = self._get_files(os.path.join(migration_base_directory, schema),
'{}*{}*{}'.format(prefix, separator, suffix))
migration_files = [self._repeatable_file_parts(schema, f, separator, prefix, suffix) for f in files]
return sorted(migration_files, key=lambda fi: self._path_leaf(fi.script))
def _get_versioned_migrations(self, migration_base_directory, schema, prefix, separator, suffix):
files = self._get_files(os.path.join(migration_base_directory, schema),
'{}*{}*{}'.format(prefix, separator, suffix))
migration_files = [self._versioned_file_parts(f, separator, prefix, suffix) for f in files]
return sorted(migration_files, key=lambda fi: fi.version)
def _repeatable_file_parts(self, schema, file_path, separator, prefix, suffix):
description, file_parts = self._migration_file_parts(file_path, separator, suffix)
checksum = self._crc(file_path)
return self.MigrationRecord(file_path, None, description, checksum)
def _versioned_file_parts(self, file_path, separator, prefix, suffix):
description, file_parts = self._migration_file_parts(file_path, separator, suffix)
ver = version.parse(file_parts[0].replace(prefix, '').replace('_', '.'))
checksum = self._crc(file_path)
return self.MigrationRecord(file_path, ver, description, checksum)
def _migration_file_parts(self, file_path, separator, suffix):
file_name = self._path_leaf(file_path)
file_parts = file_name.split(separator)
description = file_parts[1].replace(suffix, '').replace('_', ' ')
return description, file_parts
def _create_schema_version(self, schema):
if self._is_schema_version(schema):
return True
else:
with self.conn:
with self.conn.cursor() as curs:
self.logger.warning('Creating schema_version in {}'.format(schema))
curs.execute(self.create_schema_version_sql, {'nspname': AsIs(schema)})
return True
def _execute_migration(self, migration, user, bindir, host, port, dbname, password):
penv = os.environ.copy()
penv['PGPASSWORD'] = password
psql = ['{}/psql'.format(bindir), '-h', host, '-p', '{}'.format(port), '-d', dbname, '-U', user]
with self.conn:
with self.conn.cursor() as curs:
start, end, success = self._execute_psql(penv, psql, migration.script)
if not success:
exit(1)
execution_time = end - start
curs.execute(self.insert_schema_version_sql,
{'nspname': AsIs(migration.schema),
'version': None if migration.version is None else '{}'.format(migration.version),
'description': migration.description,
'type': 'SQL',
'script': migration.script,
'checksum': migration.checksum,
'installed_by': user,
'execution_time': int(round(execution_time * 1000)),
'success': True})
def _execute_psql(self, penv, psql, script):
success = True
error_regex = re.compile(r'psql\.bin:(.*):([\d]+):\sERROR:\s+(.*)')
warning_regex = re.compile(r'psql\.bin:(.*):([\d]+):\sNOTICE:\s+(.*)')
if self.args.verbose is not None and self.args.verbose >= 3:
psql.extend(['-e'])
psql.extend(['-f', script])
self.logger.warning(subprocess.list2cmdline(psql))
start = timer()
end, output = self._execute_process(psql, penv)
error_matches = error_regex.search(output)
if error_matches:
success = False
self.logger.error(
'Error found at {start}-{end}: {match}'.format(start=error_matches.start(),
end=error_matches.end(),
match=error_matches.group()))
warning_matches = warning_regex.search(output)
if warning_matches:
self.logger.warning(
'Warning found at {start}-{end}: {match}'.format(start=warning_matches.start(),
end=warning_matches.end(),
match=warning_matches.group()))
return start, end, success
def _execute_clean_db(self, user, host, port, dbname, password):
self._execute_dropdb(user, host, port, dbname, password)
self._execute_createdb(user, host, port, dbname, password)
def _execute_dropdb(self, user, host, port, dbname, password):
start = timer()
db_dsn = make_dsn(host=host,
port=port,
dbname='postgres',
user=user,
password=password)
with psycopg2.connect(db_dsn, connection_factory=LoggingConnection) as conn:
conn.initialize(self.logger)
with conn.cursor() as curs:
self.logger.warning('Setting database {} to accept no more connections.'.format(dbname))
curs.execute(self.db_no_more_connections_sql, {'dbname': dbname})
with conn.cursor() as curs:
self.logger.warning('Disconnecting database {}.'.format(dbname))
curs.execute(self.db_disconnect_sql, {'dbname': dbname})
self.logger.warning('Dropping database {}.'.format(dbname))
connect = psycopg2.connect(db_dsn, connection_factory=LoggingConnection)
connect.initialize(self.logger)
connect.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
cursor = connect.cursor()
cursor.execute(self.db_drop_sql, {'dbname': AsIs(dbname)})
end = timer()
return start, end
def _execute_createdb(self, user, host, port, dbname, password):
start = timer()
db_dsn = make_dsn(host=host,
port=port,
dbname='postgres',
user=user,
password=password)
self.logger.warning('Creating database {}.'.format(dbname))
connect = psycopg2.connect(db_dsn, connection_factory=LoggingConnection)
connect.initialize(self.logger)
connect.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
cursor = connect.cursor()
cursor.execute(self.db_create_sql, {'dbname': AsIs(dbname), 'owner': user})
end = timer()
return start, end
def _execute_process(self, process, penv):
try:
stdout = subprocess.check_output(process, stderr=subprocess.STDOUT, env=penv)
end = timer()
output = stdout.decode('utf-8')
self.logger.debug(output)
except subprocess.CalledProcessError as db_exc:
output = db_exc.output.decode('utf-8')
self.logger.error(output)
raise
return end, output
# noinspection PyTypeChecker
def _find_changed_migrations(self, expected, difference):
changed = []
for exp in expected:
for dif in difference:
if exp.script == dif.script:
changed.append(self.ChangedMigrationRecord._make(list(exp) + [dif.checksum]))
return OrderedSet(changed)
@staticmethod
def _crc(file_path):
crc = 0
with open(file_path, "rb") as f:
data = f.read()
crc = zlib.crc32(data) & 0xffffffff
return crc
@staticmethod
def _path_leaf(path):
head, tail = ntpath.split(path)
return tail or ntpath.basename(head)
@staticmethod
def _get_files(path, glob_pattern):
files = [y for x in os.walk(path) for y in
glob(os.path.join(x[0], glob_pattern))]
return files
def set_log_level_from_verbose(self):
if not self.args.verbose:
self.console_handler.setLevel('ERROR')
elif self.args.verbose == 1:
self.console_handler.setLevel('WARNING')
elif self.args.verbose == 2:
self.console_handler.setLevel('INFO')
elif self.args.verbose >= 3:
self.console_handler.setLevel('DEBUG')
else:
self.logger.critical('UNEXPLAINED NEGATIVE COUNT!')
return self
def parse_arguments(self):
"""
Parse command line arguments.
Sets self.args parameter for use throughout class/script.
"""
self.parser.add_argument('-m', '--migration-base-directory',
default=os.getcwd(),
help='The base directory in which to look for schema-named directories for migration '
'scripts.')
self.parser.add_argument('command',
default=PostWayCommand.validate,
type=PostWayCommand.from_string,
choices=list(PostWayCommand),
help='The PostWay command to execute.')
self.parser.add_argument('-h', '--host', default='localhost', help='The EPAS host.')
self.parser.add_argument('--port', default=5444, help='The EPAS port.')
self.parser.add_argument('-U', '--user', required=True, help='The EPAS username.')
self.parser.add_argument('-W', '--password', required=True, help='The EPAS password.')
self.parser.add_argument('-d', '--dbname', required=True, help='The EPAS database.')
self.parser.add_argument('-s', '--schema', dest='schemas', action='append',
help='The EPAS schema(s). If none, postway assumes a full DB run, iterating '
'over each *_schema.sql file in the "{MIGRATION_DIRECTORY}/public/schema" '
'directory. Each use appends a schema to the list.')
self.parser.add_argument('-b', '--bindir',
required=True,
help='The path to the EPAS bin directory where psql is located. Required to actually '
'run the migration scripts.')
self.parser.add_argument('-V', '--version', action='version', version='%(prog)s 1.0.0',
help='Print the version number of postway.')
self.parser.add_argument('-v', '--verbose', action='count', help='verbose level... repeat up to three times.')
self.parser.add_argument('--help', action='help', default=argparse.SUPPRESS,
help=argparse._('show this help message and exit'))
self.args = self.parser.parse_args()
return self
def get_schemas(self):
if not self.schema_records:
schemas = self.args.schemas
directory = self.args.migration_base_directory
suffix = self.migration_suffix
files = self._get_files(os.path.join(directory, 'public', 'schema'), '*_schema{}'.format(suffix))
public = self.SchemaRecord('public', None, None, None, None, None, None, None, None, None)
if not schemas:
self.schema_records = [self.SchemaRecord._make(
[self._path_leaf(f).replace('_schema.sql', ''), f, None, None, None, None, None, None, None, None])
for f in files]
self.schema_records.extend([public])
self.is_db = True
else:
self.schema_records = [self.SchemaRecord._make(
[self._path_leaf(f).replace('_schema.sql', ''), f, None, None, None, None, None, None, None, None])
for s in schemas for f in files if s in f]
if 'public' in schemas:
self.schema_records.extend([public])
return self
def connect(self):
if not self.conn:
self.conn = psycopg2.connect(self._dsn, connection_factory=LoggingConnection)
self.conn.initialize(self.logger)
return self
def do_clean(self, user, bindir, host, port, dbname, password):
if self.is_db:
self._execute_clean_db(user, host, port, dbname, password)
self.connect()
for sr in self.schema_records:
if sr.name != 'public':
with self.conn:
with self.conn.cursor() as curs:
self.logger.warning('Dropping {} schema'.format(sr.name))
curs.execute(self.drop_schema_sql, {'nspname': AsIs(sr.name)})
self.logger.warning('Creating {} schema'.format(sr.name))
penv = os.environ.copy()
penv['PGPASSWD'] = password
psql = ['{}/psql'.format(bindir), '-h', host, '-p', '{}'.format(port), '-d', dbname, '-U', user]
start, end, success = self._execute_psql(penv, psql, sr.script)
print('{} cleaned'.format(sr.name))
if not success:
exit(1)
def do_baseline(self, baseline_version, baseline_description, user):
self.connect()
if not self.is_baselined:
schema_records = []
for sr in self.schema_records:
schema_baseline_version = -1
if self._is_schema(sr.name) and self._create_schema_version(sr.name):
schema_baseline_version = self._schema_baseline(sr.name, baseline_description)
if schema_baseline_version == 0:
with self.conn:
with self.conn.cursor() as curs:
self.logger.warning(
'Inserting baseline version for : {}'.format(sr.name, baseline_version))
schema_baseline_version = baseline_version
curs.execute(self.insert_schema_version_sql,
{'nspname': AsIs(sr.name),
'version': baseline_version,
'description': baseline_description,
'type': 'SQL',
'script': baseline_description,
'checksum': 0,
'installed_by': user,
'execution_time': 0,
'success': True})
schema_records.extend(self.SchemaRecord(sr.name, sr.script, baseline_version, sr.max_current_version,
sr.max_schema_version,
sr.repeatable_schema_migrations, sr.versioned_schema_migrations,
sr.diff_repeatable_migrations, sr.diff_versioned_migrations,
sr.changed_repeatable_migrations))
print('Baseline schema_version for {}: {}'.format(sr.name, schema_baseline_version))
if schema_records:
self.schema_records = schema_records
self.is_baselined = True
def get_max_schema_version(self, baseline_version, baseline_description, user):
self.connect()
if not self.is_max_schema_version:
schema_records = []
for sr in self.schema_records:
max_schema_version = 0
self.do_baseline(baseline_version, baseline_description, user)
with self.conn:
with self.conn.cursor() as curs:
curs.execute(self.max_schema_version_sql, {'nspname': AsIs(sr.name)})
res = curs.fetchone()
max_schema_version = version.parse(res[0])
schema_records.append(self.SchemaRecord(sr.name, sr.script, sr.baseline_version, sr.max_current_version,
max_schema_version,
sr.repeatable_schema_migrations, sr.versioned_schema_migrations,
sr.diff_repeatable_migrations, sr.diff_versioned_migrations,
sr.changed_repeatable_migrations))
print('Max schema_version for {}: {}'.format(sr.name, max_schema_version))
if schema_records:
self.schema_records = schema_records
self.is_max_schema_version = True
def migrate(self, baseline_version, baseline_description, user, bindir, host, port, dbname, password,
migration_base_directory, repeatable_prefix, versioned_prefix, migration_separator,
migration_suffix):
self.connect()
print('Beginning migrations.')
self.validate(baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix,
versioned_prefix, migration_separator, migration_suffix)
diff_versioned_migrations = []
diff_repeatable_migrations = []
for sr in self.schema_records:
if sr.diff_versioned_migrations:
# noinspection PyTypeChecker
diff_versioned_migrations.extend(
map(self.ExecutingMigrationRecord._make,
[(sr.name, d_v.script, d_v.version, d_v.description, d_v.checksum)
for d_v in sr.diff_versioned_migrations]))
if sr.diff_repeatable_migrations:
# noinspection PyTypeChecker
diff_repeatable_migrations.extend(
map(self.ExecutingMigrationRecord._make,
[(sr.name, d_r.script, d_r.version, d_r.description, d_r.checksum)
for d_r in sr.diff_repeatable_migrations]))
versioned_migrations = sorted(diff_versioned_migrations, key=lambda v: v.version)
repeatable_migrations = sorted(diff_repeatable_migrations, key=lambda r: r.script)
versioned = 0
repeatable = 0
try:
for m in versioned_migrations:
self._execute_migration(m, user, bindir, host, port, dbname, password)
versioned += 1
for m in repeatable_migrations:
self._execute_migration(m, user, bindir, host, port, dbname, password)
repeatable += 1
finally:
print('Executed - Versioned: {}, Repeatable: {}'.format(versioned, repeatable))
def validate(self, baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix,
versioned_prefix, migration_separator, migration_suffix):
"""
Determine if the migration is valid by checking the following metadata about the current migration:
1. Max applied version: the max version in the schema_version table.
2. Max current version: the max version of the scripts.
3. The number of versioned scripts that need to be applied:
a. new versioned scripts.
b. throw an error if a versioned script has been changed (checksum is different).
4. The number of repeatable migration scripts that need to be applied:
a. new repeatable scripts.
b. old scripts that have changed (checksum is different).
5. The number of unchanged versioned scripts.
6. The number of unchanged repeatable scripts.
:param sr:
:type sr:
:param baseline_version:
:type baseline_version:
:param baseline_description:
:type baseline_description:
:param user:
:type user:
:param migration_base_directory:
:type migration_base_directory:
:param repeatable_prefix:
:type repeatable_prefix:
:param versioned_prefix:
:type versioned_prefix:
:param migration_separator:
:type migration_separator:
:param migration_suffix:
:type migration_suffix:
:return:
:rtype:
"""
self.connect()
print('Validating migrations.')
if not self.is_validated:
self.get_max_schema_version(baseline_version, baseline_description, user)
schema_records = []
for sr in self.schema_records:
max_current_version, \
repeatable_migrations, \
versioned_migrations = self._get_current_migrations(migration_base_directory, sr.name,
repeatable_prefix,
versioned_prefix, migration_separator,
migration_suffix)
repeatable_schema_migrations, \
versioned_schema_migrations = self._get_schema_migrations(sr.name, baseline_description,
baseline_version, user)
repeatable_expected = OrderedSet(repeatable_schema_migrations)
repeatable_found = OrderedSet(repeatable_migrations)
versioned_expected = OrderedSet(versioned_schema_migrations)
versioned_found = OrderedSet(versioned_migrations)
diff_repeatable_migrations = OrderedSet(
sorted(repeatable_found - repeatable_expected, key=lambda r: r.script))
diff_versioned_migrations = OrderedSet(
sorted(versioned_found - versioned_expected, key=lambda v: v.version))
changed_repeatable_migrations = self._find_changed_migrations(repeatable_expected,
diff_repeatable_migrations)
changed_versioned_migrations = self._find_changed_migrations(versioned_expected,
diff_versioned_migrations)
schema_records.append(self.SchemaRecord(sr.name, sr.script, sr.baseline_version, max_current_version,
sr.max_schema_version, repeatable_schema_migrations,
versioned_schema_migrations, diff_repeatable_migrations,
diff_versioned_migrations, changed_repeatable_migrations))
if bool(changed_versioned_migrations):
self.logger.error('Applied versioned migrations changed for {} - THIS SHOULD NEVER HAPPEN!!!')
self.logger.error(
'YOU MUST REVERT THE CHANGES TO THESE SCRIPTS, AND CREATE NEW VERSIONED MIGRATIONS.')
self.logger.error(sr.name, changed_versioned_migrations)
exit(1)
print('{} max applied: {}, current: {} version.'
.format(sr.name, sr.max_schema_version, max_current_version))
print('versioned scripts ready to run for {}: {}'
.format(sr.name, len(diff_versioned_migrations)))
print('repeatable scripts ready to run for {}: {} ({} changed)'
.format(sr.name, len(diff_repeatable_migrations), len(changed_repeatable_migrations)))
self.logger.warning('{} versioned: {}'.format(sr.name, diff_versioned_migrations))
self.logger.warning('{} repeatable: {}'.format(sr.name, diff_repeatable_migrations))
self.logger.warning('{} changed repeatable: {}'.format(sr.name, changed_repeatable_migrations))
if schema_records:
self.schema_records = schema_records
self.is_validated = True
def info(self, baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix,
versioned_prefix, migration_separator, migration_suffix):
self.connect()
self.validate(baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix,
versioned_prefix, migration_separator, migration_suffix)
versioned_schema_migrations_count = 0
diff_versioned_migrations_count = 0
repeatable_schema_migrations_count = 0
changed_repeatable_migrations_count = 0
diff_repeatable_migrations_count = 0
for sr in self.schema_records:
versioned_schema_migrations_count += len(sr.versioned_schema_migrations)
print('{} VERSIONED MIGRATIONS - APPLIED: {}'.format(sr.name, len(sr.versioned_schema_migrations)))
for v_s in sr.versioned_schema_migrations:
self.logger.info(v_s)
diff_versioned_migrations_count += len(sr.diff_versioned_migrations)
print('{} VERSIONED MIGRATIONS - PENDING: {}'.format(sr.name, len(sr.diff_versioned_migrations)))
for d_v in sr.diff_versioned_migrations:
self.logger.info(d_v)
repeatable_schema_migrations_count += len(sr.repeatable_schema_migrations)
print('{} REPEATABLE MIGRATIONS - APPLIED: {}'.format(sr.name, len(sr.repeatable_schema_migrations)))
for r_s in sr.repeatable_schema_migrations:
self.logger.info(r_s)
changed_repeatable_migrations_count += len(sr.changed_repeatable_migrations)
print(
'{} REPEATABLE MIGRATIONS - CHANGED/PENDING: {}'.format(sr.name, len(sr.changed_repeatable_migrations)))
for c_r in sr.changed_repeatable_migrations:
self.logger.info(c_r)
diff_repeatable_migrations_count += len(sr.diff_repeatable_migrations)
print('{} REPEATABLE MIGRATIONS - PENDING: {}'.format(sr.name, len(sr.diff_repeatable_migrations)))
for d_r in sr.diff_repeatable_migrations:
self.logger.info(d_r)
if self.is_db:
print('TOTAL VERSIONED MIGRATIONS - APPLIED: {}'.format(versioned_schema_migrations_count))
print('TOTAL VERSIONED MIGRATIONS - PENDING: {}'.format(diff_versioned_migrations_count))
print('TOTAL REPEATABLE MIGRATIONS - APPLIED: {}'.format(repeatable_schema_migrations_count))
print('TOTAL REPEATABLE MIGRATIONS - CHANGED/PENDING: {}'.format(changed_repeatable_migrations_count))
print('TOTAL REPEATABLE MIGRATIONS - PENDING: {}'.format(diff_repeatable_migrations_count))
if __name__ == '__main__':
postway = PostWay().parse_arguments().set_log_level_from_verbose().get_schemas()
postway.logger.debug('args: {}'.format(postway.args))
m_dir = postway.args.migration_base_directory
b_dir = postway.args.bindir
u = postway.args.user
pwd = postway.args.password
h = postway.args.host
p = postway.args.port
d = postway.args.dbname
v_prefix = postway.versioned_prefix
r_prefix = postway.repeatable_prefix
sep = postway.migration_separator
suf = postway.migration_suffix
bs_ver = postway.baseline_version
bs_desc = postway.baseline_description
c = postway.args.command
if c == PostWayCommand.clean:
postway.do_clean(u, b_dir, h, p, d, pwd)
elif c == PostWayCommand.baseline:
postway.do_baseline(bs_ver, bs_desc, u)
elif c == PostWayCommand.version:
postway.get_max_schema_version(bs_ver, bs_desc, u)
elif c == PostWayCommand.migrate:
postway.migrate(bs_ver, bs_desc, u, b_dir, h, p, d, pwd, m_dir, r_prefix, v_prefix, sep, suf)
elif c == PostWayCommand.validate:
postway.validate(bs_ver, bs_desc, u, m_dir, r_prefix, v_prefix, sep, suf)
elif c == PostWayCommand.info:
postway.info(bs_ver, bs_desc, u, m_dir, r_prefix, v_prefix, sep, suf)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment