Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Builds a SQL statement to hash a table in any of postgres/mysql/redshift, with the same result across the 3.
package main
import (
"flag"
"fmt"
"strings"
)
const DB_REDSHIFT = "redshift"
const DB_POSTGRES = "postgres"
const DB_MYSQL = "mysql"
type Table struct {
Name string
Schema string
Columns []Column
}
type Column struct {
Name string
Type string
NotNull bool
Encoding string
}
var (
tableName string
schemaName string
idCol string
lowerBound string
upperBound string
colList string
dbType string
tablePath string
debug bool
)
func init() {
flag.StringVar(&tableName, "table", "", "Table name.")
flag.StringVar(&schemaName, "schema", "", "Schema name.")
flag.StringVar(&colList, "cols", "", "Comma separated list of column names/type pairs, E.g. (name 1, type1, name2, type2, ...).")
flag.StringVar(&idCol, "id", "id", "Id column")
flag.StringVar(&lowerBound, "lower", "0", "Lower value of id column, exclusive, to hash.")
flag.StringVar(&upperBound, "upper", "1000000", "Upper value of id column, inclusive, to hash.")
flag.StringVar(&dbType, "dbType", DB_POSTGRES, "Type of database, only postgres / redshift are supported")
flag.BoolVar(&debug, "debug", false, "debug")
}
func main() {
flag.Parse()
if colList == "" {
colList = strings.Join(flag.Args(), " ")
}
if tableName == "" || colList == "" {
flag.Usage()
return
}
table := makeTable(tableName, schemaName, strings.Split(colList, ",")...)
fmt.Println("")
if schemaName != "" {
tablePath = fmt.Sprintf(`%s.%s`, schemaName, tableName)
} else {
tablePath = tableName
}
fmt.Println(hashAllCols(table, idCol, lowerBound, upperBound))
}
func makeTable(name, schema string, colPairs ...string) *Table {
cols := []Column{}
for i := 0; i < len(colPairs); i += 2 {
cols = append(cols, Column{Name: strings.TrimSpace(colPairs[i]), Type: strings.TrimSpace(colPairs[i+1])})
}
return &Table{Name: name, Schema: schema, Columns: cols}
}
func hashAllCols(table *Table, idCol, lowerBound, upperBound string) string {
var (
colSqls = make([]string, len(table.Columns))
rowSql string
)
switch dbType {
case DB_MYSQL:
for idx, col := range table.Columns {
colSqls[idx] = colAsString_mysql(&col)
}
rowSql = fmt.Sprintf("concat(%s)", strings.Join(colSqls, ", "))
if debug {
rowSql = fmt.Sprintf("concat_ws(',',%s)", strings.Join(colSqls, ","))
}
//colAsString = colAsString_mysql
default:
for idx, col := range table.Columns {
colSqls[idx] = colAsString_pg(&col)
}
rowSql = strings.Join(colSqls, " || ")
if debug {
rowSql = strings.Join(colSqls, " || ',' || ")
}
}
if debug {
return fmt.Sprintf(`select md5(%s), %s from %s where %s > '%s' and %s <= '%s' order by %s desc`,
strings.Replace(rowSql, "md5", "", -1), strings.Replace(rowSql, "md5", "", -1), tablePath, idCol, lowerBound, idCol, upperBound, idCol)
}
innerQuery := fmt.Sprintf(`select md5(%s) as hash from %s where %s > '%s' and %s <= '%s'`,
rowSql, tablePath, idCol, lowerBound, idCol, upperBound)
outerQuery := fmt.Sprintf("select %s from (%s) a;", getSumOfHash("hash"), innerQuery)
return outerQuery
}
func colAsString_pg(col *Column) string {
colSql := fmt.Sprintf(`"%s"`, col.Name)
if col.Type == "date" {
colSql = fmt.Sprintf("(%s - '0001-01-01'::date)", colSql)
}
if strings.Contains(col.Type, "timestamp") {
colSql = fmt.Sprintf("floor(extract(epoch from %s))", colSql)
}
if col.Type == "boolean" {
colSql = colSql + "::integer"
}
if strings.Contains(col.Type, "varchar") {
colSql = fmt.Sprintf("md5(%s)", colSql)
} else {
colSql = fmt.Sprintf("md5(%s::text)", colSql)
}
if !col.NotNull {
colSql = fmt.Sprintf("coalesce(%s, ' ')", colSql)
}
return colSql
}
func colAsString_mysql(col *Column) string {
colSql := fmt.Sprintf("%s", col.Name)
if col.Type == "date" {
colSql = fmt.Sprintf("(to_days(%s) - 366)", colSql) //366 to represent diff from Day 1, Year 1, not Year 0 which never existed.
}
if strings.Contains(col.Type, "timestamp") || col.Type == "datetime" {
colSql = fmt.Sprintf("floor(unix_timestamp(%s - interval 7 hour))", colSql)
}
colSql = fmt.Sprintf("md5(%s)", colSql)
if !col.NotNull {
colSql = fmt.Sprintf("coalesce(%s, ' ')", colSql)
}
return colSql
}
func getSumOfHash(col string) string {
var queryPart string
switch dbType {
case DB_POSTGRES:
queryPart = `sum(('x'||substring(%s,%d,8))::bit(32)::bigint)`
case DB_REDSHIFT:
queryPart = `sum(trunc(strtol(substring(%s,%d,8),16)))`
case DB_MYSQL:
queryPart = `sum(cast(conv(substring(%s,%d,8), 16, 10) as unsigned))`
}
queryParts := []string{fmt.Sprintf(queryPart, col, 1), fmt.Sprintf(queryPart, col, 9), fmt.Sprintf(queryPart, col, 17), fmt.Sprintf(queryPart, col, 25)}
return strings.Join(queryParts, ", ")
}
@hullsean

This comment has been minimized.

Copy link

@hullsean hullsean commented Apr 19, 2017

Hi Jason, doesn't your code assume the databases are FROZEN in time? I would guess you'd have to put MySQL or Postgres in READONLY mode, and make sure redshift is also paused.

How do you guys use this in the real world?

@aadant

This comment has been minimized.

Copy link

@aadant aadant commented Feb 14, 2019

queryPart = sum(trunc(strtol(substring(%s,%d,8),16))) is probably wrong

Example

select cast(trunc(cast(9009946224037101 as bigint)) as bigint);
trunc

9009946224037100
(1 row)

I would remove the trunc in Redshift

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment