Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Operator overload from Function #408

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {

// check operator overloading
if fns, ok := v.config.Operators[node.Operator]; ok {
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, l, r)
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, v.config.Functions, l, r)
if ok {
return t, info{}
}
Expand Down
38 changes: 31 additions & 7 deletions conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/expr-lang/expr/vm/runtime"
)

type FunctionTable map[string]*builtin.Function

type Config struct {
Env any
Types TypesTable
Expand Down Expand Up @@ -85,21 +87,43 @@ func (c *Config) ConstExpr(name string) {
func (c *Config) Check() {
for operator, fns := range c.Operators {
for _, fn := range fns {
fnType, ok := c.Types[fn]
if !ok || fnType.Type.Kind() != reflect.Func {
fnType, foundType := c.Types[fn]
fnFunc, foundFunc := c.Functions[fn]
if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, operator))
}
requiredNumIn := 2
if fnType.Method {
requiredNumIn = 3 // As first argument of method is receiver.

if foundType {
checkType(fnType, fn, operator)
}
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
if foundFunc {
checkFunc(fnFunc, fn, operator)
}
}
}
}

func checkType(fnType Tag, fn string, operator string) {
requiredNumIn := 2
if fnType.Method {
requiredNumIn = 3 // As first argument of method is receiver.
}
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
}
}

func checkFunc(fn *builtin.Function, name string, operator string) {
if len(fn.Types) == 0 {
panic(fmt.Errorf("function %s for %s operator misses types", name, operator))
}
for _, t := range fn.Types {
if t.NumIn() != 2 || t.NumOut() != 1 {
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", name, operator))
}
}
}

func (c *Config) IsOverridden(name string) bool {
if _, ok := c.Functions[name]; ok {
return true
Expand Down
55 changes: 46 additions & 9 deletions conf/operators.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,65 @@ import (
// Functions should be provided in the environment to allow operator overloading.
type OperatorsTable map[string][]string

func FindSuitableOperatorOverload(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
func FindSuitableOperatorOverload(fns []string, types TypesTable, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) {
t, fn, ok := FindSuitableOperatorOverloadInFunctions(fns, funcs, l, r)
if !ok {
t, fn, ok = FindSuitableOperatorOverloadInTypes(fns, types, l, r)
}
return t, fn, ok
}

func FindSuitableOperatorOverloadInTypes(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
for _, fn := range fns {
fnType := types[fn]
fnType, ok := types[fn]
if !ok {
continue
}
firstInIndex := 0
if fnType.Method {
firstInIndex = 1 // As first argument to method is receiver.
}
firstArgType := fnType.Type.In(firstInIndex)
secondArgType := fnType.Type.In(firstInIndex + 1)
ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex)
if done {
return ret, fn, true
}
}
return nil, "", false
}

firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
if firstArgumentFit && secondArgumentFit {
return fnType.Type.Out(0), fn, true
func FindSuitableOperatorOverloadInFunctions(fns []string, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) {
for _, fn := range fns {
fnType, ok := funcs[fn]
if !ok {
continue
}
firstInIndex := 0
for _, overload := range fnType.Types {
ret, done := checkTypeSuits(overload, l, r, firstInIndex)
if done {
return ret, fn, true
}
}
}
return nil, "", false
}

func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
firstArgType := t.In(firstInIndex)
secondArgType := t.In(firstInIndex + 1)

firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
if firstArgumentFit && secondArgumentFit {
return t.Out(0), true
}
return nil, false
}

type OperatorPatcher struct {
Operators OperatorsTable
Types TypesTable
Functions FunctionTable
}

func (p *OperatorPatcher) Visit(node *ast.Node) {
Expand All @@ -48,7 +85,7 @@ func (p *OperatorPatcher) Visit(node *ast.Node) {
leftType := binaryNode.Left.Type()
rightType := binaryNode.Right.Type()

ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, leftType, rightType)
ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, p.Functions, leftType, rightType)
if ok {
newNode := &ast.CallNode{
Callee: &ast.IdentifierNode{Value: fn},
Expand Down
1 change: 1 addition & 0 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
config.Visitors = append(config.Visitors, &conf.OperatorPatcher{
Operators: config.Operators,
Types: config.Types,
Functions: config.Functions,
})
}

Expand Down
167 changes: 167 additions & 0 deletions test/operator/operator_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package operator_test

import (
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -55,3 +56,169 @@ func TestOperator_interface(t *testing.T) {
require.NoError(t, err)
require.Equal(t, true, output)
}

type Value struct {
Int int
}

func TestOperator_Function(t *testing.T) {
env := map[string]interface{}{
"foo": Value{1},
"bar": Value{2},
}

tests := []struct {
input string
want int
}{
{
input: `foo + bar`,
want: 3,
},
{
input: `2 + 4`,
want: 6,
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) {
program, err := expr.Compile(
tt.input,
expr.Env(env),
expr.Operator("+", "Add", "AddInt"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value, __ Value) int),
),
expr.Function("AddInt", func(args ...interface{}) (interface{}, error) {
return args[0].(int) + args[1].(int), nil
},
new(func(_ int, __ int) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, tt.want, output)
})
}

}

func TestOperator_Function_WithTypes(t *testing.T) {
env := map[string]interface{}{
"foo": Value{1},
"bar": Value{2},
}

require.PanicsWithError(t, `function Add for + operator misses types`, func() {
_, _ = expr.Compile(
`foo + bar`,
expr.Env(env),
expr.Operator("+", "Add", "AddInt"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
}),
)
})

require.PanicsWithError(t, `function Add for + operator does not have a correct signature`, func() {
_, _ = expr.Compile(
`foo + bar`,
expr.Env(env),
expr.Operator("+", "Add", "AddInt"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value) int),
),
)
})

}

func TestOperator_FunctionOverTypesPrecedence(t *testing.T) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you suggested I also added test on precedence order.

env := struct {
Add func(a, b int) int
}{
Add: func(a, b int) int {
return a + b
},
}

program, err := expr.Compile(
`1 + 2`,
expr.Env(env),
expr.Operator("+", "Add"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
nikolaymatrosov marked this conversation as resolved.
Show resolved Hide resolved
// Wierd function that returns 100 + a + b in testing purposes.
return args[0].(int) + args[1].(int) + 100, nil
},
new(func(_ int, __ int) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 103, output)
}

func TestOperator_CanBeDefinedEitherInTypesOrInFunctions(t *testing.T) {
env := struct {
Add func(a, b int) int
}{
Add: func(a, b int) int {
return a + b
},
}

program, err := expr.Compile(
`1 + 2`,
expr.Env(env),
expr.Operator("+", "Add", "AddValues"),
expr.Function("AddValues", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value, __ Value) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 3, output)
}

func TestOperator_Polymorphic(t *testing.T) {
env := struct {
Add func(a, b int) int
Foo Value
Bar Value
}{
Add: func(a, b int) int {
return a + b
},
Foo: Value{1},
Bar: Value{2},
}

program, err := expr.Compile(
`1 + 2 + (Foo + Bar)`,
expr.Env(env),
expr.Operator("+", "Add", "AddValues"),
expr.Function("AddValues", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value, __ Value) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 6, output)
}
Loading