diff --git a/expr_test.go b/expr_test.go index 5128ea51..db254130 100644 --- a/expr_test.go +++ b/expr_test.go @@ -901,18 +901,147 @@ func TestExpr(t *testing.T) { `all(1..3, {# > 0})`, true, }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 4})`, + false, + }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# != 2})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# != 2})`, + false, + }, { `none(1..3, {# == 0})`, true, }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 4})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 3})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 3})`, + false, + }, { `any([1,1,0,1], {# == 0})`, true, }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 2})`, + false, + }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 4})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 4})`, + false, + }, { `one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`, true, }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, + false, + }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, + false, + }, + { `count(1..30, {# % 3 == 0})`, 10, @@ -2524,3 +2653,63 @@ func TestOperatorDependsOnEnv(t *testing.T) { require.NoError(t, err) assert.Equal(t, 42, out) } + +func TestIssue624(t *testing.T) { + type tag struct { + Name string + } + + type item struct { + Tags []tag + } + + i := item{ + Tags: []tag{ + {Name: "one"}, + {Name: "two"}, + }, + } + + rule := `[ +true && true, +one(Tags, .Name in ["one"]), +one(Tags, .Name in ["two"]), +one(Tags, .Name in ["one"]) && one(Tags, .Name in ["two"]) +]` + resp, err := expr.Eval(rule, i) + require.NoError(t, err) + require.Equal(t, []interface{}{true, true, true, true}, resp) +} + +func TestPredicateCombination(t *testing.T) { + tests := []struct { + code1 string + code2 string + }{ + {"all(1..3, {# > 0}) && all(1..3, {# < 4})", "all(1..3, {# > 0 && # < 4})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 4})", "all(1..3, {# > 1 && # < 4})"}, + {"all(1..3, {# > 0}) && all(1..3, {# < 2})", "all(1..3, {# > 0 && # < 2})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 2})", "all(1..3, {# > 1 && # < 2})"}, + + {"any(1..3, {# > 0}) || any(1..3, {# < 4})", "any(1..3, {# > 0 || # < 4})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 4})", "any(1..3, {# > 1 || # < 4})"}, + {"any(1..3, {# > 0}) || any(1..3, {# < 2})", "any(1..3, {# > 0 || # < 2})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 2})", "any(1..3, {# > 1 || # < 2})"}, + + {"none(1..3, {# > 0}) && none(1..3, {# < 4})", "none(1..3, {# > 0 || # < 4})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 4})", "none(1..3, {# > 1 || # < 4})"}, + {"none(1..3, {# > 0}) && none(1..3, {# < 2})", "none(1..3, {# > 0 || # < 2})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 2})", "none(1..3, {# > 1 || # < 2})"}, + } + for _, tt := range tests { + t.Run(tt.code1, func(t *testing.T) { + out1, err := expr.Eval(tt.code1, nil) + require.NoError(t, err) + + out2, err := expr.Eval(tt.code2, nil) + require.NoError(t, err) + + require.Equal(t, out1, out2) + }) + } +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index a9c0fa3d..6d1fb0b5 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -36,5 +36,6 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLen{}) Walk(node, &filterLast{}) Walk(node, &filterFirst{}) + Walk(node, &predicateCombination{}) return nil } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index e45de763..316b1718 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -1,6 +1,7 @@ package optimizer_test import ( + "fmt" "reflect" "strings" "testing" @@ -339,3 +340,118 @@ func TestOptimize_filter_map_first(t *testing.T) { assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } + +func TestOptimize_predicate_combination(t *testing.T) { + tests := []struct { + op string + fn string + wantOp string + }{ + {"and", "all", "and"}, + {"&&", "all", "&&"}, + {"or", "any", "or"}, + {"||", "any", "||"}, + {"and", "none", "or"}, + {"&&", "none", "||"}, + } + + for _, tt := range tests { + rule := fmt.Sprintf(`%s(users, .Age > 18 and .Name != "Bob") %s %s(users, .Age < 30)`, tt.fn, tt.op, tt.fn) + t.Run(rule, func(t *testing.T) { + tree, err := parser.Parse(rule) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: tt.fn, + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: tt.wantOp, + Left: &ast.BinaryNode{ + Operator: "and", + Left: &ast.BinaryNode{ + Operator: ">", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + Right: &ast.BinaryNode{ + Operator: "<", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 30}, + }, + }, + }, + }, + } + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) + }) + } +} + +func TestOptimize_predicate_combination_nested(t *testing.T) { + tree, err := parser.Parse(`all(users, {all(.Friends, {.Age == 18 })}) && all(users, {all(.Friends, {.Name != "Bob" })})`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: "all", + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BuiltinNode{ + Name: "all", + Arguments: []ast.Node{ + &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Friends"}, + }, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: "&&", + Left: &ast.BinaryNode{ + Operator: "==", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + }, + }, + }, + }, + }, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go new file mode 100644 index 00000000..6e8a7f7c --- /dev/null +++ b/optimizer/predicate_combination.go @@ -0,0 +1,61 @@ +package optimizer + +import ( + . "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/parser/operator" +) + +/* +predicateCombination is a visitor that combines multiple predicate calls into a single call. +For example, the following expression: + + all(x, x > 1) && all(x, x < 10) -> all(x, x > 1 && x < 10) + any(x, x > 1) || any(x, x < 10) -> any(x, x > 1 || x < 10) + none(x, x > 1) && none(x, x < 10) -> none(x, x > 1 || x < 10) +*/ +type predicateCombination struct{} + +func (v *predicateCombination) Visit(node *Node) { + if op, ok := (*node).(*BinaryNode); ok && operator.IsBoolean(op.Operator) { + if left, ok := op.Left.(*BuiltinNode); ok { + if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { + if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { + if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { + closure := &ClosureNode{ + Node: &BinaryNode{ + Operator: combinedOp, + Left: left.Arguments[1].(*ClosureNode).Node, + Right: right.Arguments[1].(*ClosureNode).Node, + }, + } + v.Visit(&closure.Node) + Patch(node, &BuiltinNode{ + Name: left.Name, + Arguments: []Node{ + left.Arguments[0], + closure, + }, + }) + } + } + } + } + } +} + +func combinedOperator(fn, op string) (string, bool) { + switch { + case fn == "all" && (op == "and" || op == "&&"): + return op, true + case fn == "any" && (op == "or" || op == "||"): + return op, true + case fn == "none" && (op == "and" || op == "&&"): + switch op { + case "and": + return "or", true + case "&&": + return "||", true + } + } + return "", false +}