diff --git a/drivers/drivers_test.go b/drivers/drivers_test.go index ead41f786ee..40d9980ff87 100644 --- a/drivers/drivers_test.go +++ b/drivers/drivers_test.go @@ -7,11 +7,13 @@ import ( "bytes" "context" "database/sql" + "errors" "flag" "fmt" "log" "net/url" "os" + "reflect" "regexp" "strings" "testing" @@ -435,9 +437,11 @@ func TestCopy(t *testing.T) { testCases := []struct { dbName string + testCase string setupQueries []setupQuery src string dest string + destCmpQuery string }{ { dbName: "pgsql", @@ -449,7 +453,8 @@ func TestCopy(t *testing.T) { dest: "staff_copy", }, { - dbName: "pgsql", + dbName: "pgsql", + testCase: "schemaInDest", setupQueries: []setupQuery{ {query: "DROP TABLE staff_copy"}, {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, @@ -466,8 +471,9 @@ func TestCopy(t *testing.T) { src: "select * from staff", dest: "staff_copy", }, - { - dbName: "pgx", + { // this holds even select iterates over table in a ran + dbName: "pgx", + testCase: "schemaInDest", setupQueries: []setupQuery{ {query: "DROP TABLE staff_copy"}, {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, @@ -478,12 +484,22 @@ func TestCopy(t *testing.T) { { dbName: "mysql", setupQueries: []setupQuery{ - {query: "DROP TABLE staff_copy"}, {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, }, src: "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff", dest: "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)", }, + { + dbName: "mysql", + testCase: "bulkCopy", + setupQueries: []setupQuery{ + {query: "SET GLOBAL local_infile = ON"}, + {query: "DROP TABLE staff_copy"}, + {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, + }, + src: "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff", + dest: "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)", + }, { dbName: "sqlserver", setupQueries: []setupQuery{ @@ -497,9 +513,11 @@ func TestCopy(t *testing.T) { dbName: "csvq", setupQueries: []setupQuery{ {query: "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1", check: true}, + {query: "DELETE from staff_copy", check: true}, }, - src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff", - dest: "staff_copy", + src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff", + dest: "staff_copy", + destCmpQuery: "select first_name, last_name, address_id, email, store_id, active, username, password, datetime(last_update) from staff_copy", }, } for _, test := range testCases { @@ -508,7 +526,11 @@ func TestCopy(t *testing.T) { continue } - t.Run(test.dbName, func(t *testing.T) { + testName := test.dbName + if test.testCase != "" { + testName += "-" + test.testCase + } + t.Run(testName, func(t *testing.T) { // TODO test copy from a different DB, maybe csvq? // TODO test copy from same DB @@ -524,7 +546,7 @@ func TestCopy(t *testing.T) { t.Fatalf("Could not get rows to copy: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Second) defer cancel() var rlen int64 = 1 n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest) @@ -534,10 +556,95 @@ func TestCopy(t *testing.T) { if n != rlen { t.Fatalf("Expected to copy %d rows but got %d", rlen, n) } + + checkSameData(t, ctx, pg.DB, test.src, db.DB, test.destCmpQuery) }) } } +// checkSameData fails the test if the data in the srcDB."staff" table is different than destDB."staff_copy" table +func checkSameData(t *testing.T, ctx context.Context, srcDB *sql.DB, srcQuery string, destDB *sql.DB, destCmpQuery string) { + if destCmpQuery == "" { + srcQuery = strings.ToLower(srcQuery) + if !strings.Contains(srcQuery, "from staff") { + t.Fatalf("destCmpQuery needs to be configured if src '%s' is not for table 'staff'", srcQuery) + } + // if destCmpQuery needs special syntax, configure it in the test case definitions above + destCmpQuery = strings.Replace(srcQuery, "from staff", "from staff_copy", 1) + } + srcValues, srcColumnTypes, err := getSrcRow(ctx, srcDB, srcQuery) + if err != nil { + t.Fatalf("Could not get src row from database: %v", err) + } + destValues, err := getDestRow(ctx, destDB, destCmpQuery, srcColumnTypes) + if err != nil { + t.Fatalf("Could not get dest row from database: %v", err) + } + // Comparing more than 1 row is more complex because SELECT result order is undefined without order by + adjustDates(srcValues, destValues) + if !reflect.DeepEqual(srcValues, destValues) { + t.Fatalf("Source and dest row don't match: \n%v\n vs \n%v", srcValues, destValues) + } +} + +// adjustDates removes sub-second differences between any dates in the two rows, because +// the difference are likely caused by difference in precision and not by a copy issue +func adjustDates(src []interface{}, dest []interface{}) { + for i, v := range src { + srcDate, okSrc := v.(time.Time) + destDate, okDest := dest[i].(time.Time) + if okSrc && okDest && srcDate.Sub(destDate).Abs() <= time.Second { + dest[i] = srcDate + } + } +} + +func getSrcRow(ctx context.Context, db *sql.DB, query string) ([]interface{}, []*sql.ColumnType, error) { + rows, err := db.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + columnTypes, err := rows.ColumnTypes() + if err != nil { + return nil, nil, err + } + values, err := readRow(rows, columnTypes) + return values, columnTypes, err +} + +func getDestRow(ctx context.Context, db *sql.DB, query string, columnTypes []*sql.ColumnType) ([]interface{}, error) { + rows, err := db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + return readRow(rows, columnTypes) +} + +func readRow(rows *sql.Rows, columnTypes []*sql.ColumnType) ([]interface{}, error) { + if !rows.Next() { + return nil, errors.New("exactly one row expected but got 0") + } + // some DB drivers don't handle reading into *any well so use *reportedType instead + values := make([]interface{}, len(columnTypes)) + for i := 0; i < len(columnTypes); i++ { + values[i] = reflect.New(columnTypes[i].ScanType()).Interface() + } + err := rows.Scan(values...) + if err != nil { + return nil, err + } + if rows.Next() { + return nil, errors.New("exactly one row expected but more found") + } + // dereference the pointers + for i, v := range values { + values[i] = reflect.ValueOf(v).Elem().Interface() + } + return values, nil +} + // filesEqual compares the files at paths a and b and returns an error if // the content is not equal. Ignore is a regex. All matches will be removed // from the file contents before comparison. diff --git a/drivers/mysql/copy.go b/drivers/mysql/copy.go new file mode 100644 index 00000000000..bd3984796a5 --- /dev/null +++ b/drivers/mysql/copy.go @@ -0,0 +1,127 @@ +package mysql + +import ( + "context" + "database/sql" + "encoding/csv" + "fmt" + "io" + "os" + "reflect" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/xo/usql/drivers" +) + +func copyRows(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) { + localInfileSupported := false + row := db.QueryRowContext(ctx, "SELECT @@GLOBAL.local_infile") + err := row.Scan(&localInfileSupported) + if err == nil && localInfileSupported && !hasBlobColumn(rows) { + return bulkCopy(ctx, db, rows, table) + } else { + return drivers.CopyWithInsert(func(int) string { return "?" })(ctx, db, rows, table) + } +} + +func bulkCopy(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) { + mysql.RegisterReaderHandler("data", func() io.Reader { + return toCsvReader(rows) + }) + defer mysql.DeregisterReaderHandler("data") + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + var cnt int64 + csvSpec := " FIELDS TERMINATED BY ',' " + stmt := fmt.Sprintf("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE %s", + // if there is a column list, csvSpec goes between the table name and the list + strings.Replace(table, "(", csvSpec+" (", 1)) + // if there wasn't a column list in the table spec, csvSpec goes at the end + if !strings.Contains(table, "(") { + stmt += csvSpec + } + res, err := tx.ExecContext(ctx, stmt) + if err != nil { + tx.Rollback() + } else { + err = tx.Commit() + if err == nil { + cnt, err = res.RowsAffected() + } + } + return cnt, err +} + +func hasBlobColumn(rows *sql.Rows) bool { + columnTypes, err := rows.ColumnTypes() + if err != nil { + return false + } + for _, ct := range columnTypes { + if ct.DatabaseTypeName() == "BLOB" { + return true + } + } + return false +} + +func toCsvReader(rows *sql.Rows) io.Reader { + r, w := io.Pipe() + go writeAsCsv(rows, w) + return r +} + +// writeAsCsv writes the rows in a CSV format compatible with LOAD DATA INFILE +func writeAsCsv(rows *sql.Rows, w *io.PipeWriter) { + defer w.Close() // noop if already closed + columnTypes, err := rows.ColumnTypes() + if err != nil { + w.CloseWithError(err) + return + } + values := make([]interface{}, len(columnTypes)) + valueRefs := make([]reflect.Value, len(columnTypes)) + for i := 0; i < len(columnTypes); i++ { + valueRefs[i] = reflect.New(columnTypes[i].ScanType()) + values[i] = valueRefs[i].Interface() + } + record := make([]string, len(values)) + csvWriter := csv.NewWriter(io.MultiWriter(w, os.Stdout)) + for rows.Next() { + if err = rows.Err(); err != nil { + break + } + err = rows.Scan(values...) + if err != nil { + break + } + for i, valueRef := range valueRefs { + val := valueRef.Elem().Interface() + val = toIntIfBool(val) + // NB: Does not work for BLOBs. Use regular copy if there are BLOB columns + record[i] = fmt.Sprintf("%v", val) + } + err = csvWriter.Write(record) + if err != nil { + break + } + } + if err == nil { + csvWriter.Flush() + err = csvWriter.Error() + } + w.CloseWithError(err) // same as w.Close(), if err is nil +} + +func toIntIfBool(val interface{}) interface{} { + if boolVal, ok := val.(bool); ok { + val = 0 + if boolVal { + val = 1 + } + } + return val +} diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index 4e333ea4c5a..a19f789bf9e 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -45,7 +45,7 @@ func init() { NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer { return metadata.NewDefaultWriter(mymeta.NewReader(db, opts...))(db, w) }, - Copy: drivers.CopyWithInsert(func(int) string { return "?" }), + Copy: copyRows, NewCompleter: mymeta.NewCompleter, }, "memsql", "vitess", "tidb") }