diff --git a/drivers/drivers_test.go b/drivers/drivers_test.go index ead41f786ee..21f3c8b4dc7 100644 --- a/drivers/drivers_test.go +++ b/drivers/drivers_test.go @@ -435,6 +435,7 @@ func TestCopy(t *testing.T) { testCases := []struct { dbName string + testCase string setupQueries []setupQuery src string dest string @@ -449,7 +450,8 @@ func TestCopy(t *testing.T) { dest: "staff_copy", }, { - dbName: "pgsql", + dbName: "pgsql", + testCase: "schemaInDest", setupQueries: []setupQuery{ {query: "DROP TABLE staff_copy"}, {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, @@ -467,7 +469,8 @@ func TestCopy(t *testing.T) { dest: "staff_copy", }, { - dbName: "pgx", + dbName: "pgx", + testCase: "schemaInDest", setupQueries: []setupQuery{ {query: "DROP TABLE staff_copy"}, {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, @@ -484,6 +487,17 @@ func TestCopy(t *testing.T) { src: "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff", dest: "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)", }, + { + dbName: "mysql", + testCase: "bulkCopy", + setupQueries: []setupQuery{ + {query: "SET GLOBAL local_infile = ON"}, + {query: "DROP TABLE staff_copy"}, + {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, + }, + src: "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff", + dest: "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)", + }, { dbName: "sqlserver", setupQueries: []setupQuery{ @@ -508,7 +522,11 @@ func TestCopy(t *testing.T) { continue } - t.Run(test.dbName, func(t *testing.T) { + testName := test.dbName + if test.testCase != "" { + testName += "-" + test.testCase + } + t.Run(testName, func(t *testing.T) { // TODO test copy from a different DB, maybe csvq? // TODO test copy from same DB diff --git a/drivers/mysql/copy.go b/drivers/mysql/copy.go new file mode 100644 index 00000000000..81560686f77 --- /dev/null +++ b/drivers/mysql/copy.go @@ -0,0 +1,108 @@ +package mysql + +import ( + "context" + "database/sql" + "encoding/csv" + "fmt" + "io" + "os" + "reflect" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/xo/usql/drivers" +) + +func copyRows(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) { + localInfileSupported := false + row := db.QueryRowContext(ctx, "SELECT @@GLOBAL.local_infile") + err := row.Scan(&localInfileSupported) + if err == nil && localInfileSupported && !hasBlobColumn(rows) { + return bulkCopy(ctx, db, rows, table) + } else { + return drivers.CopyWithInsert(func(int) string { return "?" })(ctx, db, rows, table) + } +} + +func bulkCopy(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) { + mysql.RegisterReaderHandler("data", func() io.Reader { + return toCsvReader(rows) + }) + defer mysql.DeregisterReaderHandler("data") + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + var cnt int64 + res, err := tx.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE %s", + strings.Replace(table, "(", " FIELDS TERMINATED BY ',' (", 1))) + if err != nil { + tx.Rollback() + } else { + err = tx.Commit() + if err == nil { + cnt, err = res.RowsAffected() + } + } + return cnt, err +} + +func hasBlobColumn(rows *sql.Rows) bool { + columnTypes, err := rows.ColumnTypes() + if err != nil { + return false + } + for _, ct := range columnTypes { + if ct.DatabaseTypeName() == "BLOB" { + return true + } + } + return false +} + +func toCsvReader(rows *sql.Rows) io.Reader { + r, w := io.Pipe() + go writeAsCsv(rows, w) + return r +} + +// writeAsCsv writes the rows in a CSV format compatible with LOAD DATA INFILE +func writeAsCsv(rows *sql.Rows, w *io.PipeWriter) { + defer w.Close() // noop if already closed + columnTypes, err := rows.ColumnTypes() + if err != nil { + w.CloseWithError(err) + return + } + values := make([]interface{}, len(columnTypes)) + valueRefs := make([]reflect.Value, len(columnTypes)) + for i := 0; i < len(columnTypes); i++ { + valueRefs[i] = reflect.New(columnTypes[i].ScanType()) + values[i] = valueRefs[i].Interface() + } + record := make([]string, len(values)) + csvWriter := csv.NewWriter(io.MultiWriter(w, os.Stdout)) + for rows.Next() { + if err = rows.Err(); err != nil { + break + } + err = rows.Scan(values...) + if err != nil { + break + } + for i, valueRef := range valueRefs { + // NB: Does not work for BLOBs. Use regular copy if there are BLOB columns + record[i] = fmt.Sprintf("%v", valueRef.Elem().Interface()) + } + err = csvWriter.Write(record) + if err != nil { + break + } + } + if err == nil { + csvWriter.Flush() + err = csvWriter.Error() + } + w.CloseWithError(err) // same as w.Close(), if err is nil +} diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index 4e333ea4c5a..a19f789bf9e 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -45,7 +45,7 @@ func init() { NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer { return metadata.NewDefaultWriter(mymeta.NewReader(db, opts...))(db, w) }, - Copy: drivers.CopyWithInsert(func(int) string { return "?" }), + Copy: copyRows, NewCompleter: mymeta.NewCompleter, }, "memsql", "vitess", "tidb") }