Skip to content

Commit

Permalink
Expressions: coverage (#133)
Browse files Browse the repository at this point in the history
* Repro issue 84

* Improve coverage for expressions, specifically Print

* Clean up warnings

---------

Co-authored-by: ysheffer <[email protected]>
  • Loading branch information
yaronf and ysheffer authored May 2, 2024
1 parent 0bfaab9 commit 584aa35
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 24 deletions.
72 changes: 48 additions & 24 deletions datalog/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) (
return nil, fmt.Errorf("datalog: expressions: unknown variable %d", id.(Variable))
}
id = *idptr
default: // do nothing
}
err := s.Push(id)
if err != nil {
return nil, fmt.Errorf("datalog: expressions: stack overflow")
}
s.Push(id)
case OpTypeUnary:
v, err := s.Pop()
if err != nil {
Expand All @@ -45,7 +49,10 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) (
if err != nil {
return nil, fmt.Errorf("datalog: expressions: unary eval failed: %w", err)
}
s.Push(res)
err = s.Push(res)
if err != nil {
return nil, fmt.Errorf("datalog: expressions: stack overflow")
}
case OpTypeBinary:
right, err := s.Pop()
if err != nil {
Expand All @@ -60,7 +67,10 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) (
if err != nil {
return nil, fmt.Errorf("datalog: expressions: binary eval failed: %w", err)
}
s.Push(res)
err = s.Push(res)
if err != nil {
return nil, fmt.Errorf("datalog: expressions: stack overflow")
}
default:
return nil, fmt.Errorf("datalog: expressions: unsupported Op: %v", op.Type())
}
Expand All @@ -83,22 +93,31 @@ func (e *Expression) Print(symbols *SymbolTable) string {
id := op.(Value).ID
switch id.Type() {
case TermTypeString:
s.Push(fmt.Sprintf("\"%s\"", symbols.Str(id.(String))))
err := s.Push(fmt.Sprintf("\"%s\"", symbols.Str(id.(String))))
if err != nil {
return "<invalid expression: stack overflow>"
}
case TermTypeVariable:
s.Push(fmt.Sprintf("$%s", symbols.Var(id.(Variable))))
err := s.Push(fmt.Sprintf("$%s", symbols.Var(id.(Variable))))
if err != nil {
return "<invalid expression: stack overflow>"
}
default:
s.Push(id.String())
err := s.Push(id.String())
if err != nil {
return "<invalid expression: stack overflow>"
}
}
case OpTypeUnary:
v, err := s.Pop()
if err != nil {
return "<invalid expression: unary operation failed to pop value>"
}
res := op.(UnaryOp).Print(v)
err = s.Push(res)
if err != nil {
return "<invalid expression: binary operation failed to pop right value>"
return "<invalid expression: stack overflow>"
}
s.Push(res)
case OpTypeBinary:
right, err := s.Pop()
if err != nil {
Expand All @@ -109,7 +128,10 @@ func (e *Expression) Print(symbols *SymbolTable) string {
return "<invalid expression: binary operation failed to pop left value>"
}
res := op.(BinaryOp).Print(left, right)
s.Push(res)
err = s.Push(res)
if err != nil {
return "<invalid expression: stack overflow>"
}
default:
return fmt.Sprintf("<invalid expression: unsupported op type %v>", op.Type())
}
Expand Down Expand Up @@ -160,6 +182,8 @@ func (op UnaryOp) Print(value string) string {
out = fmt.Sprintf("!%s", value)
case UnaryParens:
out = fmt.Sprintf("(%s)", value)
case UnaryLength:
out = fmt.Sprintf("%s.length()", value)
default:
out = fmt.Sprintf("unknown(%s)", value)
}
Expand All @@ -186,7 +210,7 @@ type Negate struct{}
func (Negate) Type() UnaryOpType {
return UnaryNegate
}
func (Negate) Eval(value Term, symbols *SymbolTable) (Term, error) {
func (Negate) Eval(value Term, _ *SymbolTable) (Term, error) {
var out Term
switch value.Type() {
case TermTypeBool:
Expand All @@ -206,7 +230,7 @@ type Parens struct{}
func (Parens) Type() UnaryOpType {
return UnaryParens
}
func (Parens) Eval(value Term, symbols *SymbolTable) (Term, error) {
func (Parens) Eval(value Term, _ *SymbolTable) (Term, error) {
return value, nil
}

Expand All @@ -228,7 +252,7 @@ func (Length) Eval(value Term, symbols *SymbolTable) (Term, error) {
case TermTypeSet:
out = Integer(len(value.(Set)))
default:
return nil, fmt.Errorf("datalog: unexpected Negate value type: %d", value.Type())
return nil, fmt.Errorf("datalog: unexpected Length value type: %d", value.Type())
}
return out, nil
}
Expand Down Expand Up @@ -318,7 +342,7 @@ type LessThan struct{}
func (LessThan) Type() BinaryOpType {
return BinaryLessThan
}
func (LessThan) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (LessThan) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: LessThan type mismatch: %d != %d", g, w)
}
Expand All @@ -344,7 +368,7 @@ type LessOrEqual struct{}
func (LessOrEqual) Type() BinaryOpType {
return BinaryLessOrEqual
}
func (LessOrEqual) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (LessOrEqual) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: LessOrEqual type mismatch: %d != %d", g, w)
}
Expand All @@ -370,7 +394,7 @@ type GreaterThan struct{}
func (GreaterThan) Type() BinaryOpType {
return BinaryGreaterThan
}
func (GreaterThan) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (GreaterThan) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: GreaterThan type mismatch: %d != %d", g, w)
}
Expand All @@ -396,7 +420,7 @@ type GreaterOrEqual struct{}
func (GreaterOrEqual) Type() BinaryOpType {
return BinaryGreaterOrEqual
}
func (GreaterOrEqual) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (GreaterOrEqual) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: GreaterOrEqual type mismatch: %d != %d", g, w)
}
Expand All @@ -422,7 +446,7 @@ type Equal struct{}
func (Equal) Type() BinaryOpType {
return BinaryEqual
}
func (Equal) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Equal) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: Equal type mismatch: %d != %d", g, w)
}
Expand Down Expand Up @@ -510,7 +534,7 @@ type Intersection struct{}
func (Intersection) Type() BinaryOpType {
return BinaryIntersection
}
func (Intersection) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Intersection) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
set, ok := left.(Set)
if !ok {
return nil, errors.New("datalog: Intersection left value must be a Set")
Expand All @@ -530,7 +554,7 @@ type Union struct{}
func (Union) Type() BinaryOpType {
return BinaryUnion
}
func (Union) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Union) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
set, ok := left.(Set)
if !ok {
return nil, errors.New("datalog: Union left value must be a Set")
Expand Down Expand Up @@ -654,7 +678,7 @@ type Sub struct{}
func (Sub) Type() BinaryOpType {
return BinarySub
}
func (Sub) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Sub) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
ileft, ok := left.(Integer)
if !ok {
return nil, fmt.Errorf("datalog: Sub requires left value to be an Integer, got %T", left)
Expand Down Expand Up @@ -682,7 +706,7 @@ type Mul struct{}
func (Mul) Type() BinaryOpType {
return BinaryMul
}
func (Mul) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Mul) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
ileft, ok := left.(Integer)
if !ok {
return nil, fmt.Errorf("datalog: Mul requires left value to be an Integer, got %T", left)
Expand Down Expand Up @@ -711,7 +735,7 @@ type Div struct{}
func (Div) Type() BinaryOpType {
return BinaryDiv
}
func (Div) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Div) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
ileft, ok := left.(Integer)
if !ok {
return nil, fmt.Errorf("datalog: Div requires left value to be an Integer, got %T", left)
Expand All @@ -735,7 +759,7 @@ type And struct{}
func (And) Type() BinaryOpType {
return BinaryAnd
}
func (And) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (And) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
bleft, ok := left.(Bool)
if !ok {
return nil, fmt.Errorf("datalog: And requires left value to be a Bool, got %T", left)
Expand All @@ -755,7 +779,7 @@ type Or struct{}
func (Or) Type() BinaryOpType {
return BinaryOr
}
func (Or) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Or) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
bleft, ok := left.(Bool)
if !ok {
return nil, fmt.Errorf("datalog: Or requires left value to be a Bool, got %T", left)
Expand Down
49 changes: 49 additions & 0 deletions datalog/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1188,3 +1188,52 @@ func TestBinaryOr(t *testing.T) {
})
}
}

func TestPrint(t *testing.T) {
syms := SymbolTable{}
syms.Insert("abc")
testCases := []struct {
desc string
expr Expression
res string
}{
{
desc: "number",
expr: Expression{Value{Integer(9)}},
res: "9",
},
{
desc: "string",
expr: Expression{Value{syms.Sym("abc")}},
res: "\"abc\"",
},
{
desc: "unary",
expr: Expression{Value{syms.Sym("abc")}, UnaryOp{Length{}}},
res: "\"abc\".length()",
},
{
desc: "binary",
expr: Expression{Value{Integer(9)}, Value{Integer(4)}, BinaryOp{Mul{}}},
res: "9 * 4",
},
{
desc: "parens",
expr: Expression{
Value{Integer(9)},
Value{Integer(3)},
BinaryOp{Add{}},
UnaryOp{Parens{}},
Value{Integer(4)},
BinaryOp{Div{}},
},
res: "(9 + 3) / 4",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
p := tc.expr.Print(&syms)
require.Equal(t, tc.res, p)
})
}
}

0 comments on commit 584aa35

Please sign in to comment.