From 419265804ccd936d046c5344ecaf481e24121973 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Sat, 30 Nov 2024 09:28:51 -0500 Subject: [PATCH] short-circuit for && and || operators Implement short-circuiting logic for boolean binary operators --- hclsyntax/expression_ops.go | 96 ++++++++- hclsyntax/expression_test.go | 397 +++++++++++++++++++++++++++++++++++ 2 files changed, 486 insertions(+), 7 deletions(-) diff --git a/hclsyntax/expression_ops.go b/hclsyntax/expression_ops.go index 6585612c..6d9e1af5 100644 --- a/hclsyntax/expression_ops.go +++ b/hclsyntax/expression_ops.go @@ -16,16 +16,73 @@ import ( type Operation struct { Impl function.Function Type cty.Type + + // ShortCircuit is an optional callback for binary operations which, if set, + // will be called with the result of evaluating the LHS and RHS expressions + // and their individual diagnostics. The LHS and RHS values are guaranteed + // to be unmarked and of the correct type. + // + // ShortCircuit may return cty.NilVal to allow evaluation to proceed as + // normal, or it may return a non-nil value with diagnostics to return + // before the main Impl is called. The returned diagnostics should match + // the side of the Operation which was taken. + ShortCircuit func(lhs, rhs cty.Value, lhsDiags, rhsDiags hcl.Diagnostics) (cty.Value, hcl.Diagnostics) } var ( OpLogicalOr = &Operation{ Impl: stdlib.OrFunc, Type: cty.Bool, + + ShortCircuit: func(lhs, rhs cty.Value, lhsDiags, rhsDiags hcl.Diagnostics) (cty.Value, hcl.Diagnostics) { + switch { + // if both are unknown, we don't short circuit anything + case !lhs.IsKnown() && !rhs.IsKnown(): + return cty.NilVal, nil + + // for ||, a single true is the controlling condition + case lhs.IsKnown() && lhs.True(): + return cty.True, lhsDiags + case rhs.IsKnown() && rhs.True(): + return cty.True, rhsDiags + + // if the opposing side is false we can't sort-circuit based on + // boolean logic, so an unknown becomes the controlling condition + case !lhs.IsKnown() && rhs.False(): + return cty.UnknownVal(cty.Bool).RefineNotNull(), lhsDiags + case !rhs.IsKnown() && lhs.False(): + return cty.UnknownVal(cty.Bool).RefineNotNull(), rhsDiags + } + + return cty.NilVal, nil + }, } OpLogicalAnd = &Operation{ Impl: stdlib.AndFunc, Type: cty.Bool, + + ShortCircuit: func(lhs, rhs cty.Value, lhsDiags, rhsDiags hcl.Diagnostics) (cty.Value, hcl.Diagnostics) { + switch { + // if both are unknown, we don't short circuit anything + case !lhs.IsKnown() && !rhs.IsKnown(): + return cty.NilVal, nil + + // For &&, a single false is the controlling condition + case lhs.IsKnown() && lhs.False(): + return cty.False, lhsDiags + case rhs.IsKnown() && rhs.False(): + return cty.False, rhsDiags + + // if the opposing side is true we can't sort-circuit based on + // boolean logic, so an unknown becomes the controlling condition + case !lhs.IsKnown() && rhs.True(): + return cty.UnknownVal(cty.Bool).RefineNotNull(), lhsDiags + case !rhs.IsKnown() && lhs.True(): + return cty.UnknownVal(cty.Bool).RefineNotNull(), rhsDiags + } + + return cty.NilVal, nil + }, } OpLogicalNot = &Operation{ Impl: stdlib.NotFunc, @@ -145,10 +202,6 @@ func (e *BinaryOpExpr) Value(ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) var diags hcl.Diagnostics givenLHSVal, lhsDiags := e.LHS.Value(ctx) - givenRHSVal, rhsDiags := e.RHS.Value(ctx) - diags = append(diags, lhsDiags...) - diags = append(diags, rhsDiags...) - lhsVal, err := convert.Convert(givenLHSVal, lhsParam.Type) if err != nil { diags = append(diags, &hcl.Diagnostic{ @@ -161,6 +214,8 @@ func (e *BinaryOpExpr) Value(ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) EvalContext: ctx, }) } + + givenRHSVal, rhsDiags := e.RHS.Value(ctx) rhsVal, err := convert.Convert(givenRHSVal, rhsParam.Type) if err != nil { diags = append(diags, &hcl.Diagnostic{ @@ -174,12 +229,39 @@ func (e *BinaryOpExpr) Value(ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) }) } + // diags so far only contains conversion errors, which should cover + // incorrect parameter types. if diags.HasErrors() { - // Don't actually try the call if we have errors already, since the - // this will probably just produce a confusing duplicative diagnostic. + // Add the rest of the diagnostic in case that helps the user, but keep + // them separate as we continue for short-circuit handling. + diags = append(diags, lhsDiags...) + diags = append(diags, rhsDiags...) return cty.UnknownVal(e.Op.Type), diags } + lhsVal, lhsMarks := lhsVal.Unmark() + rhsVal, rhsMarks := rhsVal.Unmark() + + // If we short-circuited above and still passed the type-check of RHS then + // we'll halt here and return the short-circuit result rather than actually + // executing the operation. + if e.Op.ShortCircuit != nil { + forceResult, diags := e.Op.ShortCircuit(lhsVal, rhsVal, lhsDiags, rhsDiags) + if forceResult != cty.NilVal { + // It would be technically more correct to insert rhs diagnostics if + // forceResult is not known since we didn't really short-circuit. That + // would however not match the behavior of conditional expressions which + // do drop all diagnostics from the unevaluated expressions + return forceResult.WithMarks(lhsMarks, rhsMarks), diags + } + } + + if diags.HasErrors() { + // Don't actually try the call if we have errors, since the this will + // probably just produce confusing duplicate diagnostics. + return cty.UnknownVal(e.Op.Type).WithMarks(lhsMarks, rhsMarks), diags + } + args := []cty.Value{lhsVal, rhsVal} result, err := impl.Call(args) if err != nil { @@ -195,7 +277,7 @@ func (e *BinaryOpExpr) Value(ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) return cty.UnknownVal(e.Op.Type), diags } - return result, diags + return result.WithMarks(lhsMarks, rhsMarks), diags } func (e *BinaryOpExpr) Range() hcl.Range { diff --git a/hclsyntax/expression_test.go b/hclsyntax/expression_test.go index df11d28c..843847df 100644 --- a/hclsyntax/expression_test.go +++ b/hclsyntax/expression_test.go @@ -1913,6 +1913,112 @@ EOT cty.False, 0, }, + { + // Logical AND operator short-circuit behavior + `nullobj != null && nullobj.is_thingy`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "nullobj": cty.NullVal(cty.Object(map[string]cty.Type{ + "is_thingy": cty.Bool, + })), + }, + }, + cty.False, + 0, // nullobj != null prevents evaluating nullobj.is_thingy + }, + { + // Logical AND short-circuit handling of unknown values + // If the first operand is an unknown bool then we can't know if + // we will short-circuit or not, and so we must assume we will + // and wait until the value becomes known before fully evaluating RHS. + `unknown < 4 && list[zero]`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "unknown": cty.UnknownVal(cty.Number), + "zero": cty.Zero, + "list": cty.ListValEmpty(cty.Bool), + }, + }, + cty.UnknownVal(cty.Bool).RefineNotNull(), + 0, + }, + { + // Logical OR operator short-circuit behavior + `nullobj == null || nullobj.is_thingy`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "nullobj": cty.NullVal(cty.Object(map[string]cty.Type{ + "is_thingy": cty.Bool, + })), + }, + }, + cty.True, + 0, // nullobj == null prevents evaluating nullobj.is_thingy + }, + { + // Logical OR short-circuit handling of unknown values + // If the first operand is an unknown bool then we can't know if + // we will short-circuit or not, and so we must assume we will + // and wait until the value becomes known before fully evaluating RHS. + `unknown > 4 || list[zero]`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "unknown": cty.UnknownVal(cty.Number), + "zero": cty.Zero, + "list": cty.ListValEmpty(cty.Bool), + }, + }, + cty.UnknownVal(cty.Bool).RefineNotNull(), + 0, + }, + { + // short circuit calls must still retain marks + `lhsTrue || rhsUnknown`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "lhsTrue": cty.True.Mark("a"), + "rhsUnknown": cty.UnknownVal(cty.Bool).Mark("b"), + }, + }, + cty.True.Mark("a").Mark("b"), + 0, + }, + { + // short circuit calls must still retain marks + `lhsUnknown || rhsTrue`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "rhsTrue": cty.True.Mark("a"), + "lhsUnknown": cty.UnknownVal(cty.Bool).Mark("b"), + }, + }, + cty.True.Mark("a").Mark("b"), + 0, + }, + { + // short circuit calls must still retain marks + `lhsUnknown && rhsFalse`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "rhsFalse": cty.False.Mark("a"), + "lhsUnknown": cty.UnknownVal(cty.Bool).Mark("b"), + }, + }, + cty.False.Mark("a").Mark("b"), + 0, + }, + { + // short circuit calls must still retain marks + `lhsFalse && rhsUnknown`, + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "lhsFalse": cty.False.Mark("a"), + "rhsUnknown": cty.UnknownVal(cty.Bool).Mark("b"), + }, + }, + cty.False.Mark("a").Mark("b"), + 0, + }, { `true ? var : null`, &hcl.EvalContext{ @@ -2272,6 +2378,36 @@ EOT cty.UnknownVal(cty.String).RefineNotNull().Mark("sensitive"), 0, }, + { + // foo does not exist, but we need to catch the diagnostics when + // coming out of a ShortCircuit call + "foo(value) && true", + &hcl.EvalContext{}, + cty.UnknownVal(cty.Bool).RefineNotNull(), + 1, + }, + { + // foo does not exist, but the short-circuit wins + "foo(value) && false", + &hcl.EvalContext{}, + cty.False, + 0, + }, + { + // foo does not exist, but we need to catch the diagnostics when + // coming out of a ShortCircuit call + "foo(value) || false", + &hcl.EvalContext{}, + cty.UnknownVal(cty.Bool).RefineNotNull(), + 1, + }, + { + // foo does not exist, but the short-circuit wins + "foo(value) || true", + &hcl.EvalContext{}, + cty.True, + 0, + }, } for _, test := range tests { @@ -2399,6 +2535,87 @@ func TestExpressionErrorMessages(t *testing.T) { // describe coherently. "The true and false result expressions must have consistent types. At least one deeply-nested attribute or element is not compatible across both the 'true' and the 'false' value.", }, + + // Error messages describing situations where the logical operator + // short-circuit behavior still found a type error on the RHS that + // we therefore still report, because the LHS only guards against + // value-related problems in the RHS. + { + // It's not valid to access an attribute on a non-object-typed + // value even if we've proven it isn't null. + "notobj != null && notobj.foo", + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "notobj": cty.True, + }, + }, + "Unsupported attribute", + "Can't access attributes on a primitive-typed value (bool).", + }, + { + // It's not valid to access an attribute on a non-object-typed + // value even if we've proven it isn't null. + "notobj == null || notobj.foo", + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "notobj": cty.True, + }, + }, + "Unsupported attribute", + "Can't access attributes on a primitive-typed value (bool).", + }, + { + // It's not valid to access an index on an unindexable type + // even if we've proven it isn't null. + "notlist != null && notlist[0]", + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "notlist": cty.True, + }, + }, + "Invalid index", + "This value does not have any indices.", + }, + { + // Short-circuit can't avoid an error accessing a variable that + // doesn't exist at all, so we can still report typos. + "value != null && valeu", + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "value": cty.True, + }, + }, + "Unknown variable", + `There is no variable named "valeu". Did you mean "value"?`, + }, + { + // Short-circuit must still catch type errors on the opposite side + "unknown && \"value\"", + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "unknown": cty.UnknownVal(cty.Bool), + }, + }, + "Invalid operand", + `Unsuitable value for right operand: a bool is required.`, + }, + { + // Short-circuiting must still catch type errors on the opposite side + "value && \"value\"", + &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "value": cty.False, + }, + }, + "Invalid operand", + `Unsuitable value for right operand: a bool is required.`, + }, + { + "foo(value) && true", + &hcl.EvalContext{}, + "Function calls not allowed", + `Functions may not be called here.`, + }, } for _, test := range tests { @@ -2733,3 +2950,183 @@ func TestParseExpression_incompleteFunctionCall(t *testing.T) { }) } } + +func TestAllBoolExpressions(t *testing.T) { + inputs := map[string]cty.Value{ + // truth table for all boolean expressions + "true && true": cty.True, + "true || true": cty.True, + "true && false": cty.False, + "true || false": cty.True, + "true && unknown": cty.DynamicVal, + "true || unknown": cty.True, + "false && true": cty.False, + "false || true": cty.True, + "false && false": cty.False, + "false || false": cty.False, + "false && unknown": cty.False, + "false || unknown": cty.DynamicVal, + "unknown && true": cty.DynamicVal, + "unknown || true": cty.True, + "unknown && false": cty.False, + "unknown || false": cty.DynamicVal, + "unknown && unknown": cty.DynamicVal, + "unknown || unknown": cty.DynamicVal, + + // Truth table for all possible combinations of 3 part boolean + // expressions. Also added equivalent parenthesized versions for when + // the operator precedense affects the result. + "true && true && true": cty.True, + "true || true && true": cty.True, + "true || true || true": cty.True, + "true && true || true": cty.True, + "true && true && false": cty.False, + "true || true && false": cty.True, + "true || true || false": cty.True, + "true && true || false": cty.True, + "true && true && unknown": cty.DynamicVal, + "true || true && unknown": cty.True, + "true || true || unknown": cty.True, + "true && true || unknown": cty.True, + "true && false && true": cty.False, + "true || false && true": cty.True, + "true || false || true": cty.True, + "true && false || true": cty.True, + "true && false && false": cty.False, + "true || false && false": cty.True, + "true || false || false": cty.True, + "true && false || false": cty.False, + "true && false && unknown": cty.False, + "true || false && unknown": cty.True, + "true || false || unknown": cty.True, + "true && false || unknown": cty.DynamicVal, + "true && unknown && true": cty.DynamicVal, + "true || unknown && true": cty.True, + "true || unknown || true": cty.True, + "true && unknown || true": cty.True, + "true && unknown && false": cty.False, + "true || unknown && false": cty.True, + "true || unknown || false": cty.True, + "true && unknown || false": cty.DynamicVal, + "true && unknown && unknown": cty.DynamicVal, + "true || unknown && unknown": cty.True, + "true || unknown || unknown": cty.True, + "true && unknown || unknown": cty.DynamicVal, + "false && true && true": cty.False, + "false || true && true": cty.True, + "false || true || true": cty.True, + "false && true || true": cty.True, + "(false && true) || true": cty.True, + "false && true && false": cty.False, + "false || true && false": cty.False, + "false || true || false": cty.True, + "false && true || false": cty.False, + "false && true && unknown": cty.False, + "false || true && unknown": cty.DynamicVal, + "false || true || unknown": cty.True, + "false && true || unknown": cty.DynamicVal, + "(false && true) || unknown": cty.DynamicVal, + "false && false && true": cty.False, + "false || false && true": cty.False, + "false || false || true": cty.True, + "false && false || true": cty.True, + "false && false && false": cty.False, + "false || false && false": cty.False, + "false || false || false": cty.False, + "false && false || false": cty.False, + "false && false && unknown": cty.False, + "false || false && unknown": cty.False, + "false || false || unknown": cty.DynamicVal, + "false && false || unknown": cty.DynamicVal, + "(false && false) || unknown": cty.DynamicVal, + "false && unknown && true": cty.False, + "false || unknown && true": cty.DynamicVal, + "false || unknown || true": cty.True, + "false && unknown || true": cty.True, + "(false && unknown) || true": cty.True, + "false && unknown && false": cty.False, + "false || unknown && false": cty.False, + "false || unknown || false": cty.DynamicVal, + "false && unknown || false": cty.False, + "false && unknown && unknown": cty.False, + "false || unknown && unknown": cty.DynamicVal, + "false || unknown || unknown": cty.DynamicVal, + "false && unknown || unknown": cty.DynamicVal, + "(false && unknown) || unknown": cty.DynamicVal, + "unknown && true && true": cty.DynamicVal, + "unknown || true && true": cty.True, + "unknown || true || true": cty.True, + "unknown && true || true": cty.True, + "unknown && true && false": cty.False, + "unknown || true && false": cty.DynamicVal, + "unknown || (true && false)": cty.DynamicVal, + "unknown || true || false": cty.True, + "unknown && true || false": cty.DynamicVal, + "unknown && true && unknown": cty.DynamicVal, + "unknown || true && unknown": cty.DynamicVal, + "unknown || true || unknown": cty.True, + "unknown && true || unknown": cty.DynamicVal, + "unknown && false && true": cty.False, + "unknown || false && true": cty.DynamicVal, + "unknown || false || true": cty.True, + "unknown && false || true": cty.True, + "(unknown && false) || true": cty.True, + "unknown && false && false": cty.False, + "unknown || false && false": cty.DynamicVal, + "unknown || false || false": cty.DynamicVal, + "unknown && false || false": cty.False, + "unknown && false && unknown": cty.False, + "unknown || false && unknown": cty.DynamicVal, + "unknown || false || unknown": cty.DynamicVal, + "unknown && false || unknown": cty.DynamicVal, + "unknown && unknown && true": cty.DynamicVal, + "unknown || unknown && true": cty.DynamicVal, + "unknown || unknown || true": cty.True, + "unknown && unknown || true": cty.True, + "unknown && unknown && false": cty.False, + "unknown || unknown && false": cty.DynamicVal, + "unknown || unknown || false": cty.DynamicVal, + "unknown && unknown || false": cty.DynamicVal, + "unknown && unknown && unknown": cty.DynamicVal, + "unknown || unknown && unknown": cty.DynamicVal, + "unknown || unknown || unknown": cty.DynamicVal, + "unknown && unknown || unknown": cty.DynamicVal, + } + + for input, want := range inputs { + t.Run(input, func(t *testing.T) { + if !want.IsKnown() { + want = cty.UnknownVal(cty.Bool).RefineNotNull() + } + ctx := &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "unknown": cty.UnknownVal(cty.DynamicPseudoType), + }, + } + expr, diags := ParseExpression([]byte(input), "", hcl.Pos{Line: 1, Column: 1, Byte: 0}) + if diags.HasErrors() { + t.Fatal(diags.Error()) + } + got, diags := expr.Value(ctx) + if diags.HasErrors() { + t.Fatal(diags.Error()) + } + + if got.IsKnown() != want.IsKnown() { + t.Fatalf("%q resulted in %#v, wanted %#v\n", input, got, want) + } + if !got.IsKnown() { + // this validates that the uknown refinements are correct too + if !got.RawEquals(want) { + t.Fatalf("wrong unknown, got:%#v, want:%#v\n", got, want) + } + // covered in known comparison + return + } + + if got.Equals(want).False() { + t.Fatalf("%q resulted in %#v, wanted %#v\n", input, got, want) + } + }) + } +}