Skip to content

Commit

Permalink
feat: bulk load for MySQL
Browse files Browse the repository at this point in the history
The PR implements bulk loading for MySQL using the "LOAD DATA from io.Reader" feature of github.com/go-sql-driver/mysql - https://github.com/go-sql-driver/mysql?tab=readme-ov-file#load-data-local-infile-support . As expected,
bulk loading this way is significantly faster. 1 mln. rows in the "staff" table from the test schema are inserted for 15 sec vs. 120 sec using INSERT: 8x improvement. Note that LOAD DATA INFILE LOCAL is disabled by default on
MySQL 8+ servers and must be enabled using SET GLOBAL local_infile = ON beforehand. MySQL doesn't seem to have any remote bulk loading options that are enabled by default.

The PR also extends TestCopy in drivers_test.go with comparison of copied data to ensure MySQL bulk loading is safe across data types.

Testing Done: tests in drivers_test.go#
  • Loading branch information
murfffi committed Dec 5, 2024
1 parent c3e8cde commit cae603a
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 9 deletions.
123 changes: 115 additions & 8 deletions drivers/drivers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"bytes"
"context"
"database/sql"
"errors"
"flag"
"fmt"
"log"
"net/url"
"os"
"reflect"
"regexp"
"strings"
"testing"
Expand Down Expand Up @@ -435,9 +437,11 @@ func TestCopy(t *testing.T) {

testCases := []struct {
dbName string
testCase string
setupQueries []setupQuery
src string
dest string
destCmpQuery string
}{
{
dbName: "pgsql",
Expand All @@ -449,7 +453,8 @@ func TestCopy(t *testing.T) {
dest: "staff_copy",
},
{
dbName: "pgsql",
dbName: "pgsql",
testCase: "schemaInDest",
setupQueries: []setupQuery{
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
Expand All @@ -466,8 +471,9 @@ func TestCopy(t *testing.T) {
src: "select * from staff",
dest: "staff_copy",
},
{
dbName: "pgx",
{ // this holds even select iterates over table in a ran
dbName: "pgx",
testCase: "schemaInDest",
setupQueries: []setupQuery{
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
Expand All @@ -478,12 +484,22 @@ func TestCopy(t *testing.T) {
{
dbName: "mysql",
setupQueries: []setupQuery{
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
},
src: "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
},
{
dbName: "mysql",
testCase: "bulkCopy",
setupQueries: []setupQuery{
{query: "SET GLOBAL local_infile = ON"},
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
},
src: "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)",
},
{
dbName: "sqlserver",
setupQueries: []setupQuery{
Expand All @@ -497,9 +513,11 @@ func TestCopy(t *testing.T) {
dbName: "csvq",
setupQueries: []setupQuery{
{query: "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1", check: true},
{query: "DELETE from staff_copy", check: true},
},
src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy",
src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy",
destCmpQuery: "select first_name, last_name, address_id, email, store_id, active, username, password, datetime(last_update) from staff_copy",
},
}
for _, test := range testCases {
Expand All @@ -508,7 +526,11 @@ func TestCopy(t *testing.T) {
continue
}

t.Run(test.dbName, func(t *testing.T) {
testName := test.dbName
if test.testCase != "" {
testName += "-" + test.testCase
}
t.Run(testName, func(t *testing.T) {

// TODO test copy from a different DB, maybe csvq?
// TODO test copy from same DB
Expand All @@ -524,7 +546,7 @@ func TestCopy(t *testing.T) {
t.Fatalf("Could not get rows to copy: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Second)
defer cancel()
var rlen int64 = 1
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
Expand All @@ -534,10 +556,95 @@ func TestCopy(t *testing.T) {
if n != rlen {
t.Fatalf("Expected to copy %d rows but got %d", rlen, n)
}

checkSameData(t, ctx, pg.DB, test.src, db.DB, test.destCmpQuery)
})
}
}

// checkSameData fails the test if the data in the srcDB."staff" table is different than destDB."staff_copy" table
func checkSameData(t *testing.T, ctx context.Context, srcDB *sql.DB, srcQuery string, destDB *sql.DB, destCmpQuery string) {
if destCmpQuery == "" {
srcQuery = strings.ToLower(srcQuery)
if !strings.Contains(srcQuery, "from staff") {
t.Fatalf("destCmpQuery needs to be configured if src '%s' is not for table 'staff'", srcQuery)
}
// if destCmpQuery needs special syntax, configure it in the test case definitions above
destCmpQuery = strings.Replace(srcQuery, "from staff", "from staff_copy", 1)
}
srcValues, srcColumnTypes, err := getSrcRow(ctx, srcDB, srcQuery)
if err != nil {
t.Fatalf("Could not get src row from database: %v", err)
}
destValues, err := getDestRow(ctx, destDB, destCmpQuery, srcColumnTypes)
if err != nil {
t.Fatalf("Could not get dest row from database: %v", err)
}
// Comparing more than 1 row is more complex because SELECT result order is undefined without order by
adjustDates(srcValues, destValues)
if !reflect.DeepEqual(srcValues, destValues) {
t.Fatalf("Source and dest row don't match: \n%v\n vs \n%v", srcValues, destValues)
}
}

// adjustDates removes sub-second differences between any dates in the two rows, because
// the difference are likely caused by difference in precision and not by a copy issue
func adjustDates(src []interface{}, dest []interface{}) {
for i, v := range src {
srcDate, okSrc := v.(time.Time)
destDate, okDest := dest[i].(time.Time)
if okSrc && okDest && srcDate.Sub(destDate).Abs() <= time.Second {
dest[i] = srcDate
}
}
}

func getSrcRow(ctx context.Context, db *sql.DB, query string) ([]interface{}, []*sql.ColumnType, error) {
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, nil, err
}
values, err := readRow(rows, columnTypes)
return values, columnTypes, err
}

func getDestRow(ctx context.Context, db *sql.DB, query string, columnTypes []*sql.ColumnType) ([]interface{}, error) {
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return readRow(rows, columnTypes)
}

func readRow(rows *sql.Rows, columnTypes []*sql.ColumnType) ([]interface{}, error) {
if !rows.Next() {
return nil, errors.New("exactly one row expected but got 0")
}
// some DB drivers don't handle reading into *any well so use *reportedType instead
values := make([]interface{}, len(columnTypes))
for i := 0; i < len(columnTypes); i++ {
values[i] = reflect.New(columnTypes[i].ScanType()).Interface()
}
err := rows.Scan(values...)
if err != nil {
return nil, err
}
if rows.Next() {
return nil, errors.New("exactly one row expected but more found")
}
// dereference the pointers
for i, v := range values {
values[i] = reflect.ValueOf(v).Elem().Interface()
}
return values, nil
}

// filesEqual compares the files at paths a and b and returns an error if
// the content is not equal. Ignore is a regex. All matches will be removed
// from the file contents before comparison.
Expand Down
127 changes: 127 additions & 0 deletions drivers/mysql/copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package mysql

import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"io"
"os"
"reflect"
"strings"

"github.com/go-sql-driver/mysql"
"github.com/xo/usql/drivers"
)

func copyRows(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
localInfileSupported := false
row := db.QueryRowContext(ctx, "SELECT @@GLOBAL.local_infile")
err := row.Scan(&localInfileSupported)
if err == nil && localInfileSupported && !hasBlobColumn(rows) {
return bulkCopy(ctx, db, rows, table)
} else {
return drivers.CopyWithInsert(func(int) string { return "?" })(ctx, db, rows, table)
}
}

func bulkCopy(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
mysql.RegisterReaderHandler("data", func() io.Reader {
return toCsvReader(rows)
})
defer mysql.DeregisterReaderHandler("data")
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
var cnt int64
csvSpec := " FIELDS TERMINATED BY ',' "
stmt := fmt.Sprintf("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE %s",
// if there is a column list, csvSpec goes between the table name and the list
strings.Replace(table, "(", csvSpec+" (", 1))
// if there wasn't a column list in the table spec, csvSpec goes at the end
if !strings.Contains(table, "(") {
stmt += csvSpec
}
res, err := tx.ExecContext(ctx, stmt)
if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
if err == nil {
cnt, err = res.RowsAffected()
}
}
return cnt, err
}

func hasBlobColumn(rows *sql.Rows) bool {
columnTypes, err := rows.ColumnTypes()
if err != nil {
return false
}
for _, ct := range columnTypes {
if ct.DatabaseTypeName() == "BLOB" {
return true
}
}
return false
}

func toCsvReader(rows *sql.Rows) io.Reader {
r, w := io.Pipe()
go writeAsCsv(rows, w)
return r
}

// writeAsCsv writes the rows in a CSV format compatible with LOAD DATA INFILE
func writeAsCsv(rows *sql.Rows, w *io.PipeWriter) {
defer w.Close() // noop if already closed
columnTypes, err := rows.ColumnTypes()
if err != nil {
w.CloseWithError(err)
return
}
values := make([]interface{}, len(columnTypes))
valueRefs := make([]reflect.Value, len(columnTypes))
for i := 0; i < len(columnTypes); i++ {
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
values[i] = valueRefs[i].Interface()
}
record := make([]string, len(values))
csvWriter := csv.NewWriter(io.MultiWriter(w, os.Stdout))
for rows.Next() {
if err = rows.Err(); err != nil {
break
}
err = rows.Scan(values...)
if err != nil {
break
}
for i, valueRef := range valueRefs {
val := valueRef.Elem().Interface()
val = toIntIfBool(val)
// NB: Does not work for BLOBs. Use regular copy if there are BLOB columns
record[i] = fmt.Sprintf("%v", val)
}
err = csvWriter.Write(record)
if err != nil {
break
}
}
if err == nil {
csvWriter.Flush()
err = csvWriter.Error()
}
w.CloseWithError(err) // same as w.Close(), if err is nil
}

func toIntIfBool(val interface{}) interface{} {
if boolVal, ok := val.(bool); ok {
val = 0
if boolVal {
val = 1
}
}
return val
}
2 changes: 1 addition & 1 deletion drivers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func init() {
NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer {
return metadata.NewDefaultWriter(mymeta.NewReader(db, opts...))(db, w)
},
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
Copy: copyRows,
NewCompleter: mymeta.NewCompleter,
}, "memsql", "vitess", "tidb")
}

0 comments on commit cae603a

Please sign in to comment.