Skip to content

Commit

Permalink
Merge pull request #306 from upper/issue-303
Browse files Browse the repository at this point in the history
WIP: Better support for raw queries and subqueries.
  • Loading branch information
José Carlos authored Dec 13, 2016
2 parents a3ef9f0 + 6df76ea commit 8232c84
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 84 deletions.
36 changes: 18 additions & 18 deletions internal/sqladapter/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type PartialDatabase interface {
FindTablePrimaryKeys(name string) ([]string, error)

NewLocalCollection(name string) db.Collection
CompileStatement(stmt *exql.Statement) (query string)
CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{})
ConnectionURL() db.ConnectionURL

Err(in error) (out error)
Expand Down Expand Up @@ -316,7 +316,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res
}

if execer, ok := d.PartialDatabase.(HasStatementExec); ok {
query = d.compileStatement(stmt)
query, args = d.compileStatement(stmt, args)
res, err = execer.StatementExec(query, args...)
return
}
Expand All @@ -325,7 +325,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res

if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
var p *Stmt
if p, query, err = d.prepareStatement(stmt); err != nil {
if p, query, args, err = d.prepareStatement(stmt, args); err != nil {
return nil, err
}
defer p.Close()
Expand All @@ -334,8 +334,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res
return
}

query = d.compileStatement(stmt)

query, args = d.compileStatement(stmt, args)
if tx != nil {
res, err = tx.(*sqlTx).Exec(query, args...)
return
Expand Down Expand Up @@ -367,7 +366,7 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro

if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
var p *Stmt
if p, query, err = d.prepareStatement(stmt); err != nil {
if p, query, args, err = d.prepareStatement(stmt, args); err != nil {
return nil, err
}
defer p.Close()
Expand All @@ -376,7 +375,7 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro
return
}

query = d.compileStatement(stmt)
query, args = d.compileStatement(stmt, args)
if tx != nil {
rows, err = tx.(*sqlTx).Query(query, args...)
return
Expand Down Expand Up @@ -410,7 +409,7 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{})

if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
var p *Stmt
if p, query, err = d.prepareStatement(stmt); err != nil {
if p, query, args, err = d.prepareStatement(stmt, args); err != nil {
return nil, err
}
defer p.Close()
Expand All @@ -419,7 +418,7 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{})
return
}

query = d.compileStatement(stmt)
query, args = d.compileStatement(stmt, args)
if tx != nil {
row = tx.(*sqlTx).QueryRow(query, args...)
return
Expand All @@ -439,47 +438,48 @@ func (d *database) Driver() interface{} {
}

// compileStatement compiles the given statement into a string.
func (d *database) compileStatement(stmt *exql.Statement) string {
return d.PartialDatabase.CompileStatement(stmt)
func (d *database) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) {
return d.PartialDatabase.CompileStatement(stmt, args)
}

// prepareStatement compiles a query and tries to use previously generated
// statement.
func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) {
func (d *database) prepareStatement(stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) {
d.sessMu.Lock()
defer d.sessMu.Unlock()

sess, tx := d.sess, d.Transaction()
if sess == nil && tx == nil {
return nil, "", db.ErrNotConnected
return nil, "", nil, db.ErrNotConnected
}

pc, ok := d.cachedStatements.ReadRaw(stmt)
if ok {
// The statement was cached.
ps, err := pc.(*Stmt).Open()
if err == nil {
return ps, ps.query, nil
_, args = d.compileStatement(stmt, args)
return ps, ps.query, args, nil
}
}

query := d.compileStatement(stmt)
query, args := d.compileStatement(stmt, args)
sqlStmt, err := func(query *string) (*sql.Stmt, error) {
if tx != nil {
return tx.(*sqlTx).Prepare(*query)
}
return sess.Prepare(*query)
}(&query)
if err != nil {
return nil, "", err
return nil, "", nil, err
}

p, err := NewStatement(sqlStmt, query).Open()
if err != nil {
return nil, query, err
return nil, query, args, err
}
d.cachedStatements.Write(stmt, p)
return p, p.query, nil
return p, p.query, args, nil
}

var waitForConnMu sync.Mutex
Expand Down
12 changes: 9 additions & 3 deletions lib/sqlbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, e
return b.sess.StatementExec(q, args...)
case string:
return b.sess.StatementExec(exql.RawSQL(q), args...)
case db.RawValue:
return b.Exec(q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("Unsupported query type %T.", query)
}
Expand All @@ -124,6 +126,8 @@ func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, e
return b.sess.StatementQuery(q, args...)
case string:
return b.sess.StatementQuery(exql.RawSQL(q), args...)
case db.RawValue:
return b.Query(q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("Unsupported query type %T.", query)
}
Expand All @@ -135,6 +139,8 @@ func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row,
return b.sess.StatementQueryRow(q, args...)
case string:
return b.sess.StatementQueryRow(exql.RawSQL(q), args...)
case db.RawValue:
return b.QueryRow(q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("Unsupported query type %T.", query)
}
Expand Down Expand Up @@ -320,7 +326,7 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql
for i := 0; i < l; i++ {
switch v := columns[i].(type) {
case *selector:
expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments()...)
expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments())
f[i] = exql.RawValue(expanded)
args = append(args, rawArgs...)
case db.Function:
Expand All @@ -330,11 +336,11 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql
} else {
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
expanded, fnArgs := expandPlaceholders(fnName, fnArgs...)
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
f[i] = exql.RawValue(expanded)
args = append(args, fnArgs...)
case db.RawValue:
expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments()...)
expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments())
f[i] = exql.RawValue(expanded)
args = append(args, rawArgs...)
case exql.Fragment:
Expand Down
100 changes: 60 additions & 40 deletions lib/sqlbuilder/convert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlbuilder

