Skip to content

Commit

Permalink
include complex function with static arguments (#59)
Browse files Browse the repository at this point in the history
* include complex function with static arguments

also fix nested rendering issue

* fix another small whitespace issue
  • Loading branch information
Ferada authored Jan 17, 2022
1 parent 3b39011 commit 1524f8a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
31 changes: 21 additions & 10 deletions db/chain/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestExpressionChain_Render(t *testing.T) {
OrWhere("field3 > ?", "pajarito").
OrHaving("haveable < ?", 1).
AndHaving("moreHaveable == ?", 3),
want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR field3 > $3 HAVING moreHaveable == $4 OR haveable < $5",
want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR field3 > $3 HAVING moreHaveable == $4 OR haveable < $5",
wantArgs: []interface{}{1, 2, "pajarito", 3, 1},
wantErr: false,
},
Expand All @@ -110,7 +110,7 @@ func TestExpressionChain_Render(t *testing.T) {
AndWhere("field1 > ?", 1).
AndWhere("field2 = ?", 2).
OrWhereGroup(NewNoDB().AndWhere("inner = ?", 1).AndWhere("inner2 > ?", 2)),
want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR ( inner = $3 AND inner2 > $4)",
want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR (inner = $3 AND inner2 > $4)",
wantArgs: []interface{}{1, 2, 1, 2},
wantErr: false,
},
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestExpressionChain_Render(t *testing.T) {
AndWhere("field2 = ?", 2).
AndWhere("field3 > ?", "pajarito").
Join("another_convenient_table", "pirulo = ?", "unpirulo"),
want: "DELETE FROM convenient_table JOIN another_convenient_table ON pirulo = $1 WHERE field1 > $2 AND field2 = $3 AND field3 > $4",
want: "DELETE FROM convenient_table JOIN another_convenient_table ON pirulo = $1 WHERE field1 > $2 AND field2 = $3 AND field3 > $4",
wantArgs: []interface{}{"unpirulo", 1, 2, "pajarito"},
wantErr: false,
},
Expand All @@ -178,9 +178,9 @@ func TestExpressionChain_Render(t *testing.T) {
name: "basic insert multi",
chain: func() *ExpressionChain {
cn, err := NewNoDB().InsertMulti(map[string][]interface{}{
"field1": []interface{}{"value1", "value1.1"},
"field2": []interface{}{2, 22},
"field3": []interface{}{"blah", "blah2"}})
"field1": {"value1", "value1.1"},
"field2": {2, 22},
"field3": {"blah", "blah2"}})
if err != nil {
t.Logf("insert multi failed: %v", err)
t.FailNow()
Expand All @@ -204,9 +204,9 @@ func TestExpressionChain_Render(t *testing.T) {
name: "insert multi with chan value",
chain: func() *ExpressionChain {
cn, err := NewNoDB().InsertMulti(map[string][]interface{}{
"field1": []interface{}{"value1", "value1.1"},
"field2": []interface{}{2, NewNoDB().Select("MAX(value)").From("table").AndWhere("arbitrary = ?", 222)},
"field3": []interface{}{"blah", "blah2"}})
"field1": {"value1", "value1.1"},
"field2": {2, NewNoDB().Select("MAX(value)").From("table").AndWhere("arbitrary = ?", 222)},
"field3": {"blah", "blah2"}})
if err != nil {
t.Logf("insert multi failed: %v", err)
t.FailNow()
Expand Down Expand Up @@ -481,7 +481,7 @@ func TestExpressionChain_Render(t *testing.T) {
Where(NewNoDB().AndWhere(Equals("atablename.field1"), "something"))
}).
Returning("atablename.field2"),
want: "INSERT INTO atablename (field1, field2) VALUES ($1, $2) ON CONFLICT ( field1 ) DO UPDATE SET (field2) = (atablename.field2 + 1) WHERE atablename.field1 = $3 RETURNING atablename.field2",
want: "INSERT INTO atablename (field1, field2) VALUES ($1, $2) ON CONFLICT ( field1 ) DO UPDATE SET (field2) = (atablename.field2 + 1) WHERE atablename.field1 = $3 RETURNING atablename.field2",
wantArgs: []interface{}{"somethingelse", 2, "something"},
wantErr: false,
},
Expand Down Expand Up @@ -600,6 +600,17 @@ func TestExpressionChain_Render(t *testing.T) {
wantArgs: []interface{}{1},
wantErr: false,
},
{
name: "Complex function with static arguments gets included",
chain: func() *ExpressionChain {
f := ComplexFunction("COALESCE").Static("true").Static("false")
ec := NewNoDB().Select("true").AndWhereGroup(NewNoDB().OrWhere(f.Fn()).OrWhere(f.Fn()))
return ec
}(),
want: "SELECT true WHERE (COALESCE(true, false) OR COALESCE(true, false))",
wantArgs: []interface{}{},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
4 changes: 1 addition & 3 deletions db/chain/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ func (ec *ExpressionChain) whereGroup(c *ExpressionChain, whereFunc baseSegmentF
dst.WriteRune('(')
whereArgs := c.renderWhereRaw(dst)
dst.WriteRune(')')
if len(whereArgs) > 0 {
whereFunc(dst.String(), whereArgs...)
}
whereFunc(dst.String(), whereArgs...)
}

// appendExpandedOp is the constructor of the most common chain segment.
Expand Down
11 changes: 8 additions & 3 deletions db/chain/rendering.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,14 @@ func (ec *ExpressionChain) render(raw bool, query *strings.Builder) ([]interface
}
if ec.mainOperation.segment == sqlSelect {
query.WriteString("SELECT ")
query.WriteString(expression)
if ec.mainOperation.segment == sqlSelect {
query.WriteString(expression)
}
} else {
query.WriteString("DELETE ")
query.WriteString("DELETE")
}
if len(ec.mainOperation.arguments) != 0 {
query.WriteRune(' ')
}
// FROM
if ec.table == "" && ec.mainOperation.segment == sqlDelete {
Expand Down Expand Up @@ -219,7 +224,7 @@ func (ec *ExpressionChain) render(raw bool, query *strings.Builder) ([]interface

// WHERE
if segmentsPresent(ec, sqlWhere) > 0 {
query.WriteString(" WHERE")
query.WriteString(" WHERE ")
args = append(args, ec.renderWhereRaw(query)...)
}

Expand Down
2 changes: 1 addition & 1 deletion db/chain/segment.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func (q *querySegmentAtom) render(firstForSegment, lastForSegment bool,
if !firstForSegment {
dst.WriteRune(' ')
dst.WriteString(string(q.sqlBool))
dst.WriteRune(' ')
}
dst.WriteRune(' ')
dst.WriteString(q.expression)
return q.arguments
}

0 comments on commit 1524f8a

Please sign in to comment.