Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Builds a SQL statement to hash a table in any of postgres/mysql/redshift, with the same result across the 3.
package main
import (
"flag"
"fmt"
"strings"
)
const DB_REDSHIFT = "redshift"
const DB_POSTGRES = "postgres"
const DB_MYSQL = "mysql"
type Table struct {
Name string
Schema string
Columns []Column
}
type Column struct {
Name string
Type string
NotNull bool
Encoding string
}
var (
tableName string
schemaName string
idCol string
lowerBound string
upperBound string
colList string
dbType string
tablePath string
debug bool
)
func init() {
flag.StringVar(&tableName, "table", "", "Table name.")
flag.StringVar(&schemaName, "schema", "", "Schema name.")
flag.StringVar(&colList, "cols", "", "Comma separated list of column names/type pairs, E.g. (name 1, type1, name2, type2, ...).")
flag.StringVar(&idCol, "id", "id", "Id column")
flag.StringVar(&lowerBound, "lower", "0", "Lower value of id column, exclusive, to hash.")
flag.StringVar(&upperBound, "upper", "1000000", "Upper value of id column, inclusive, to hash.")
flag.StringVar(&dbType, "dbType", DB_POSTGRES, "Type of database, only postgres / redshift are supported")
flag.BoolVar(&debug, "debug", false, "debug")
}
func main() {
flag.Parse()
if colList == "" {
colList = strings.Join(flag.Args(), " ")
}
if tableName == "" || colList == "" {
flag.Usage()
return
}
table := makeTable(tableName, schemaName, strings.Split(colList, ",")...)
fmt.Println("")
if schemaName != "" {
tablePath = fmt.Sprintf(`%s.%s`, schemaName, tableName)
} else {
tablePath = tableName
}
fmt.Println(hashAllCols(table, idCol, lowerBound, upperBound))
}
func makeTable(name, schema string, colPairs ...string) *Table {
cols := []Column{}
for i := 0; i < len(colPairs); i += 2 {
cols = append(cols, Column{Name: strings.TrimSpace(colPairs[i]), Type: strings.TrimSpace(colPairs[i+1])})
}
return &Table{Name: name, Schema: schema, Columns: cols}
}
func hashAllCols(table *Table, idCol, lowerBound, upperBound string) string {
var (
colSqls = make([]string, len(table.Columns))
rowSql string
)
switch dbType {
case DB_MYSQL:
for idx, col := range table.Columns {
colSqls[idx] = colAsString_mysql(&col)
}
rowSql = fmt.Sprintf("concat(%s)", strings.Join(colSqls, ", "))
if debug {
rowSql = fmt.Sprintf("concat_ws(',',%s)", strings.Join(colSqls, ","))
}
//colAsString = colAsString_mysql
default:
for idx, col := range table.Columns {
colSqls[idx] = colAsString_pg(&col)
}
rowSql = strings.Join(colSqls, " || ")
if debug {
rowSql = strings.Join(colSqls, " || ',' || ")
}
}
if debug {
return fmt.Sprintf(`select md5(%s), %s from %s where %s > '%s' and %s <= '%s' order by %s desc`,
strings.Replace(rowSql, "md5", "", -1), strings.Replace(rowSql, "md5", "", -1), tablePath, idCol, lowerBound, idCol, upperBound, idCol)
}
innerQuery := fmt.Sprintf(`select md5(%s) as hash from %s where %s > '%s' and %s <= '%s'`,
rowSql, tablePath, idCol, lowerBound, idCol, upperBound)
outerQuery := fmt.Sprintf("select %s from (%s) a;", getSumOfHash("hash"), innerQuery)
return outerQuery
}
func colAsString_pg(col *Column) string {
colSql := fmt.Sprintf(`"%s"`, col.Name)
if col.Type == "date" {
colSql = fmt.Sprintf("(%s - '0001-01-01'::date)", colSql)
}
if strings.Contains(col.Type, "timestamp") {
colSql = fmt.Sprintf("floor(extract(epoch from %s))", colSql)
}
if col.Type == "boolean" {
colSql = colSql + "::integer"
}
if strings.Contains(col.Type, "varchar") {
colSql = fmt.Sprintf("md5(%s)", colSql)
} else {
colSql = fmt.Sprintf("md5(%s::text)", colSql)
}
if !col.NotNull {
colSql = fmt.Sprintf("coalesce(%s, ' ')", colSql)
}
return colSql
}
func colAsString_mysql(col *Column) string {
colSql := fmt.Sprintf("%s", col.Name)
if col.Type == "date" {
colSql = fmt.Sprintf("(to_days(%s) - 366)", colSql) //366 to represent diff from Day 1, Year 1, not Year 0 which never existed.
}
if strings.Contains(col.Type, "timestamp") || col.Type == "datetime" {
colSql = fmt.Sprintf("floor(unix_timestamp(%s - interval 7 hour))", colSql)
}
colSql = fmt.Sprintf("md5(%s)", colSql)
if !col.NotNull {
colSql = fmt.Sprintf("coalesce(%s, ' ')", colSql)
}
return colSql
}
func getSumOfHash(col string) string {
var queryPart string
switch dbType {
case DB_POSTGRES:
queryPart = `sum(('x'||substring(%s,%d,8))::bit(32)::bigint)`
case DB_REDSHIFT:
queryPart = `sum(trunc(strtol(substring(%s,%d,8),16)))`
case DB_MYSQL:
queryPart = `sum(cast(conv(substring(%s,%d,8), 16, 10) as unsigned))`
}
queryParts := []string{fmt.Sprintf(queryPart, col, 1), fmt.Sprintf(queryPart, col, 9), fmt.Sprintf(queryPart, col, 17), fmt.Sprintf(queryPart, col, 25)}
return strings.Join(queryParts, ", ")
}
@hullsean

This comment has been minimized.

Copy link

commented Apr 19, 2017

Hi Jason, doesn't your code assume the databases are FROZEN in time? I would guess you'd have to put MySQL or Postgres in READONLY mode, and make sure redshift is also paused.

How do you guys use this in the real world?

@aadant

This comment has been minimized.

Copy link

commented Feb 14, 2019

queryPart = sum(trunc(strtol(substring(%s,%d,8),16))) is probably wrong

Example

select cast(trunc(cast(9009946224037101 as bigint)) as bigint);
trunc

9009946224037100
(1 row)

I would remove the trunc in Redshift

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.