import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
Expand All @@ -24,49 +25,64 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils {
return &templateWithUtils{template}
}

func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) {
func expandQuery(in string, args []interface{}, fn func(interface{}) (string, []interface{})) (string, []interface{}) {
argn := 0
argx := make([]interface{}, 0, len(args))
for i := 0; i < len(in); i++ {
if in[i] == '?' {
if len(args) > argn {
k := `?`

values, isSlice := toInterfaceArguments(args[argn])
if isSlice {
if len(values) == 0 {
k = `(NULL)`
} else {
k = `(?` + strings.Repeat(`, ?`, len(values)-1) + `)`
}
} else {
if len(values) == 1 {
switch t := values[0].(type) {
case db.RawValue:
k, values = t.Raw(), nil
case *selector:
k, values = `(`+t.statement().Compile(t.stringer.t)+`)`, t.Arguments()
}
} else if len(values) == 0 {
k = `NULL`
}
}

if k != `?` {
in = in[:i] + k + in[i+1:]
i += len(k) - 1
}

if len(values) > 0 {
argx = append(argx, values...)
}
argn++
if in[i] != '?' {
continue
}
if len(args) > argn {
k, values := fn(args[argn])
if k != "" {
in = in[:i] + k + in[i+1:]
i += len(k) - 1
}
if len(values) > 0 {
argx = append(argx, values...)
}
argn++
}
}
if len(argx) < len(args) {
argx = append(argx, args[argn:]...)
}
return in, argx
}

func preprocessFn(arg interface{}) (string, []interface{}) {
values, isSlice := toInterfaceArguments(arg)

if isSlice {
if len(values) == 0 {
return `(NULL)`, nil
}
return `(?` + strings.Repeat(`, ?`, len(values)-1) + `)`, values
}

if len(values) == 1 {
switch t := arg.(type) {
case db.RawValue:
return Preprocess(t.Raw(), t.Arguments())
case *selector:
return `(` + t.statement().Compile(t.stringer.t) + `)`, t.Arguments()
}
} else if len(values) == 0 {
return `NULL`, nil
}

return "", []interface{}{arg}
}

func Preprocess(in string, args []interface{}) (string, []interface{}) {
return expandQuery(in, args, preprocessFn)
}

func expandPlaceholders(in string, args []interface{}) (string, []interface{}) {
// TODO: Remove after immutable query builder
return in, args
}

// ToWhereWithArguments converts the given parameters into a exql.Where
// value.
func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) {
Expand All @@ -77,7 +93,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.
if len(t) > 0 {
if s, ok := t[0].(string); ok {
if strings.ContainsAny(s, "?") || len(t) == 1 {
s, args = expandPlaceholders(s, t[1:]...)
s, args = expandPlaceholders(s, t[1:])
where.Conditions = []exql.Fragment{exql.RawValue(s)}
} else {
var val interface{}
Expand Down Expand Up @@ -106,7 +122,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.
}
return
case db.RawValue:
r, v := expandPlaceholders(t.Raw(), t.Arguments()...)
r, v := expandPlaceholders(t.Raw(), t.Arguments())
where.Conditions = []exql.Fragment{exql.RawValue(r)}
args = append(args, v...)
return
Expand Down Expand Up @@ -182,12 +198,16 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []

// toInterfaceArguments converts the given value into an array of interfaces.
func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) {
v := reflect.ValueOf(value)

if value == nil {
return nil, false
}

switch t := value.(type) {
case driver.Valuer:
return []interface{}{t}, false
}

v := reflect.ValueOf(value)
if v.Type().Kind() == reflect.Slice {
var i, total int

Expand Down Expand Up @@ -274,11 +294,11 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal
// A function with one or more arguments.
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
expanded, fnArgs := expandPlaceholders(fnName, fnArgs...)
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
columnValue.Value = exql.RawValue(expanded)
args = append(args, fnArgs...)
case db.RawValue:
expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()...)
expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments())
columnValue.Value = exql.RawValue(expanded)
args = append(args, rawArgs...)
default:
Expand Down
Loading

0 comments on commit 8232c84

Please sign in to comment.