diff --git a/src/main.go b/src/main.go index b353e0c..4cd336a 100644 --- a/src/main.go +++ b/src/main.go @@ -6,7 +6,7 @@ import ( "time" ) -const VERSION = "0.27.0" +const VERSION = "0.27.1" func main() { config := LoadConfig() diff --git a/src/query_handler_test.go b/src/query_handler_test.go index 784f499..50127a8 100644 --- a/src/query_handler_test.go +++ b/src/query_handler_test.go @@ -466,6 +466,23 @@ func TestHandleQuery(t *testing.T) { "description": {"oid", "description"}, "values": {"10", ""}, }, + // CASE + "SELECT CASE WHEN true THEN 'yes' ELSE 'no' END AS case": { + "description": {"case"}, + "values": {"yes"}, + }, + "SELECT CASE WHEN false THEN 'yes' ELSE 'no' END AS case": { + "description": {"case"}, + "values": {"no"}, + }, + "SELECT CASE WHEN true THEN 'one' WHEN false THEN 'two' ELSE 'three' END AS case": { + "description": {"case"}, + "values": {"one"}, + }, + "SELECT CASE WHEN (SELECT count(extname) FROM pg_catalog.pg_extension WHERE extname = 'bdr') > 0 THEN 'pgd' WHEN (SELECT count(*) FROM pg_replication_slots) > 0 THEN 'log' ELSE NULL END AS type": { + "description": {"type"}, + "values": {""}, + }, } for query, responses := range responsesByQuery { diff --git a/src/select_remapper.go b/src/select_remapper.go index f2d84b1..d91c813 100644 --- a/src/select_remapper.go +++ b/src/select_remapper.go @@ -59,6 +59,12 @@ func (selectRemapper *SelectRemapper) RemapSetStatement(stmt *pgQuery.RawStmt) * func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt { selectStatement = selectRemapper.remapTypeCastsInSelect(selectStatement) + // CASE - only process if we have CASE expressions + if hasCaseExpr := selectRemapper.hasCaseExpressions(selectStatement); hasCaseExpr { + selectRemapper.traceTreeTraversal("CASE expressions", indentLevel) + return selectRemapper.remapCaseExpressions(selectStatement, indentLevel) + } + // UNION if selectStatement.FromClause == nil && selectStatement.Larg != nil && selectStatement.Rarg != nil { selectRemapper.traceTreeTraversal("UNION left", indentLevel) @@ -114,6 +120,57 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu return selectStatement } +func (selectRemapper *SelectRemapper) hasCaseExpressions(selectStatement *pgQuery.SelectStmt) bool { + for _, target := range selectStatement.TargetList { + if target.GetResTarget().Val.GetCaseExpr() != nil { + return true + } + } + return false +} + +func (selectRemapper *SelectRemapper) remapCaseExpressions(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt { + for _, target := range selectStatement.TargetList { + if caseExpr := target.GetResTarget().Val.GetCaseExpr(); caseExpr != nil { + for _, when := range caseExpr.Args { + if whenClause := when.GetCaseWhen(); whenClause != nil { + if whenClause.Expr != nil { + if aExpr := whenClause.Expr.GetAExpr(); aExpr != nil { + if subLink := aExpr.Lexpr.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE WHEN left", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + if subLink := aExpr.Rexpr.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE WHEN right", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + } + } + + if whenClause.Result != nil { + if subLink := whenClause.Result.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE THEN", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + } + } + } + + if caseExpr.Defresult != nil { + if subLink := caseExpr.Defresult.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE ELSE", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + } + } + } + return selectStatement +} + // FROM PG_FUNCTION() func (selectRemapper *SelectRemapper) remapTableFunction(fromNode *pgQuery.Node, indentLevel int) *pgQuery.Node { selectRemapper.traceTreeTraversal("FROM function()", indentLevel)