From 080115d1aad6009facfc7435086f0ba10d55cc8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= Date: Thu, 5 Jan 2017 21:01:47 +0000 Subject: [PATCH 1/2] Add a test case for #316 --- lib/sqlbuilder/builder_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 07c69f9c..2438ff10 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -959,6 +959,13 @@ func TestUpdate(t *testing.T) { "id = id + ?", 10, ).Where("id > ?", 0).String(), ) + + assert.Equal( + `UPDATE "posts" SET "tags" = array_remove(tags, $1) WHERE (hub_id = $2 AND $3 = ANY(tags) AND $4 = ANY(tags))`, + b.Update("posts").Set( + db.Cond{"tags": db.Raw("array_remove(tags, ?)", "foo")}, + ).Where(db.Raw("hub_id = ? AND ? = ANY(tags) AND ? = ANY(tags)", 1, "bar", "baz")).String(), + ) } func TestDelete(t *testing.T) { From 2c780beddf065414caf3488c26c1b5d7068ee0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= Date: Thu, 5 Jan 2017 22:03:52 +0000 Subject: [PATCH 2/2] Closes #316 --- internal/sqladapter/exql/column_value.go | 9 +++-- lib/sqlbuilder/builder.go | 4 ++ lib/sqlbuilder/builder_test.go | 47 +++++++++++++++++++++--- lib/sqlbuilder/convert.go | 20 +++++++++- lib/sqlbuilder/update.go | 46 ++++++++++++----------- 5 files changed, 95 insertions(+), 31 deletions(-) diff --git a/internal/sqladapter/exql/column_value.go b/internal/sqladapter/exql/column_value.go index 8834ee8a..d532705a 100644 --- a/internal/sqladapter/exql/column_value.go +++ b/internal/sqladapter/exql/column_value.go @@ -31,9 +31,12 @@ func (c *ColumnValue) Compile(layout *Template) (compiled string) { } data := columnValueT{ - c.Column.Compile(layout), - c.Operator, - c.Value.Compile(layout), + Column: c.Column.Compile(layout), + Operator: c.Operator, + } + + if c.Value != nil { + data.Value = c.Value.Compile(layout) } compiled = mustParse(layout.ColumnValue, data) diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index bd6a350e..d11c17b5 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -301,6 +301,10 @@ func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) return nil, nil, ErrExpectingPointerToEitherMapOrStruct } + if len(fv.fields) == 0 { + return nil, nil, errors.New("No values mapped.") + } + sort.Sort(&fv) return fv.fields, fv.values, nil diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 2438ff10..1eeb3b7f 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -960,12 +960,49 @@ func TestUpdate(t *testing.T) { ).Where("id > ?", 0).String(), ) - assert.Equal( - `UPDATE "posts" SET "tags" = array_remove(tags, $1) WHERE (hub_id = $2 AND $3 = ANY(tags) AND $4 = ANY(tags))`, - b.Update("posts").Set( + { + q := b.Update("posts").Set("column = ?", "foo") + + assert.Equal( + `UPDATE "posts" SET "column" = $1`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo"}, + q.Arguments(), + ) + } + + { + q := b.Update("posts").Set(db.Raw("column = ?", "foo")) + + assert.Equal( + `UPDATE "posts" SET column = $1`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo"}, + q.Arguments(), + ) + } + + { + q := b.Update("posts").Set( db.Cond{"tags": db.Raw("array_remove(tags, ?)", "foo")}, - ).Where(db.Raw("hub_id = ? AND ? = ANY(tags) AND ? = ANY(tags)", 1, "bar", "baz")).String(), - ) + ).Where(db.Raw("hub_id = ? AND ? = ANY(tags) AND ? = ANY(tags)", 1, "bar", "baz")) + + assert.Equal( + `UPDATE "posts" SET "tags" = array_remove(tags, $1) WHERE (hub_id = $2 AND $3 = ANY(tags) AND $4 = ANY(tags))`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo", 1, "bar", "baz"}, + q.Arguments(), + ) + } } func TestDelete(t *testing.T) { diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index db8a592a..308347cc 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -172,7 +172,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql. func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) { switch t := in.(type) { case db.RawValue: - return exql.RawValue(t.String()), nil + return exql.RawValue(t.String()), t.Arguments() case db.Function: fnName := t.Name() fnArgs := []interface{}{} @@ -230,7 +230,14 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal case []interface{}: l := len(t) for i := 0; i < l; i++ { - column := t[i].(string) + column, ok := t[i].(string) + + if !ok { + p, q := tu.ToColumnValues(t[i]) + cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) + args = append(args, q...) + continue + } if !strings.ContainsAny(column, "=") { column = fmt.Sprintf("%s = ?", column) @@ -337,6 +344,15 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal cv.ColumnValues = append(cv.ColumnValues, &columnValue) + return cv, args + case db.RawValue: + columnValue := exql.ColumnValue{} + p, q := Preprocess(t.Raw(), t.Arguments()) + + columnValue.Column = exql.RawValue(p) + args = append(args, q...) + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) return cv, args case db.Constraints: for _, c := range t.Constraints() { diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index cb1b09bd..4faede8b 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -23,34 +23,38 @@ type updater struct { mu sync.Mutex } -func (qu *updater) Set(terms ...interface{}) Updater { - if len(terms) == 1 { - ff, vv, _ := Map(terms[0], nil) +func (qu *updater) Set(columns ...interface{}) Updater { - cvs := make([]exql.Fragment, 0, len(ff)) - args := make([]interface{}, 0, len(vv)) + if len(columns) == 1 { + ff, vv, err := Map(columns[0], nil) + if err == nil { - for i := range ff { - cv := &exql.ColumnValue{ - Column: exql.ColumnWithName(ff[i]), - Operator: qu.builder.t.AssignmentOperator, - } + cvs := make([]exql.Fragment, 0, len(ff)) + args := make([]interface{}, 0, len(vv)) - var localArgs []interface{} - cv.Value, localArgs = qu.builder.t.PlaceholderValue(vv[i]) + for i := range ff { + cv := &exql.ColumnValue{ + Column: exql.ColumnWithName(ff[i]), + Operator: qu.builder.t.AssignmentOperator, + } - args = append(args, localArgs...) - cvs = append(cvs, cv) - } + var localArgs []interface{} + cv.Value, localArgs = qu.builder.t.PlaceholderValue(vv[i]) + + args = append(args, localArgs...) + cvs = append(cvs, cv) + } - qu.columnValues.Insert(cvs...) - qu.columnValuesArgs = append(qu.columnValuesArgs, args...) - } else if len(terms) > 1 { - cv, arguments := qu.builder.t.ToColumnValues(terms) - qu.columnValues.Insert(cv.ColumnValues...) - qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...) + qu.columnValues.Insert(cvs...) + qu.columnValuesArgs = append(qu.columnValuesArgs, args...) + return qu + } } + cv, arguments := qu.builder.t.ToColumnValues(columns) + qu.columnValues.Insert(cv.ColumnValues...) + qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...) + return qu }