Skip to content

Instantly share code, notes, and snippets.

@CAFxX
Last active December 31, 2022 02:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CAFxX/f1f7b9d6227e3c4df3ea6d784e5c88b1 to your computer and use it in GitHub Desktop.
Save CAFxX/f1f7b9d6227e3c4df3ea6d784e5c88b1 to your computer and use it in GitHub Desktop.
SQL bulk insert
package bulkinsert
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
)
type conn interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
}
// BulkInsert inserts rows in bulk.
//
// The query should be in the form "INSERT ... VALUES" (without the query placeholders following
// the "VALUES" keyword). The placeholders are automatically generated and appended to the query.
//
// cols is the number of columns/values to be inserted per row. The number of arguments must be
// a multiple of cols.
//
// maxrows is the maximum number of rows to be inserted in a single INSERT query. If more than
// this number of rows needs to be inserted, multiple INSERT queries are executed automatically.
//
// If a non-nil error is returned, it is possible for some rows to have been inserted successfully.
// If partial inserts are undesirable, start a transaction before calling BulkInsert and perform
// a rollback in case of error.
func BulkInsert(ctx context.Context, c conn, query string, cols, maxrows int, args ...any) error {
if cols <= 0 {
return errors.New("wrong number of columns")
}
if maxrows <= 0 {
return errors.New("wrong number of max rows per query")
}
if len(args)%cols != 0 {
return errors.New("wrong number of arguments")
}
if maxrows > ((1<<16)-1)/cols {
maxrows = ((1 << 16) - 1) / cols
}
rows := len(args) / cols
if rows >= maxrows {
stmt, err := c.PrepareContext(ctx, placeholders(query, cols, maxrows))
if err != nil {
return fmt.Errorf("preparing statement: %w", err)
}
defer stmt.Close()
for rows >= maxrows {
_, err := stmt.ExecContext(ctx, args[:cols*maxrows]...)
if err != nil {
return fmt.Errorf("executing prepared statement: %w", err)
}
args = args[cols*maxrows:]
rows -= maxrows
}
}
if rows > 0 {
_, err := c.ExecContext(ctx, placeholders(query, cols, rows), args...)
if err != nil {
return fmt.Errorf("executing statement: %w", err)
}
}
return nil
}
func placeholders(query string, cols, rows int) string {
var b strings.Builder
b.Grow(len(query) + (cols*2+2)*rows)
b.WriteString(query)
b.WriteByte(' ')
for r := 0; r < rows; r++ {
if r != 0 {
b.WriteByte(',')
}
b.WriteByte('(')
for c := 0; c < cols; c++ {
if c != 0 {
b.WriteByte(',')
}
b.WriteByte('?')
}
b.WriteByte(')')
}
return b.String()
}
@CAFxX
Copy link
Author

CAFxX commented Dec 31, 2022

alternative implementation for placeholders:

func placeholders(query string, cols, rows int) string {
	var b strings.Builder
	b.Grow(len(query) + (cols*2+2)*rows)
	b.WriteString(query)
	b.WriteString(" (")
    repeat(&b, "?", ",", cols)
    b.WriteByte(')')
    if rows > 1 {
        cs := b.String()
        b.WriteByte(',')
        repeat(&b, cs[len(cs)-cols*2-1:], ",", rows-1)
    }
	return b.String()
}

func repeat(b *strings.Builder, elem, join string, n int) {
    switch n {
    case 0:
    case 1:
        b.WriteString(elem)
    case 2:
        b.Grow(len(elem)*2+len(join)*1)
        b.WriteString(elem)
        b.WriteString(join)
        b.WriteString(elem)
    case 3:
        b.Grow(len(elem)*3+len(join)*2)
        b.WriteString(elem)
        b.WriteString(join)
        b.WriteString(elem)        
        b.WriteString(join)
        b.WriteString(elem)        
    default:
        rs := len(elem)+len(join)
        b.Grow(len(elem)+rs*(n-1))
        b.WriteString(elem)
        b.WriteString(join)
        r := 1
        maxr := max(8*1024/rs, 1)
        for n-r-1 > 0 {
            cr := min(r, min(n-r-1, maxr))
            cs := b.String()
            b.WriteString(cs[len(cs)-cr*rs:])
            r += cr
        }
        b.WriteString(elem)
    }
}

func min(a, b int) int { if a <= b { return a } else { return b }}
func max(a, b int) int { if a >= b { return a } else { return b }}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment