Skip to content

Instantly share code, notes, and snippets.

@dlisboa
Last active July 18, 2024 17:14
Show Gist options
  • Save dlisboa/6e88704c2b503b35461c47ef3d214829 to your computer and use it in GitHub Desktop.
Save dlisboa/6e88704c2b503b35461c47ef3d214829 to your computer and use it in GitHub Desktop.
Go Range over Func refactoring
func SeedTable(t *testing.T, db *sql.DB, data string) error {
for _, seed := range seeds(data) {
_, err := db.Exec(query(seed))
if err != nil {
return err
}
}
return nil
}
type seed [][]string
func seeds(data string) func(func(int, seed) bool) {
var (
lines = strings.Split(data, "\n")
index = 0
current = seed{}
)
return func(yield func(int, seed) bool) {
for _, line := range lines {
fields := strings.Fields(line)
// line is empty
if len(fields) == 0 {
continue
}
// end of seed
if fields[0] == "---" {
// call block
result := yield(index, current)
if !result {
// end iteration
return
}
// get ready for next seed
index++
current = nil
continue
}
current = append(current, fields)
}
}
}
func query(seed seed) string {
table := seed[0]
columns := strings.Join(seed[1], ", ")
var values []string
for _, tuple := range seed[2:] {
values = append(values, fmt.Sprintf("(%s)", strings.Join(tuple, ", ")))
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES %s RETURNING *", table[1], columns, strings.Join(values, ", "))
}
func SeedTable(t *testing.T, db *sql.DB, data string) error {
lines := strings.Split(data, "\n")
if len(lines) == 0 {
return errors.New("testdb: no data to seed")
}
t.Logf("lines: %+v\n", lines)
type query struct {
table string
columns []string
values []string
}
queries := []query{}
var seeds [][]string
var acc []string
for _, line := range lines {
if len(line) == 0 {
continue
}
if strings.HasPrefix(line, "---") {
seeds = append(seeds, acc)
acc = nil
continue
}
acc = append(acc, line)
}
for _, seed := range seeds {
table := strings.Fields(seed[0])
tableName := table[1]
columns := strings.Fields(seed[1])
var values []string
for _, tuple := range seed[2:] {
fields := strings.Fields(tuple)
values = append(values, fmt.Sprintf("(%s)", strings.Join(fields, ", ")))
}
q := query{
table: tableName,
columns: columns,
values: values,
}
queries = append(queries, q)
}
for _, q := range queries {
t.Logf("%+v\n", q)
cols := strings.Join(q.columns, ", ")
vals := strings.Join(q.values, ", ")
stmt := fmt.Sprintf(`INSERT INTO %s(%s) VALUES %s RETURNING *`, q.table, cols, vals)
fmt.Println("printing")
fmt.Println(stmt)
_, err := db.Exec(stmt)
if err != nil {
return err
}
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment