diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index 20c9a0ad..6c63a2f5 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -49,7 +49,7 @@ type PartialDatabase interface { FindTablePrimaryKeys(name string) ([]string, error) NewLocalCollection(name string) db.Collection - CompileStatement(stmt *exql.Statement) (query string) + CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) ConnectionURL() db.ConnectionURL Err(in error) (out error) @@ -316,7 +316,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res } if execer, ok := d.PartialDatabase.(HasStatementExec); ok { - query = d.compileStatement(stmt) + query, args = d.compileStatement(stmt, args) res, err = execer.StatementExec(query, args...) return } @@ -325,7 +325,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res if db.Conf.PreparedStatementCacheEnabled() && tx == nil { var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { + if p, query, args, err = d.prepareStatement(stmt, args); err != nil { return nil, err } defer p.Close() @@ -334,8 +334,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res return } - query = d.compileStatement(stmt) - + query, args = d.compileStatement(stmt, args) if tx != nil { res, err = tx.(*sqlTx).Exec(query, args...) return @@ -367,7 +366,7 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro if db.Conf.PreparedStatementCacheEnabled() && tx == nil { var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { + if p, query, args, err = d.prepareStatement(stmt, args); err != nil { return nil, err } defer p.Close() @@ -376,7 +375,7 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro return } - query = d.compileStatement(stmt) + query, args = d.compileStatement(stmt, args) if tx != nil { rows, err = tx.(*sqlTx).Query(query, args...) return @@ -410,7 +409,7 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) if db.Conf.PreparedStatementCacheEnabled() && tx == nil { var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { + if p, query, args, err = d.prepareStatement(stmt, args); err != nil { return nil, err } defer p.Close() @@ -419,7 +418,7 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) return } - query = d.compileStatement(stmt) + query, args = d.compileStatement(stmt, args) if tx != nil { row = tx.(*sqlTx).QueryRow(query, args...) return @@ -439,19 +438,19 @@ func (d *database) Driver() interface{} { } // compileStatement compiles the given statement into a string. -func (d *database) compileStatement(stmt *exql.Statement) string { - return d.PartialDatabase.CompileStatement(stmt) +func (d *database) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + return d.PartialDatabase.CompileStatement(stmt, args) } // prepareStatement compiles a query and tries to use previously generated // statement. -func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) { +func (d *database) prepareStatement(stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { d.sessMu.Lock() defer d.sessMu.Unlock() sess, tx := d.sess, d.Transaction() if sess == nil && tx == nil { - return nil, "", db.ErrNotConnected + return nil, "", nil, db.ErrNotConnected } pc, ok := d.cachedStatements.ReadRaw(stmt) @@ -459,11 +458,12 @@ func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) // The statement was cached. ps, err := pc.(*Stmt).Open() if err == nil { - return ps, ps.query, nil + _, args = d.compileStatement(stmt, args) + return ps, ps.query, args, nil } } - query := d.compileStatement(stmt) + query, args := d.compileStatement(stmt, args) sqlStmt, err := func(query *string) (*sql.Stmt, error) { if tx != nil { return tx.(*sqlTx).Prepare(*query) @@ -471,15 +471,15 @@ func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) return sess.Prepare(*query) }(&query) if err != nil { - return nil, "", err + return nil, "", nil, err } p, err := NewStatement(sqlStmt, query).Open() if err != nil { - return nil, query, err + return nil, query, args, err } d.cachedStatements.Write(stmt, p) - return p, p.query, nil + return p, p.query, args, nil } var waitForConnMu sync.Mutex diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index b5eb8623..97d99bc2 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -113,6 +113,8 @@ func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, e return b.sess.StatementExec(q, args...) case string: return b.sess.StatementExec(exql.RawSQL(q), args...) + case db.RawValue: + return b.Exec(q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } @@ -124,6 +126,8 @@ func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, e return b.sess.StatementQuery(q, args...) case string: return b.sess.StatementQuery(exql.RawSQL(q), args...) + case db.RawValue: + return b.Query(q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } @@ -135,6 +139,8 @@ func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, return b.sess.StatementQueryRow(q, args...) case string: return b.sess.StatementQueryRow(exql.RawSQL(q), args...) + case db.RawValue: + return b.QueryRow(q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } @@ -320,7 +326,7 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql for i := 0; i < l; i++ { switch v := columns[i].(type) { case *selector: - expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments()...) + expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments()) f[i] = exql.RawValue(expanded) args = append(args, rawArgs...) case db.Function: @@ -330,11 +336,11 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) f[i] = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: - expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments()...) + expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments()) f[i] = exql.RawValue(expanded) args = append(args, rawArgs...) case exql.Fragment: diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index af03f883..6c999aeb 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "database/sql/driver" "fmt" "reflect" "strings" @@ -24,49 +25,64 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils { return &templateWithUtils{template} } -func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) { +func expandQuery(in string, args []interface{}, fn func(interface{}) (string, []interface{})) (string, []interface{}) { argn := 0 argx := make([]interface{}, 0, len(args)) for i := 0; i < len(in); i++ { - if in[i] == '?' { - if len(args) > argn { - k := `?` - - values, isSlice := toInterfaceArguments(args[argn]) - if isSlice { - if len(values) == 0 { - k = `(NULL)` - } else { - k = `(?` + strings.Repeat(`, ?`, len(values)-1) + `)` - } - } else { - if len(values) == 1 { - switch t := values[0].(type) { - case db.RawValue: - k, values = t.Raw(), nil - case *selector: - k, values = `(`+t.statement().Compile(t.stringer.t)+`)`, t.Arguments() - } - } else if len(values) == 0 { - k = `NULL` - } - } - - if k != `?` { - in = in[:i] + k + in[i+1:] - i += len(k) - 1 - } - - if len(values) > 0 { - argx = append(argx, values...) - } - argn++ + if in[i] != '?' { + continue + } + if len(args) > argn { + k, values := fn(args[argn]) + if k != "" { + in = in[:i] + k + in[i+1:] + i += len(k) - 1 } + if len(values) > 0 { + argx = append(argx, values...) + } + argn++ } } + if len(argx) < len(args) { + argx = append(argx, args[argn:]...) + } return in, argx } +func preprocessFn(arg interface{}) (string, []interface{}) { + values, isSlice := toInterfaceArguments(arg) + + if isSlice { + if len(values) == 0 { + return `(NULL)`, nil + } + return `(?` + strings.Repeat(`, ?`, len(values)-1) + `)`, values + } + + if len(values) == 1 { + switch t := arg.(type) { + case db.RawValue: + return Preprocess(t.Raw(), t.Arguments()) + case *selector: + return `(` + t.statement().Compile(t.stringer.t) + `)`, t.Arguments() + } + } else if len(values) == 0 { + return `NULL`, nil + } + + return "", []interface{}{arg} +} + +func Preprocess(in string, args []interface{}) (string, []interface{}) { + return expandQuery(in, args, preprocessFn) +} + +func expandPlaceholders(in string, args []interface{}) (string, []interface{}) { + // TODO: Remove after immutable query builder + return in, args +} + // ToWhereWithArguments converts the given parameters into a exql.Where // value. func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { @@ -77,7 +93,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql. if len(t) > 0 { if s, ok := t[0].(string); ok { if strings.ContainsAny(s, "?") || len(t) == 1 { - s, args = expandPlaceholders(s, t[1:]...) + s, args = expandPlaceholders(s, t[1:]) where.Conditions = []exql.Fragment{exql.RawValue(s)} } else { var val interface{} @@ -106,7 +122,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql. } return case db.RawValue: - r, v := expandPlaceholders(t.Raw(), t.Arguments()...) + r, v := expandPlaceholders(t.Raw(), t.Arguments()) where.Conditions = []exql.Fragment{exql.RawValue(r)} args = append(args, v...) return @@ -182,12 +198,16 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] // toInterfaceArguments converts the given value into an array of interfaces. func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) { - v := reflect.ValueOf(value) - if value == nil { return nil, false } + switch t := value.(type) { + case driver.Valuer: + return []interface{}{t}, false + } + + v := reflect.ValueOf(value) if v.Type().Kind() == reflect.Slice { var i, total int @@ -274,11 +294,11 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal // A function with one or more arguments. fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) columnValue.Value = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: - expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()...) + expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()) columnValue.Value = exql.RawValue(expanded) args = append(args, rawArgs...) default: diff --git a/lib/sqlbuilder/placeholder_test.go b/lib/sqlbuilder/placeholder_test.go index 3f05da39..80917b71 100644 --- a/lib/sqlbuilder/placeholder_test.go +++ b/lib/sqlbuilder/placeholder_test.go @@ -9,74 +9,74 @@ import ( func TestPlaceholderSimple(t *testing.T) { { - ret, _ := expandPlaceholders("?", 1) + ret, _ := Preprocess("?", []interface{}{1}) assert.Equal(t, "?", ret) } { - ret, _ := expandPlaceholders("?") + ret, _ := Preprocess("?", nil) assert.Equal(t, "?", ret) } } func TestPlaceholderMany(t *testing.T) { { - ret, _ := expandPlaceholders("?, ?, ?", 1, 2, 3) + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, 3}) assert.Equal(t, "?, ?, ?", ret) } } func TestPlaceholderArray(t *testing.T) { { - ret, _ := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) assert.Equal(t, "?, ?, (?, ?, ?)", ret) } { - ret, _ := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + ret, _ := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) assert.Equal(t, "(?, ?, ?), ?, ?", ret) } { - ret, _ := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + ret, _ := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) assert.Equal(t, "?, (?, ?, ?), ?", ret) } { - ret, _ := expandPlaceholders("???", 1, []interface{}{2, 3, 4}, 5) + ret, _ := Preprocess("???", []interface{}{1, []interface{}{2, 3, 4}, 5}) assert.Equal(t, "?(?, ?, ?)?", ret) } { - ret, _ := expandPlaceholders("??", []interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}) + ret, _ := Preprocess("??", []interface{}{[]interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}}) assert.Equal(t, "(?, ?, ?)(NULL)", ret) } } func TestPlaceholderArguments(t *testing.T) { { - _, args := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + _, args := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } { - _, args := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + _, args := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } { - _, args := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + _, args := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } { - _, args := expandPlaceholders("?, ?", []interface{}{1, 2, 3}, []interface{}{4, 5}) + _, args := Preprocess("?, ?", []interface{}{[]interface{}{1, 2, 3}, []interface{}{4, 5}}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } } func TestPlaceholderReplace(t *testing.T) { { - ret, args := expandPlaceholders("?, ?, ?", 1, db.Raw("foo"), 3) + ret, args := Preprocess("?, ?, ?", []interface{}{1, db.Raw("foo"), 3}) assert.Equal(t, "?, foo, ?", ret) assert.Equal(t, []interface{}{1, 3}, args) } diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 695748ee..7df9f272 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -156,7 +156,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { switch value := columns[i].(type) { case db.RawValue: - col, args := expandPlaceholders(value.Raw(), value.Arguments()...) + col, args := expandPlaceholders(value.Raw(), value.Arguments()) sort = &exql.SortColumn{ Column: exql.RawValue(col), } @@ -170,7 +170,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) sort = &exql.SortColumn{ Column: exql.RawValue(expanded), } diff --git a/mysql/database.go b/mysql/database.go index a00d9a7c..2cb66f73 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -149,8 +149,8 @@ func (d *database) clone() (*database, error) { // CompileStatement allows sqladapter to compile the given statement into the // format MySQL expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return stmt.Compile(template) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + return sqlbuilder.Preprocess(stmt.Compile(template), args) } // Err allows sqladapter to translate some known errors into generic errors. diff --git a/postgresql/database.go b/postgresql/database.go index ed4257bd..8b92ba2d 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -148,8 +148,9 @@ func (d *database) clone() (*database, error) { // CompileStatement allows sqladapter to compile the given statement into the // format PostgreSQL expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return sqladapter.ReplaceWithDollarSign(stmt.Compile(template)) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + query, args := sqlbuilder.Preprocess(stmt.Compile(template), args) + return sqladapter.ReplaceWithDollarSign(query), args } // Err allows sqladapter to translate some known errors into generic errors. diff --git a/postgresql/local_test.go b/postgresql/local_test.go index b09e86ad..82a9069e 100644 --- a/postgresql/local_test.go +++ b/postgresql/local_test.go @@ -117,3 +117,92 @@ func TestIssue210(t *testing.T) { _, err = sess.Collection("hello").Find().Count() assert.NoError(t, err) } + +func TestNonTrivialSubqueries(t *testing.T) { + sess := mustOpen() + defer sess.Close() + + { + q, err := sess.Query(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + ) + + assert.NoError(t, err) + assert.NotNil(t, q) + + assert.True(t, q.Next()) + + var number int + assert.NoError(t, q.Scan(&number)) + + assert.Equal(t, 1, number) + assert.NoError(t, q.Close()) + } + + { + row, err := sess.QueryRow(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + ) + + assert.NoError(t, err) + assert.NotNil(t, row) + + var number int + assert.NoError(t, row.Scan(&number)) + + assert.Equal(t, 1, number) + } + + { + res, err := sess.Exec(`UPDATE artist a1 SET id = ?`, + sess.Select(db.Raw("id + 1")).From("artist a2").Where("a2.id = a1.id"), + ) + + assert.NoError(t, err) + assert.NotNil(t, res) + } + + { + q, err := sess.Query(db.Raw(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + )) + + assert.NoError(t, err) + assert.NotNil(t, q) + + assert.True(t, q.Next()) + + var number int + assert.NoError(t, q.Scan(&number)) + + assert.Equal(t, 2, number) + assert.NoError(t, q.Close()) + } + + { + row, err := sess.QueryRow(db.Raw(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + )) + + assert.NoError(t, err) + assert.NotNil(t, row) + + var number int + assert.NoError(t, row.Scan(&number)) + + assert.Equal(t, 2, number) + } + + { + res, err := sess.Exec(db.Raw(`UPDATE artist a1 SET id = ?`, + sess.Select(db.Raw("id + 1")).From("artist a2").Where("a2.id = a1.id"), + )) + + assert.NoError(t, err) + assert.NotNil(t, res) + } +} diff --git a/ql/database.go b/ql/database.go index 16116fb8..59e2d48f 100644 --- a/ql/database.go +++ b/ql/database.go @@ -227,8 +227,9 @@ func (d *database) clone() (*database, error) { // CompileStatement allows sqladapter to compile the given statement into the // format SQLite expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return sqladapter.ReplaceWithDollarSign(stmt.Compile(template)) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + query, args := sqlbuilder.Preprocess(stmt.Compile(template), args) + return sqladapter.ReplaceWithDollarSign(query), args } // Err allows sqladapter to translate some known errors into generic errors. diff --git a/sqlite/database.go b/sqlite/database.go index 074a764f..96c85d70 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -169,8 +169,8 @@ func (d *database) clone() (*database, error) { // CompileStatement allows sqladapter to compile the given statement into the // format SQLite expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return stmt.Compile(template) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + return sqlbuilder.Preprocess(stmt.Compile(template), args) } // Err allows sqladapter to translate some known errors into generic errors.