Skip to content

Instantly share code, notes, and snippets.

@ionling
Last active January 18, 2024 07:00
Show Gist options
  • Save ionling/10f50bf3d77040fa8bb4f6695c23befe to your computer and use it in GitHub Desktop.
Save ionling/10f50bf3d77040fa8bb4f6695c23befe to your computer and use it in GitHub Desktop.
Check PostgreSQL data differences
package main
import (
"context"
"database/sql"
"flag"
"fmt"
"log/slog"
"math/rand"
"os"
"reflect"
"strings"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
"github.com/uptrace/bun/extra/bundebug"
"github.com/uptrace/bun/extra/bunotel"
"golang.org/x/sync/errgroup"
)
const (
maxCheckRows = 3000
)
var (
srcDSN = os.Getenv("SRC_DSN")
dstDSN = os.Getenv("DST_DSN")
checkF = flag.Bool("check", false,
"Check the difference between the source and destination table")
syncSeqsF = flag.Bool("sync-seqs", false,
"Sync the last value of sequences from source to destination")
schemaF = flag.String("schema", "public", "PostgreSQL database schema")
tablesF = flag.String("tables", "",
"Table names separated by ',', default to all tables in database")
orderByF = flag.String("orderby", "", "Order by clause used in query")
)
func main() {
flag.Parse()
l := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
}))
if err := do(context.Background(), l); err != nil {
fmt.Println("ERR:", err)
}
}
func do(ctx context.Context, l *slog.Logger) error {
srcDB, srcCleanup, err := newPostgres(srcDSN, *schemaF)
if err != nil {
return fmt.Errorf("new src db: %w", err)
}
defer srcCleanup()
dstDB, dstCleanup, err := newPostgres(dstDSN, *schemaF)
if err != nil {
return fmt.Errorf("new dst db: %w", err)
}
defer dstCleanup()
lr := LogicalRep{
SrcDB: srcDB,
DstDB: dstDB,
Schema: *schemaF,
l: l,
}
switch {
default:
return fmt.Errorf("no action")
case *checkF:
return lr.Check(ctx)
case *syncSeqsF:
res, err := lr.SyncSequences(ctx)
if err != nil {
return err
}
fmt.Printf("%v\n", res)
}
return nil
}
func newPostgres(dsn, schema string) (db *bun.DB, cleanup func(), err error) {
dsn += "&application_name=logicalrep&search_path=" + schema
// The config has defaults for timeouts,
// so it's not necessary to specify them again:
// - DialTimeout: 5 * time.Second
// - ReadTimeout: 10 * time.Second
// - WriteTimeout: 5 * time.Second
connector := pgdriver.NewConnector(pgdriver.WithDSN(dsn))
sqldb := sql.OpenDB(connector)
db = bun.NewDB(sqldb, pgdialect.New(), bun.WithDiscardUnknownColumns())
db.AddQueryHook(bundebug.NewQueryHook())
db.AddQueryHook(bunotel.NewQueryHook(
bunotel.WithDBName(connector.Config().Database),
))
err = db.Ping()
cleanup = func() {
db.Close()
}
return
}
func (lr *LogicalRep) Check(ctx context.Context) error {
l := lr.l.With("func", "Check")
if *tablesF != "" {
var tables []string
for _, t := range strings.Split(*tablesF, ",") {
t = strings.TrimSpace(t)
if t != "" {
tables = append(tables, t)
}
}
if len(tables) == 0 {
return fmt.Errorf("no table provided")
}
lr.Tables = tables
} else {
tables, err := lr.ListTables(ctx)
if err != nil {
return fmt.Errorf("list tables: %w", err)
}
for _, t := range tables {
lr.Tables = append(lr.Tables, t.TableName)
}
}
countRes, err := lr.CheckCount(ctx)
if err != nil {
return fmt.Errorf("check count: %w", err)
}
rowsRes, err := lr.CheckRows(ctx, countRes)
if err != nil {
return fmt.Errorf("check rows: %w", err)
}
l.InfoContext(ctx, "check count", "result", countRes)
l.InfoContext(ctx, "check rows", "result", rowsRes)
return nil
}
type LogicalRep struct {
SrcDB *bun.DB
DstDB *bun.DB
Schema string
Tables []string
l *slog.Logger
}
type Table struct {
TableName string
}
func (lr *LogicalRep) ListTables(ctx context.Context) (res []*Table, err error) {
// REF https://stackoverflow.com/a/2276722/7134763
err = lr.SrcDB.NewSelect().Table("information_schema.tables").
Where("table_schema = ?", lr.Schema).
Scan(ctx, &res)
return
}
type TableCount struct {
Src, Dst int
}
type CheckCountRes struct {
Count map[string]TableCount // table -> count
BadCount int
}
func (lr *LogicalRep) CheckCount(ctx context.Context) (res *CheckCountRes, err error) {
l := lr.l.With("func", "CheckCount")
res = &CheckCountRes{
Count: make(map[string]TableCount),
}
for _, t := range lr.Tables {
l := l.With("table", t)
eg, ctx := errgroup.WithContext(ctx)
var srcN, dstN int
eg.Go(func() (err error) {
srcN, err = lr.SrcDB.NewSelect().Table(t).Count(ctx)
return wrap(err, "count src")
})
eg.Go(func() (err error) {
dstN, err = lr.DstDB.NewSelect().Table(t).Count(ctx)
return wrap(err, "count dst")
})
if err := eg.Wait(); err != nil {
return nil, err
}
res.Count[t] = TableCount{
Src: srcN,
Dst: dstN,
}
l = l.With("src", srcN, "dst", dstN)
if srcN == dstN {
l.InfoContext(ctx, "compare")
} else {
res.BadCount++
l.WarnContext(ctx, "compare")
}
}
return
}
type CheckRowsRes struct {
BadCount int
}
func (lr *LogicalRep) CheckRows(
ctx context.Context, countRes *CheckCountRes,
) (res *CheckRowsRes, err error) {
l := lr.l.With("func", "CheckRows")
res = &CheckRowsRes{}
for _, t := range lr.Tables {
l := l.With("table", t)
count := countRes.Count[t]
maxN := max(count.Src, count.Dst)
var sql string
var offset int
if maxN <= maxCheckRows {
offset = 0
} else {
offset = rand.Intn(maxN - maxCheckRows)
}
var orderBy string
if *orderByF != "" {
orderBy = " ORDER BY " + *orderByF
}
sql = fmt.Sprintf("SELECT * FROM %s%s LIMIT %d OFFSET %d",
t, orderBy, maxCheckRows, offset)
// Below random selection is difficult to archive,
// because we don't have the fixed columns to order them.
// sql = fmt.Sprintf("SELECT * FROM %s TABLESAMPLE SYSTEM(3000 / %d)", t, maxN)
var srcs, dsts []map[string]any
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
err := lr.SrcDB.NewRaw(sql).Scan(ctx, &srcs)
return wrap(err, "query src")
})
eg.Go(func() error {
err := lr.DstDB.NewRaw(sql).Scan(ctx, &dsts)
return wrap(err, "query dst")
})
if err := eg.Wait(); err != nil {
return nil, err
}
l = l.With("len(srcs)", len(srcs), "len(dsts)", len(dsts))
if ok := reflect.DeepEqual(srcs, dsts); ok {
l.InfoContext(ctx, "compare", "ok", ok)
} else {
res.BadCount++
l.WarnContext(ctx, "compare", "ok", ok)
}
}
return
}
type SyncSequencesRes struct {
Total int
ErrCount int
OKCount int
Errs map[string]error // seq -> error
}
func (lr *LogicalRep) SyncSequences(ctx context.Context) (res *SyncSequencesRes, err error) {
seqs, err := lr.listSequences(ctx)
if err != nil {
return nil, fmt.Errorf("list seqs: %w", err)
}
l := lr.l.With("func", "SyncSequences")
res = &SyncSequencesRes{
Total: len(seqs),
}
for _, seq := range seqs {
lv := seq.LastValue * 110 / 100
if lv == seq.LastValue {
lv += 10
}
l := l.With("seq_name", seq.Name, "src_last_value", seq.LastValue, "dst_last_value", lv)
if err := lr.setSeqLastValue(ctx, seq.Name, lv); err != nil {
res.ErrCount++
res.Errs[seq.Name] = err
l.ErrorContext(ctx, err.Error())
} else {
res.OKCount++
l.InfoContext(ctx, "ok")
}
}
return
}
type Sequence struct {
bun.BaseModel `bun:"table:pg_sequences,alias:s"`
Schema string `bun:"schemaname"`
Name string `bun:"sequencename"`
LastValue int
}
func (lr *LogicalRep) listSequences(ctx context.Context) (res []*Sequence, err error) {
q := lr.SrcDB.NewSelect().Model(&res)
if lr.Schema != "" {
q.Where("schemaname = ?", lr.Schema)
}
err = q.Scan(ctx)
return
}
func (lr *LogicalRep) setSeqLastValue(ctx context.Context, name string, lastValue int) error {
q := "ALTER SEQUENCE " + lr.Schema + "." + name + " RESTART ?"
_, err := lr.DstDB.NewRaw(q, lastValue).Exec(ctx)
return err
}
func max(x, y int) int {
if x >= y {
return x
}
return y
}
func wrap(err error, msg string) error {
if err == nil {
return nil
}
return fmt.Errorf("%s: %w", msg, err)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment