From f499da0a88ca2fcec800843c25b9af5284941299 Mon Sep 17 00:00:00 2001 From: Nikolay Matrosov Date: Wed, 23 Aug 2023 16:49:35 +0200 Subject: [PATCH 1/2] feat: rewritten Operator overload from Function --- checker/checker.go | 2 +- conf/config.go | 38 +++++++-- conf/operators.go | 50 ++++++++++-- expr.go | 1 + test/operator/operator_test.go | 137 +++++++++++++++++++++++++++++++++ 5 files changed, 212 insertions(+), 16 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index c845dd78..b8808441 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -240,7 +240,7 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { // check operator overloading if fns, ok := v.config.Operators[node.Operator]; ok { - t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, l, r) + t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, v.config.Functions, l, r) if ok { return t, info{} } diff --git a/conf/config.go b/conf/config.go index baf5dee0..f34dd034 100644 --- a/conf/config.go +++ b/conf/config.go @@ -9,6 +9,8 @@ import ( "github.com/expr-lang/expr/vm/runtime" ) +type FunctionTable map[string]*builtin.Function + type Config struct { Env any Types TypesTable @@ -85,21 +87,43 @@ func (c *Config) ConstExpr(name string) { func (c *Config) Check() { for operator, fns := range c.Operators { for _, fn := range fns { - fnType, ok := c.Types[fn] - if !ok || fnType.Type.Kind() != reflect.Func { + fnType, foundType := c.Types[fn] + fnFunc, foundFunc := c.Functions[fn] + if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) { panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, operator)) } - requiredNumIn := 2 - if fnType.Method { - requiredNumIn = 3 // As first argument of method is receiver. + + if foundType { + checkType(fnType, fn, operator) } - if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 { - panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator)) + if foundFunc { + checkFunc(fnFunc, fn, operator) } } } } +func checkType(fnType Tag, fn string, operator string) { + requiredNumIn := 2 + if fnType.Method { + requiredNumIn = 3 // As first argument of method is receiver. + } + if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 { + panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator)) + } +} + +func checkFunc(fn *builtin.Function, name string, operator string) { + if len(fn.Types) == 0 { + panic(fmt.Errorf("function %s for %s operator misses types", name, operator)) + } + for _, t := range fn.Types { + if t.NumIn() != 2 || t.NumOut() != 1 { + panic(fmt.Errorf("function %s for %s operator does not have a correct signature", name, operator)) + } + } +} + func (c *Config) IsOverridden(name string) bool { if _, ok := c.Functions[name]; ok { return true diff --git a/conf/operators.go b/conf/operators.go index ced209fd..d3d7ff50 100644 --- a/conf/operators.go +++ b/conf/operators.go @@ -10,28 +10,62 @@ import ( // Functions should be provided in the environment to allow operator overloading. type OperatorsTable map[string][]string -func FindSuitableOperatorOverload(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) { +func FindSuitableOperatorOverload(fns []string, types TypesTable, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) { + t, fn, ok := FindSuitableOperatorOverloadInFunctions(fns, funcs, l, r) + if !ok { + t, fn, ok = FindSuitableOperatorOverloadInTypes(fns, types, l, r) + } + return t, fn, ok +} + +func FindSuitableOperatorOverloadInTypes(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) { for _, fn := range fns { fnType := types[fn] firstInIndex := 0 if fnType.Method { firstInIndex = 1 // As first argument to method is receiver. } - firstArgType := fnType.Type.In(firstInIndex) - secondArgType := fnType.Type.In(firstInIndex + 1) + ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex) + if done { + return ret, fn, true + } + } + return nil, "", false +} - firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType))) - secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType))) - if firstArgumentFit && secondArgumentFit { - return fnType.Type.Out(0), fn, true +func FindSuitableOperatorOverloadInFunctions(fns []string, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) { + for _, fn := range fns { + fnType, ok := funcs[fn] + if !ok { + continue + } + firstInIndex := 0 + for _, overload := range fnType.Types { + ret, done := checkTypeSuits(overload, l, r, firstInIndex) + if done { + return ret, fn, true + } } } return nil, "", false } +func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) { + firstArgType := t.In(firstInIndex) + secondArgType := t.In(firstInIndex + 1) + + firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType))) + secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType))) + if firstArgumentFit && secondArgumentFit { + return t.Out(0), true + } + return nil, false +} + type OperatorPatcher struct { Operators OperatorsTable Types TypesTable + Functions FunctionTable } func (p *OperatorPatcher) Visit(node *ast.Node) { @@ -48,7 +82,7 @@ func (p *OperatorPatcher) Visit(node *ast.Node) { leftType := binaryNode.Left.Type() rightType := binaryNode.Right.Type() - ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, leftType, rightType) + ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, p.Functions, leftType, rightType) if ok { newNode := &ast.CallNode{ Callee: &ast.IdentifierNode{Value: fn}, diff --git a/expr.go b/expr.go index f32df1d3..e43abaaf 100644 --- a/expr.go +++ b/expr.go @@ -192,6 +192,7 @@ func Compile(input string, ops ...Option) (*vm.Program, error) { config.Visitors = append(config.Visitors, &conf.OperatorPatcher{ Operators: config.Operators, Types: config.Types, + Functions: config.Functions, }) } diff --git a/test/operator/operator_test.go b/test/operator/operator_test.go index af50a24e..0a71a005 100644 --- a/test/operator/operator_test.go +++ b/test/operator/operator_test.go @@ -1,6 +1,7 @@ package operator_test import ( + "fmt" "testing" "time" @@ -55,3 +56,139 @@ func TestOperator_interface(t *testing.T) { require.NoError(t, err) require.Equal(t, true, output) } + +type Value struct { + Int int +} + +func TestOperator_Function(t *testing.T) { + env := map[string]interface{}{ + "foo": Value{1}, + "bar": Value{2}, + } + + tests := []struct { + input string + want int + }{ + { + input: `foo + bar`, + want: 3, + }, + { + input: `2 + 4`, + want: 6, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) { + program, err := expr.Compile( + tt.input, + expr.Env(env), + expr.Operator("+", "Add", "AddInt"), + expr.Function("Add", func(args ...interface{}) (interface{}, error) { + return args[0].(Value).Int + args[1].(Value).Int, nil + }, + new(func(_ Value, __ Value) int), + ), + expr.Function("AddInt", func(args ...interface{}) (interface{}, error) { + return args[0].(int) + args[1].(int), nil + }, + new(func(_ int, __ int) int), + ), + ) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, tt.want, output) + }) + } + +} + +func TestOperator_Function_WithTypes(t *testing.T) { + env := map[string]interface{}{ + "foo": Value{1}, + "bar": Value{2}, + } + + require.PanicsWithError(t, `function Add for + operator misses types`, func() { + _, _ = expr.Compile( + `foo + bar`, + expr.Env(env), + expr.Operator("+", "Add", "AddInt"), + expr.Function("Add", func(args ...interface{}) (interface{}, error) { + return args[0].(Value).Int + args[1].(Value).Int, nil + }), + ) + }) + + require.PanicsWithError(t, `function Add for + operator does not have a correct signature`, func() { + _, _ = expr.Compile( + `foo + bar`, + expr.Env(env), + expr.Operator("+", "Add", "AddInt"), + expr.Function("Add", func(args ...interface{}) (interface{}, error) { + return args[0].(Value).Int + args[1].(Value).Int, nil + }, + new(func(_ Value) int), + ), + ) + }) + +} + +func TestOperator_FunctionOverTypesPrecedence(t *testing.T) { + env := struct { + Add func(a, b int) int + }{ + Add: func(a, b int) int { + return a + b + }, + } + + program, err := expr.Compile( + `1 + 2`, + expr.Env(env), + expr.Operator("+", "Add"), + expr.Function("Add", func(args ...interface{}) (interface{}, error) { + // Wierd function that returns 100 + a + b in testing purposes. + return args[0].(int) + args[1].(int) + 100, nil + }, + new(func(_ int, __ int) int), + ), + ) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 103, output) +} + +func TestOperator_CanBeDefinedEitherInTypesOrInFunctions(t *testing.T) { + env := struct { + Add func(a, b int) int + }{ + Add: func(a, b int) int { + return a + b + }, + } + + program, err := expr.Compile( + `1 + 2`, + expr.Env(env), + expr.Operator("+", "Add", "AddValues"), + expr.Function("AddValues", func(args ...interface{}) (interface{}, error) { + return args[0].(Value).Int + args[1].(Value).Int, nil + }, + new(func(_ Value, __ Value) int), + ), + ) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 3, output) +} From d9b833c9a76d1304e322753a190853974cf6bc8e Mon Sep 17 00:00:00 2001 From: Nikolay Matrosov Date: Wed, 23 Aug 2023 17:07:14 +0200 Subject: [PATCH 2/2] fix: add test --- conf/operators.go | 5 ++++- test/operator/operator_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/conf/operators.go b/conf/operators.go index d3d7ff50..646153cb 100644 --- a/conf/operators.go +++ b/conf/operators.go @@ -20,7 +20,10 @@ func FindSuitableOperatorOverload(fns []string, types TypesTable, funcs Function func FindSuitableOperatorOverloadInTypes(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) { for _, fn := range fns { - fnType := types[fn] + fnType, ok := types[fn] + if !ok { + continue + } firstInIndex := 0 if fnType.Method { firstInIndex = 1 // As first argument to method is receiver. diff --git a/test/operator/operator_test.go b/test/operator/operator_test.go index 0a71a005..59bc73f2 100644 --- a/test/operator/operator_test.go +++ b/test/operator/operator_test.go @@ -192,3 +192,33 @@ func TestOperator_CanBeDefinedEitherInTypesOrInFunctions(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, output) } + +func TestOperator_Polymorphic(t *testing.T) { + env := struct { + Add func(a, b int) int + Foo Value + Bar Value + }{ + Add: func(a, b int) int { + return a + b + }, + Foo: Value{1}, + Bar: Value{2}, + } + + program, err := expr.Compile( + `1 + 2 + (Foo + Bar)`, + expr.Env(env), + expr.Operator("+", "Add", "AddValues"), + expr.Function("AddValues", func(args ...interface{}) (interface{}, error) { + return args[0].(Value).Int + args[1].(Value).Int, nil + }, + new(func(_ Value, __ Value) int), + ), + ) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 6, output) +}