Skip to content

Commit

Permalink
Add CASE statements nested selects support
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Dec 16, 2024
1 parent 68363c1 commit dcbe6cd
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const VERSION = "0.27.0"
const VERSION = "0.27.1"

func main() {
config := LoadConfig()
Expand Down
17 changes: 17 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,23 @@ func TestHandleQuery(t *testing.T) {
"description": {"oid", "description"},
"values": {"10", ""},
},
// CASE
"SELECT CASE WHEN true THEN 'yes' ELSE 'no' END AS case": {
"description": {"case"},
"values": {"yes"},
},
"SELECT CASE WHEN false THEN 'yes' ELSE 'no' END AS case": {
"description": {"case"},
"values": {"no"},
},
"SELECT CASE WHEN true THEN 'one' WHEN false THEN 'two' ELSE 'three' END AS case": {
"description": {"case"},
"values": {"one"},
},
"SELECT CASE WHEN (SELECT count(extname) FROM pg_catalog.pg_extension WHERE extname = 'bdr') > 0 THEN 'pgd' WHEN (SELECT count(*) FROM pg_replication_slots) > 0 THEN 'log' ELSE NULL END AS type": {
"description": {"type"},
"values": {""},
},
}

for query, responses := range responsesByQuery {
Expand Down
57 changes: 57 additions & 0 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ func (selectRemapper *SelectRemapper) RemapSetStatement(stmt *pgQuery.RawStmt) *
func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
selectStatement = selectRemapper.remapTypeCastsInSelect(selectStatement)

// CASE - only process if we have CASE expressions
if hasCaseExpr := selectRemapper.hasCaseExpressions(selectStatement); hasCaseExpr {
selectRemapper.traceTreeTraversal("CASE expressions", indentLevel)
return selectRemapper.remapCaseExpressions(selectStatement, indentLevel)
}

// UNION
if selectStatement.FromClause == nil && selectStatement.Larg != nil && selectStatement.Rarg != nil {
selectRemapper.traceTreeTraversal("UNION left", indentLevel)
Expand Down Expand Up @@ -114,6 +120,57 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu
return selectStatement
}

func (selectRemapper *SelectRemapper) hasCaseExpressions(selectStatement *pgQuery.SelectStmt) bool {
for _, target := range selectStatement.TargetList {
if target.GetResTarget().Val.GetCaseExpr() != nil {
return true
}
}
return false
}

func (selectRemapper *SelectRemapper) remapCaseExpressions(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
for _, target := range selectStatement.TargetList {
if caseExpr := target.GetResTarget().Val.GetCaseExpr(); caseExpr != nil {
for _, when := range caseExpr.Args {
if whenClause := when.GetCaseWhen(); whenClause != nil {
if whenClause.Expr != nil {
if aExpr := whenClause.Expr.GetAExpr(); aExpr != nil {
if subLink := aExpr.Lexpr.GetSubLink(); subLink != nil {
selectRemapper.traceTreeTraversal("CASE WHEN left", indentLevel+1)
subSelect := subLink.Subselect.GetSelectStmt()
subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1)
}
if subLink := aExpr.Rexpr.GetSubLink(); subLink != nil {
selectRemapper.traceTreeTraversal("CASE WHEN right", indentLevel+1)
subSelect := subLink.Subselect.GetSelectStmt()
subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1)
}
}
}

if whenClause.Result != nil {
if subLink := whenClause.Result.GetSubLink(); subLink != nil {
selectRemapper.traceTreeTraversal("CASE THEN", indentLevel+1)
subSelect := subLink.Subselect.GetSelectStmt()
subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1)
}
}
}
}

if caseExpr.Defresult != nil {
if subLink := caseExpr.Defresult.GetSubLink(); subLink != nil {
selectRemapper.traceTreeTraversal("CASE ELSE", indentLevel+1)
subSelect := subLink.Subselect.GetSelectStmt()
subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1)
}
}
}
}
return selectStatement
}

// FROM PG_FUNCTION()
func (selectRemapper *SelectRemapper) remapTableFunction(fromNode *pgQuery.Node, indentLevel int) *pgQuery.Node {
selectRemapper.traceTreeTraversal("FROM function()", indentLevel)
Expand Down

0 comments on commit dcbe6cd

Please sign in to comment.