Last active
December 29, 2022 06:42
-
-
Save CAFxX/48b62fb904c7ae31932a3b4ebae15fe9 to your computer and use it in GitHub Desktop.
MySQL bulk data loader (via LOAD DATA LOCAL INFILE)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package mysql | |
import ( | |
"bytes" | |
"context" | |
"database/sql" | |
"encoding/csv" | |
"errors" | |
"fmt" | |
"io" | |
"math/random" | |
"os" | |
"strings" | |
"github.com/go-sql-driver/mysql" | |
"github.com/gocarina/gocsv" | |
) | |
// LoadData inserts the provided rows in the specified table. | |
func LoadData(ctx context.Context, db *sql.DB, tableName string, rows [][]string, opts ...Option) (sql.Result, error) { | |
var buf bytes.Buffer | |
buf.Grow(estimateRowsSize(rows)) | |
w := csv.NewWriter(&buf) | |
err := w.WriteAll(rows) | |
if err != nil { | |
return nil, err | |
} | |
return LoadDataFromReader(ctx, db, tableName, &buf, opts...) | |
} | |
func estimateRowsSize(rows [][]string) (size int) { | |
for _, row := range rows { | |
for _, cell := range row { | |
size += len(cell) + 3 | |
} | |
size++ | |
} | |
return | |
} | |
// LoadDataStructs inserts the provided rows, serialized using gocarina/gocsv, in the specified table. | |
func LoadDataStructs(ctx context.Context, db *sql.DB, tableName string, rows []any, opts ...Option) (sql.Result, error) { | |
var buf bytes.Buffer | |
err := gocsv.MarshalWithoutHeaders(&rows, &buf) | |
if err != nil { | |
return nil, err | |
} | |
return LoadDataFromReader(ctx, db, tableName, &buf, opts...) | |
} | |
// LoadDataFromLocalFile inserts the rows in the provided CSV file into the specified table. | |
// | |
// The file must be in CSV format. | |
func LoadDataFromLocalFile(ctx context.Context, db *sql.DB, tableName, fileName string, opts ...Option) (sql.Result, error) { | |
f, err := os.Open(fileName) | |
if err != nil { | |
return nil, err | |
} | |
defer f.Close() | |
return LoadDataFromReader(ctx, db, tableName, f, opts...) | |
} | |
// LoadDataFromReader inserts the rows (in CSV format) from the provided Reader into the specified table. | |
// | |
// The data read from the reader must be in CSV format. | |
func LoadDataFromReader(ctx context.Context, db *sql.DB, tableName string, r io.Reader, opts ...Option) (sql.Result, error) { | |
if strings.Contains(tableName, "`") { | |
return nil, fmt.Errorf("invalid table name: %q", tableName) | |
} | |
cfg := cfg{} | |
for _, opt := range opts { | |
err := opt(&cfg) | |
if err != nil { | |
return nil, fmt.Errorf("option: %w", err) | |
} | |
} | |
// This is silly, unsafe, and racy. The mysql driver should really | |
// support passing a reader as a parameter to ExecContext, e.g. | |
// db.ExecContext(ctx, "LOAD DATA LOCAL INFILE ? INTO TABLE ?", mysql.LoadDataReader(r), tableName) | |
readerName := fmt.Sprintf("LoadData/%d-%d", rand.Uint64(), rand.Uint64()) | |
mysql.RegisterReaderHandler(readerName, func() io.Reader { return r }) | |
defer mysql.DeregisterReaderHandler(readerName) | |
sql := "LOAD DATA LOCAL INFILE 'Reader::"+readerName+"'" | |
if cfg.dup != "" { | |
sql += " " + cfg.dup | |
} | |
sql += " INTO TABLE `"+tableName+"`" | |
if len(cfg.partitions) > 0 { | |
sql += " PARTITION (`" + cfg.partitions[0] + "`" | |
for _, p := range cfg.partitions[1:] { | |
sql += ",`" + p + "`" | |
} | |
sql += ")" | |
} | |
if cfg.charset != "" { | |
sql += " CHARACTER SET " + cfg.charset | |
} | |
if len(cfg.columns) > 0 { | |
sql += " `" + cfg.columns[0] + "`" | |
for _, col := range cfg.columns[1:] { | |
sql += ",`" + col + "`" | |
} | |
} | |
return db.ExecContext(ctx, sql) | |
} | |
type cfg struct { | |
charset string | |
columns []string | |
dup string | |
partitions []string | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package mysql | |
import ( | |
"errors" | |
"strings" | |
) | |
type Option func(*cfg) error | |
func CharacterSet(cs string) Option { | |
return func(c *cfg) error { | |
if strings.Contains(cs, "`") { | |
return errors.New("invalid charset") | |
} | |
if c.charset != "" && c.charset != cs { | |
return errors.New("multiple characters set specified") | |
} | |
c.charset = cs | |
return nil | |
} | |
} | |
func Columns(cols ...string) Option { | |
return func(c *cfg) error { | |
for _, col := range cols { | |
if strings.Contains(col, "`") { | |
return errors.New("invalid column") | |
} | |
} | |
c.columns = append(c.columns, cols) | |
return nil | |
} | |
} | |
func ReplaceDuplicates() Option { | |
return func(c *cfg) error { | |
if c.dup != "" && c.dup != "REPLACE" { | |
return errors.New("ReplaceDuplicates and IgnoreDuplicates are mutually exclusive") | |
} | |
c.dup = "REPLACE" | |
return nil | |
} | |
} | |
func IgnoreDuplicates() Option { | |
return func(c *cfg) error { | |
if c.dup != "" && c.dup != "IGNORE" { | |
return errors.New("IgnoreDuplicates and ReplaceDuplicates are mutually exclusive") | |
} | |
c.dup = "IGNORE" | |
return nil | |
} | |
} | |
func Partitions(ps ...string) Option { | |
return func(c *cfg) error { | |
for _, p := range ps { | |
if strings.Contains(p, "`") { | |
return errors.New("invalid partition") | |
} | |
} | |
c.partitions = append(c.partitions, ps) | |
return nil | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment