Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add != #142

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions converters_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ func tokenExprBinaryToProtoExprBinary(op datalog.BinaryOp) (*pb.OpBinary, error)
pbBinaryKind = pb.OpBinary_GreaterOrEqual
case datalog.BinaryEqual:
pbBinaryKind = pb.OpBinary_Equal
case datalog.BinaryNotEqual:
pbBinaryKind = pb.OpBinary_NotEqual
case datalog.BinaryContains:
pbBinaryKind = pb.OpBinary_Contains
case datalog.BinaryPrefix:
Expand Down Expand Up @@ -397,6 +399,8 @@ func protoExprBinaryToTokenExprBinary(op *pb.OpBinary) (datalog.BinaryOpFunc, er
binaryOp = datalog.GreaterOrEqual{}
case pb.OpBinary_Equal:
binaryOp = datalog.Equal{}
case pb.OpBinary_NotEqual:
binaryOp = datalog.NotEqual{}
case pb.OpBinary_Contains:
binaryOp = datalog.Contains{}
case pb.OpBinary_Prefix:
Expand Down
46 changes: 45 additions & 1 deletion converters_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ func TestExpressionConvertV2(t *testing.T) {
},
},
},
{
Desc: "int comparison not equal",
Input: datalog.Expression{
datalog.Value{ID: datalog.Variable(3)},
datalog.Value{ID: datalog.Integer(42)},
datalog.BinaryOp{BinaryOpFunc: datalog.NotEqual{}},
},
Expected: &pb.ExpressionV2{
Ops: []*pb.Op{
{Content: &pb.Op_Value{Value: &pb.TermV2{Content: &pb.TermV2_Variable{Variable: 3}}}},
{Content: &pb.Op_Value{Value: &pb.TermV2{Content: &pb.TermV2_Integer{Integer: 42}}}},
{Content: &pb.Op_Binary{Binary: &pb.OpBinary{Kind: pb.OpBinary_NotEqual.Enum()}}},
},
},
},
{
Desc: "int comparison larger",
Input: datalog.Expression{
Expand Down Expand Up @@ -165,7 +180,6 @@ func TestExpressionConvertV2(t *testing.T) {
},
},
},

{
Desc: "string comparison equal",
Input: datalog.Expression{
Expand All @@ -181,6 +195,21 @@ func TestExpressionConvertV2(t *testing.T) {
},
},
},
{
Desc: "string comparison not equal",
Input: datalog.Expression{
datalog.Value{ID: datalog.Variable(10)},
datalog.Value{ID: syms.Insert("abcde")},
datalog.BinaryOp{BinaryOpFunc: datalog.NotEqual{}},
},
Expected: &pb.ExpressionV2{
Ops: []*pb.Op{
{Content: &pb.Op_Value{Value: &pb.TermV2{Content: &pb.TermV2_Variable{Variable: 10}}}},
{Content: &pb.Op_Value{Value: &pb.TermV2{Content: &pb.TermV2_String_{String_: syms.Index("abcde")}}}},
{Content: &pb.Op_Binary{Binary: &pb.OpBinary{Kind: pb.OpBinary_NotEqual.Enum()}}},
},
},
},
{
Desc: "string comparison prefix",
Input: datalog.Expression{
Expand Down Expand Up @@ -281,6 +310,21 @@ func TestExpressionConvertV2(t *testing.T) {
},
},
},
{
Desc: "bytes not equal",
Input: datalog.Expression{
datalog.Value{ID: datalog.Variable(16)},
datalog.Value{ID: datalog.Bytes("abcde")},
datalog.BinaryOp{BinaryOpFunc: datalog.NotEqual{}},
},
Expected: &pb.ExpressionV2{
Ops: []*pb.Op{
{Content: &pb.Op_Value{Value: &pb.TermV2{Content: &pb.TermV2_Variable{Variable: 16}}}},
{Content: &pb.Op_Value{Value: &pb.TermV2{Content: &pb.TermV2_Bytes{Bytes: []byte("abcde")}}}},
{Content: &pb.Op_Binary{Binary: &pb.OpBinary{Kind: pb.OpBinary_NotEqual.Enum()}}},
},
},
},
{
Desc: "bytes in",
Input: datalog.Expression{
Expand Down
31 changes: 31 additions & 0 deletions datalog/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ func (op BinaryOp) Print(left, right string) string {
out = fmt.Sprintf("%s >= %s", left, right)
case BinaryEqual:
out = fmt.Sprintf("%s == %s", left, right)
case BinaryNotEqual:
out = fmt.Sprintf("%s != %s", left, right)
case BinaryContains:
out = fmt.Sprintf("%s.contains(%s)", left, right)
case BinaryPrefix:
Expand Down Expand Up @@ -332,6 +334,7 @@ const (
BinaryOr
BinaryIntersection
BinaryUnion
BinaryNotEqual
)

// LessThan returns true when left is less than right.
Expand Down Expand Up @@ -466,6 +469,34 @@ func (Equal) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
return Bool(left.Equal(right)), nil
}

// NotEqual returns true when left and right are not equal.
// It requires left and right to have the same concrete type
// and only accepts Integer, Bytes or String.
type NotEqual struct{}

func (NotEqual) Type() BinaryOpType {
return BinaryNotEqual
}
func (NotEqual) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: NotEqual type mismatch: %d != %d", g, w)
}

switch left.Type() {
case TermTypeInteger:
case TermTypeBytes:
case TermTypeString:
case TermTypeDate:
case TermTypeBool:
case TermTypeSet:

default:
return nil, fmt.Errorf("datalog: unexpected NotEqual value type: %d", left.Type())
}

return Bool(!left.Equal(right)), nil
}

