Skip to content

Instantly share code, notes, and snippets.

@orlissenberg
Last active October 11, 2020 21:48
Show Gist options
  • Save orlissenberg/245ef9ed83f60e9499b6ad9d657da537 to your computer and use it in GitHub Desktop.
Save orlissenberg/245ef9ed83f60e9499b6ad9d657da537 to your computer and use it in GitHub Desktop.
Create a PostgreSQL schema from yaml, use mixins for column patterns.
CREATE TABLE IF NOT EXISTS {{ table.name }} (
{% for column in table.getColumns() %}{{ "\t" }}{{ column }}{% if not loop.last %},{{ "\n" }}{% endif %}{% endfor %}
);
{% for foreignKey in table.getForeignKeys() %}
ALTER TABLE {{ table.name }}
ADD CONSTRAINT {{ foreignKey.name }}
FOREIGN KEY ({{ ", ".join(foreignKey.columns) }})
REFERENCES {{ foreignKey.reference }}({{ ", ".join(foreignKey.referenceColumns) }});
{# ON DELETE CASCADE #}
{% endfor %}
mixins:
# Add create columns
created:
columns:
created_at:
type: timestamp-now
created_by:
type: string
# Add update columns
updated:
columns:
updated_at:
type: timestamp-now
updated_by:
type: string
# Has a comment field
commented:
columns:
comments:
type: string
# Has a begin and end date
date_ranged:
columns:
started_at:
type: datetime
ended_at:
type: datetime
# Add auto-increment PK
auto_incremented:
columns:
id:
type: pk_auto_increment
# Trace to the source
sourced:
columns:
source:
type: string
required: false
# Add a name column
named:
columns:
name:
type: string
required: true
tables:
# Example
example:
mixins:
- auto_incremented
- created
- updated
- sourced
columns:
foo_bar:
type: string
required: true
size: 42
import yaml
from jinja2 import Environment, FileSystemLoader
# https://docs.python.org/3.8/
class Table:
def __init__(self, name, schema):
self.name = name
self.columns = []
self.ForeignKeys = []
self.mixins = None
self.schema = schema
def getColumns(self):
outputColumns = []
for mixinName in self.mixins:
for mixinColumn in self.schema.mixins[mixinName].columns:
outputColumns.append(mixinColumn.render())
for column in self.columns:
outputColumns.append(column.render())
return outputColumns
def getForeignKeys(self):
return self.ForeignKeys
class ForeignKey:
def __init__(self, name):
self.name = name
self.reference = None
self.referenceColumns = []
self.columns = []
self.onDelete = ""
self.onUpdate = ""
class Column:
def __init__(self, name, size=50):
self.name = name
self.columnType = None
self.columnSize = size
self.isRequired = False
self.isUnique = False
# https://www.tutorialspoint.com/postgresql/postgresql_data_types.htm
def render(self):
required = ""
if self.isRequired:
required = " NOT NULL"
unique = ""
if self.isUnique:
unique = " UNIQUE"
if self.columnType in ["date", "timestampz", "timestampz", "smallint", "integer", "bigint", "decimal", "money", "real", "double", "serial", "bigserial", "boolean", "text"]:
return f'{self.name} {self.columnType}{required}{unique}'
elif self.columnType == "pk_auto_increment":
return f'{self.name} bigserial PRIMARY KEY'
elif self.columnType == "varchar-n" or self.columnType == "string":
return f'{self.name} varchar({self.columnSize}){required}'
elif self.columnType == "timestamp-now":
return f'{self.name} timestamp DEFAULT current_timestamp NOT NULL'
elif self.columnType == "datetime":
return f'{self.name} timestamp'
template = env.get_template('create_column.j2')
return template.render(column=self)
class Mixin:
def __init__(self, name):
self.name = name
self.columns = []
class Schema:
def __init__(self, name):
self.name = name
self.mixins = {}
self.tables = []
def createColumn(self, name, data):
column = Column(name)
column.columnType = data.get("type")
column.isRequired = data.get("required")
column.isUnique = data.get("unique")
if data.get("size"):
column.columnSize = data.get("size")
return column
def load(self, data):
for mixinName in data["mixins"]:
mixin = Mixin(mixinName)
columnList = data["mixins"][mixinName]["columns"]
for columnName in columnList:
column = self.createColumn(columnName, columnList[columnName])
mixin.columns.append(column)
self.mixins[mixinName] = mixin
# Process table columns & foreign keys
for tableName in data["tables"]:
table = Table(tableName, self)
if data["tables"][tableName]:
table.mixins = data["tables"][tableName].get("mixins")
columnList = data["tables"][tableName]["columns"]
for columnName in columnList:
column = self.createColumn(
columnName, columnList[columnName])
table.columns.append(column)
foreignKeyList = data["tables"][tableName].get("foreign_keys")
if foreignKeyList:
for foreignKeyName in foreignKeyList:
foreignKey = ForeignKey(foreignKeyName)
foreignKeyData = foreignKeyList[foreignKeyName]
foreignKey.columns = foreignKeyData.get("columns")
foreignKey.reference = foreignKeyData.get("reference")
foreignKey.referenceColumns = foreignKeyData.get(
"reference_columns")
table.ForeignKeys.append(foreignKey)
self.tables.append(table)
return self
with open("schema.yaml", "r") as file:
data = yaml.load(file, Loader=yaml.FullLoader)
schema = Schema("public").load(data)
with open("tables.yaml", "r") as file:
data = yaml.load(file, Loader=yaml.FullLoader)
schema.load(data)
# https://jinja.palletsprojects.com/en/2.11.x/templates
env = Environment(
loader=FileSystemLoader('postgresql/templates')
)
output = ""
for table in schema.tables:
template = env.get_template('create_table.j2')
output += template.render(table=table)
with open(f'output//tables.sql', "w+") as file:
file.write(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment