From 1524f8abaf15ec5dc33b092cb28fafe65d9b7c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olof-Joachim=20Frahm=20=28=E6=AC=A7=E9=9B=85=E7=A6=8F=29?= Date: Mon, 17 Jan 2022 13:31:06 +0100 Subject: [PATCH] include complex function with static arguments (#59) * include complex function with static arguments also fix nested rendering issue * fix another small whitespace issue --- db/chain/chain_test.go | 31 +++++++++++++++++++++---------- db/chain/expressions.go | 4 +--- db/chain/rendering.go | 11 ++++++++--- db/chain/segment.go | 2 +- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/db/chain/chain_test.go b/db/chain/chain_test.go index 787fbba..f3498c9 100644 --- a/db/chain/chain_test.go +++ b/db/chain/chain_test.go @@ -99,7 +99,7 @@ func TestExpressionChain_Render(t *testing.T) { OrWhere("field3 > ?", "pajarito"). OrHaving("haveable < ?", 1). AndHaving("moreHaveable == ?", 3), - want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR field3 > $3 HAVING moreHaveable == $4 OR haveable < $5", + want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR field3 > $3 HAVING moreHaveable == $4 OR haveable < $5", wantArgs: []interface{}{1, 2, "pajarito", 3, 1}, wantErr: false, }, @@ -110,7 +110,7 @@ func TestExpressionChain_Render(t *testing.T) { AndWhere("field1 > ?", 1). AndWhere("field2 = ?", 2). OrWhereGroup(NewNoDB().AndWhere("inner = ?", 1).AndWhere("inner2 > ?", 2)), - want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR ( inner = $3 AND inner2 > $4)", + want: "SELECT field1, field2, field3 FROM convenient_table WHERE field1 > $1 AND field2 = $2 OR (inner = $3 AND inner2 > $4)", wantArgs: []interface{}{1, 2, 1, 2}, wantErr: false, }, @@ -162,7 +162,7 @@ func TestExpressionChain_Render(t *testing.T) { AndWhere("field2 = ?", 2). AndWhere("field3 > ?", "pajarito"). Join("another_convenient_table", "pirulo = ?", "unpirulo"), - want: "DELETE FROM convenient_table JOIN another_convenient_table ON pirulo = $1 WHERE field1 > $2 AND field2 = $3 AND field3 > $4", + want: "DELETE FROM convenient_table JOIN another_convenient_table ON pirulo = $1 WHERE field1 > $2 AND field2 = $3 AND field3 > $4", wantArgs: []interface{}{"unpirulo", 1, 2, "pajarito"}, wantErr: false, }, @@ -178,9 +178,9 @@ func TestExpressionChain_Render(t *testing.T) { name: "basic insert multi", chain: func() *ExpressionChain { cn, err := NewNoDB().InsertMulti(map[string][]interface{}{ - "field1": []interface{}{"value1", "value1.1"}, - "field2": []interface{}{2, 22}, - "field3": []interface{}{"blah", "blah2"}}) + "field1": {"value1", "value1.1"}, + "field2": {2, 22}, + "field3": {"blah", "blah2"}}) if err != nil { t.Logf("insert multi failed: %v", err) t.FailNow() @@ -204,9 +204,9 @@ func TestExpressionChain_Render(t *testing.T) { name: "insert multi with chan value", chain: func() *ExpressionChain { cn, err := NewNoDB().InsertMulti(map[string][]interface{}{ - "field1": []interface{}{"value1", "value1.1"}, - "field2": []interface{}{2, NewNoDB().Select("MAX(value)").From("table").AndWhere("arbitrary = ?", 222)}, - "field3": []interface{}{"blah", "blah2"}}) + "field1": {"value1", "value1.1"}, + "field2": {2, NewNoDB().Select("MAX(value)").From("table").AndWhere("arbitrary = ?", 222)}, + "field3": {"blah", "blah2"}}) if err != nil { t.Logf("insert multi failed: %v", err) t.FailNow() @@ -481,7 +481,7 @@ func TestExpressionChain_Render(t *testing.T) { Where(NewNoDB().AndWhere(Equals("atablename.field1"), "something")) }). Returning("atablename.field2"), - want: "INSERT INTO atablename (field1, field2) VALUES ($1, $2) ON CONFLICT ( field1 ) DO UPDATE SET (field2) = (atablename.field2 + 1) WHERE atablename.field1 = $3 RETURNING atablename.field2", + want: "INSERT INTO atablename (field1, field2) VALUES ($1, $2) ON CONFLICT ( field1 ) DO UPDATE SET (field2) = (atablename.field2 + 1) WHERE atablename.field1 = $3 RETURNING atablename.field2", wantArgs: []interface{}{"somethingelse", 2, "something"}, wantErr: false, }, @@ -600,6 +600,17 @@ func TestExpressionChain_Render(t *testing.T) { wantArgs: []interface{}{1}, wantErr: false, }, + { + name: "Complex function with static arguments gets included", + chain: func() *ExpressionChain { + f := ComplexFunction("COALESCE").Static("true").Static("false") + ec := NewNoDB().Select("true").AndWhereGroup(NewNoDB().OrWhere(f.Fn()).OrWhere(f.Fn())) + return ec + }(), + want: "SELECT true WHERE (COALESCE(true, false) OR COALESCE(true, false))", + wantArgs: []interface{}{}, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/db/chain/expressions.go b/db/chain/expressions.go index 16c5026..1777f8c 100644 --- a/db/chain/expressions.go +++ b/db/chain/expressions.go @@ -49,9 +49,7 @@ func (ec *ExpressionChain) whereGroup(c *ExpressionChain, whereFunc baseSegmentF dst.WriteRune('(') whereArgs := c.renderWhereRaw(dst) dst.WriteRune(')') - if len(whereArgs) > 0 { - whereFunc(dst.String(), whereArgs...) - } + whereFunc(dst.String(), whereArgs...) } // appendExpandedOp is the constructor of the most common chain segment. diff --git a/db/chain/rendering.go b/db/chain/rendering.go index dc3c55a..bac6731 100644 --- a/db/chain/rendering.go +++ b/db/chain/rendering.go @@ -169,9 +169,14 @@ func (ec *ExpressionChain) render(raw bool, query *strings.Builder) ([]interface } if ec.mainOperation.segment == sqlSelect { query.WriteString("SELECT ") - query.WriteString(expression) + if ec.mainOperation.segment == sqlSelect { + query.WriteString(expression) + } } else { - query.WriteString("DELETE ") + query.WriteString("DELETE") + } + if len(ec.mainOperation.arguments) != 0 { + query.WriteRune(' ') } // FROM if ec.table == "" && ec.mainOperation.segment == sqlDelete { @@ -219,7 +224,7 @@ func (ec *ExpressionChain) render(raw bool, query *strings.Builder) ([]interface // WHERE if segmentsPresent(ec, sqlWhere) > 0 { - query.WriteString(" WHERE") + query.WriteString(" WHERE ") args = append(args, ec.renderWhereRaw(query)...) } diff --git a/db/chain/segment.go b/db/chain/segment.go index cffa830..e1a3ce9 100644 --- a/db/chain/segment.go +++ b/db/chain/segment.go @@ -117,8 +117,8 @@ func (q *querySegmentAtom) render(firstForSegment, lastForSegment bool, if !firstForSegment { dst.WriteRune(' ') dst.WriteString(string(q.sqlBool)) + dst.WriteRune(' ') } - dst.WriteRune(' ') dst.WriteString(q.expression) return q.arguments }