// Contains returns true when the right value exists in the left Set.
// The right value must be an Integer, Bytes, String or Symbol.
// The left value must be a Set, containing elements of right type.
Expand Down
80 changes: 80 additions & 0 deletions datalog/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,86 @@ func TestBinaryEqual(t *testing.T) {
}
}

func TestBinaryNotEqual(t *testing.T) {
require.Equal(t, BinaryNotEqual, NotEqual{}.Type())
syms := &SymbolTable{}

testCases := []struct {
desc string
left Term
right Term
res Bool
expectedErr bool
}{
{
desc: "not equal integers",
left: Integer(3),
right: Integer(5),
res: true,
},
{
desc: "not equal bytes",
left: Bytes{0},
right: Bytes{1},
res: true,
},
{
desc: "not equal string",
left: syms.Insert("abc"),
right: syms.Insert("def"),
res: true,
},
{
desc: "equal integers",
left: Integer(3),
right: Integer(3),
res: false,
},
{
desc: "equal bytes",
left: Bytes{0, 1, 2},
right: Bytes{0, 1, 2},
res: false,
},
{
desc: "equal strings",
left: syms.Insert("abc"),
right: syms.Insert("abc"),
res: false,
},
{
desc: "invalid left type errors",
left: String(42),
right: Integer(42),
expectedErr: true,
},
{
desc: "invalid right type errors",
left: Integer(42),
right: syms.Insert("abc"),
expectedErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ops := Expression{
Value{tc.left},
Value{tc.right},
BinaryOp{NotEqual{}},
}

res, err := ops.Evaluate(nil, syms)
if tc.expectedErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.res, res)
}
})
}
}

func TestBinaryContains(t *testing.T) {
require.Equal(t, BinaryContains, Contains{}.Type())
syms := &SymbolTable{}
Expand Down
7 changes: 5 additions & 2 deletions parser/grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ const (
OpLessThan
OpGreaterThan
OpEqual
OpNotEqual
OpContains
OpPrefix
OpSuffix
Expand All @@ -251,7 +252,7 @@ const (
var operatorMap = map[string]Operator{
"+": OpAdd,
"-": OpSub, "*": OpMul, "/": OpDiv, "&&": OpAnd, "||": OpOr, "<=": OpLessOrEqual, ">=": OpGreaterOrEqual, "<": OpLessThan, ">": OpGreaterThan,
"==": OpEqual, "!": OpNegate, "contains": OpContains, "starts_with": OpPrefix, "ends_with": OpSuffix, "matches": OpMatches, "intersection": OpIntersection, "union": OpUnion, "length": OpLength}
"==": OpEqual, "!=": OpNotEqual, "!": OpNegate, "contains": OpContains, "starts_with": OpPrefix, "ends_with": OpSuffix, "matches": OpMatches, "intersection": OpIntersection, "union": OpUnion, "length": OpLength}

func (o *Operator) Capture(s []string) error {
*o = operatorMap[s[0]]
Expand Down Expand Up @@ -284,7 +285,7 @@ type Expr2 struct {
}

type OpExpr3 struct {
Operator Operator `@("<=" | ">=" | "<" | ">" | "==")`
Operator Operator `@("<=" | ">=" | "<" | ">" | "==" | "!=")`
Expr3 *Expr3 `@@`
}

Expand Down Expand Up @@ -454,6 +455,8 @@ func (op *Operator) ToExpr(expr *biscuit.Expression) {
biscuit_op = biscuit.BinaryGreaterThan
case OpEqual:
biscuit_op = biscuit.BinaryEqual
case OpNotEqual:
biscuit_op = biscuit.BinaryNotEqual
case OpContains:
biscuit_op = biscuit.BinaryContains
case OpPrefix:
Expand Down
54 changes: 52 additions & 2 deletions parser/grammar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ func TestGrammarExpression(t *testing.T) {
biscuit.BinaryEqual,
},
},
{
Input: `$0 != 2`,
Expected: &biscuit.Expression{
biscuit.Value{Term: biscuit.Variable("0")},
biscuit.Value{Term: biscuit.Integer(2)},
biscuit.BinaryNotEqual,
},
},
{
Input: `$1 > 2`,
Expected: &biscuit.Expression{
Expand Down Expand Up @@ -214,6 +222,14 @@ func TestGrammarExpression(t *testing.T) {
biscuit.BinaryEqual,
},
},
{
Input: `$0 != "abcd"`,
Expected: &biscuit.Expression{
biscuit.Value{Term: biscuit.Variable("0")},
biscuit.Value{Term: biscuit.String("abcd")},
biscuit.BinaryNotEqual,
},
},
{
Input: `$0.starts_with("abc")`,
Expected: &biscuit.Expression{
Expand Down Expand Up @@ -304,11 +320,30 @@ func TestGrammarExpression(t *testing.T) {
},
},
{
Input: `hex:12ab == hex:ab`,
Input: `[hex:41].intersection([hex:41]).length() != $0`,
Expected: &biscuit.Expression{
biscuit.Value{Term: biscuit.Set{biscuit.Bytes([]byte("A"))}},
biscuit.Value{Term: biscuit.Set{biscuit.Bytes([]byte("A"))}},
biscuit.BinaryIntersection,
biscuit.UnaryLength,
biscuit.Value{Term: biscuit.Variable("0")},
biscuit.BinaryNotEqual,
},
},
{
Input: `hex:12ab == hex:12ab`, //not sure why but the previous test also passed even when 12ab should not equal ab
Expected: &biscuit.Expression{
biscuit.Value{Term: biscuit.Bytes([]byte{0x12, 0xab})},
biscuit.Value{Term: biscuit.Bytes([]byte{0x12, 0xab})},
biscuit.BinaryEqual,
},
},
{
Input: `hex:12ab != hex:ab`,
Expected: &biscuit.Expression{
biscuit.Value{Term: biscuit.Bytes([]byte{0x12, 0xab})},
biscuit.Value{Term: biscuit.Bytes([]byte{0xab})},
biscuit.BinaryEqual,
biscuit.BinaryNotEqual,
},
},
{
Expand All @@ -332,6 +367,21 @@ func TestGrammarExpression(t *testing.T) {
biscuit.BinaryOr,
},
},
{
Input: `{param1} != {param2} || {param3}`,
Params: map[string]biscuit.Term{
"param1": biscuit.Integer(1),
"param2": biscuit.Integer(2),
"param3": biscuit.Bool(false),
},
Expected: &biscuit.Expression{
biscuit.Value{Term: biscuit.Integer(1)},
biscuit.Value{Term: biscuit.Integer(2)},
biscuit.BinaryNotEqual,
biscuit.Value{Term: biscuit.Bool(false)},
biscuit.BinaryOr,
},
},
}

for _, testCase := range testCases {
Expand Down
2 changes: 1 addition & 1 deletion parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ var BiscuitLexerRules = []lexer.SimpleRule{
{Name: "Arrow", Pattern: `<-`},
{Name: "Or", Pattern: `\|\|`},
{Name: "And", Pattern: `&&`},
{Name: "Operator", Pattern: `==|>=|<=|>|<|\+|-|\*`},
{Name: "Operator", Pattern: `!=|==|>=|<=|>|<|\+|-|\*`},
{Name: "Comment", Pattern: `//[^\n]*`},
{Name: "String", Pattern: `\"[^\"]*\"`},
{Name: "Variable", Pattern: `\$[a-zA-Z0-9_:]+`},
Expand Down
14 changes: 12 additions & 2 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func getRuleTestCases() []testCase {
ExpectFailure: true,
},
{
Input: `rule1("a") <- body1("b"), $0 > 0, $1 < 1, $2 >= 2, $3 <= 3, $4 == 4, [1, 2, 3].contains($5), ![4,5,6].contains($6)`,
Input: `rule1("a") <- body1("b"), $0 > 0, $1 < 1, $2 >= 2, $3 <= 3, $4 == 4, [1, 2, 3].contains($5), ![4,5,6].contains($6), $7 != 6`,
Expected: biscuit.Rule{
Head: biscuit.Predicate{
Name: "rule1",
Expand Down Expand Up @@ -206,11 +206,16 @@ func getRuleTestCases() []testCase {
biscuit.BinaryContains,
biscuit.UnaryNegate,
},
{
biscuit.Value{Term: biscuit.Variable("7")},
biscuit.Value{Term: biscuit.Integer(6)},
biscuit.BinaryNotEqual,
},
},
},
},
{
Input: `rule1("a") <- body1("b"), $0 == "abc", $1.starts_with("def"), $2.ends_with("ghi"), $3.matches("file[0-9]+.txt"), ["a","b"].contains($4), !["c", "d"].contains($5)`,
Input: `rule1("a") <- body1("b"), $0 == "abc", $1.starts_with("def"), $2.ends_with("ghi"), $3.matches("file[0-9]+.txt"), ["a","b"].contains($4), !["c", "d"].contains($5), $6 != "abc"`,
Expected: biscuit.Rule{
Head: biscuit.Predicate{
Name: "rule1",
Expand Down Expand Up @@ -252,6 +257,11 @@ func getRuleTestCases() []testCase {
biscuit.BinaryContains,
biscuit.UnaryNegate,
},
{
biscuit.Value{Term: biscuit.Variable("6")},
biscuit.Value{Term: biscuit.String("abc")},
biscuit.BinaryNotEqual,
},
},
},
},
Expand Down
Loading
Loading