diff --git a/compiler/kernel/expr.go b/compiler/kernel/expr.go index 8c9443b320..7c5fee5577 100644 --- a/compiler/kernel/expr.go +++ b/compiler/kernel/expr.go @@ -16,19 +16,6 @@ import ( "golang.org/x/text/unicode/norm" ) -type exprContext struct { - resetters expr.Resetters -} - -func newExprContext() *exprContext { return new(exprContext) } - -func (r *exprContext) addResetter(resetter expr.Resetter) { - if r == nil { - return - } - r.resetters = append(r.resetters, resetter) -} - // compileExpr compiles the given Expression into an object // that evaluates the expression against a provided Record. It returns an // error if compilation fails for any reason. @@ -59,7 +46,7 @@ func (r *exprContext) addResetter(resetter expr.Resetter) { // TBD: string values and net.IP address do not need to be copied because they // are allocated by go libraries and temporary buffers are not used. This will // change down the road when we implement no-allocation string and IP conversion. -func (b *Builder) compileExpr(ectx *exprContext, e dag.Expr) (expr.Evaluator, error) { +func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { if e == nil { return nil, errors.New("null expression not allowed") } @@ -73,43 +60,40 @@ func (b *Builder) compileExpr(ectx *exprContext, e dag.Expr) (expr.Evaluator, er case *dag.Var: return expr.NewVar(e.Slot), nil case *dag.Search: - return b.compileSearch(ectx, e) + return b.compileSearch(e) case *dag.This: return expr.NewDottedExpr(b.zctx(), field.Path(e.Path)), nil case *dag.Dot: - return b.compileDotExpr(ectx, e) + return b.compileDotExpr(e) case *dag.UnaryExpr: - return b.compileUnary(ectx, *e) + return b.compileUnary(*e) case *dag.BinaryExpr: - return b.compileBinary(ectx, e) + return b.compileBinary(e) case *dag.Conditional: - return b.compileConditional(ectx, *e) + return b.compileConditional(*e) case *dag.Call: - return b.compileCall(ectx, *e) + return b.compileCall(*e) case *dag.RegexpMatch: - return b.compileRegexpMatch(ectx, e) + return b.compileRegexpMatch(e) case *dag.RegexpSearch: - return b.compileRegexpSearch(ectx, e) + return b.compileRegexpSearch(e) case *dag.RecordExpr: - return b.compileRecordExpr(ectx, e) + return b.compileRecordExpr(e) case *dag.ArrayExpr: - return b.compileArrayExpr(ectx, e) + return b.compileArrayExpr(e) case *dag.SetExpr: - return b.compileSetExpr(ectx, e) + return b.compileSetExpr(e) case *dag.MapCall: - return b.compileMapCall(ectx, e) + return b.compileMapCall(e) case *dag.MapExpr: - return b.compileMapExpr(ectx, e) + return b.compileMapExpr(e) case *dag.Agg: - agg, err := b.compileAgg(ectx, e) + agg, err := b.compileAgg(e) if err != nil { return nil, err } - if ectx == nil { - panic("system error: exprContext is nil") - } aggexpr := expr.NewAggregatorExpr(b.zctx(), agg) - ectx.addResetter(aggexpr) + b.resetters = append(b.resetters, aggexpr) return aggexpr, nil case *dag.OverExpr: return b.compileOverExpr(e) @@ -118,31 +102,31 @@ func (b *Builder) compileExpr(ectx *exprContext, e dag.Expr) (expr.Evaluator, er } } -func (b *Builder) compileExprWithEmpty(ectx *exprContext, e dag.Expr) (expr.Evaluator, error) { +func (b *Builder) compileExprWithEmpty(e dag.Expr) (expr.Evaluator, error) { if e == nil { return nil, nil } - return b.compileExpr(ectx, e) + return b.compileExpr(e) } -func (b *Builder) compileBinary(ectx *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileBinary(e *dag.BinaryExpr) (expr.Evaluator, error) { if slice, ok := e.RHS.(*dag.BinaryExpr); ok && slice.Op == ":" { - return b.compileSlice(ectx, e.LHS, slice) + return b.compileSlice(e.LHS, slice) } if e.Op == "in" { // Do a faster comparison if the LHS is a compile-time constant expression. - if in, err := b.compileConstIn(ectx, e); in != nil && err == nil { + if in, err := b.compileConstIn(e); in != nil && err == nil { return in, err } } - if e, err := b.compileConstCompare(ectx, e); e != nil && err == nil { + if e, err := b.compileConstCompare(e); e != nil && err == nil { return e, nil } - lhs, err := b.compileExpr(ectx, e.LHS) + lhs, err := b.compileExpr(e.LHS) if err != nil { return nil, err } - rhs, err := b.compileExpr(ectx, e.RHS) + rhs, err := b.compileExpr(e.RHS) if err != nil { return nil, err } @@ -166,7 +150,7 @@ func (b *Builder) compileBinary(ectx *exprContext, e *dag.BinaryExpr) (expr.Eval } } -func (b *Builder) compileConstIn(r *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileConstIn(e *dag.BinaryExpr) (expr.Evaluator, error) { literal, err := b.evalAtCompileTime(e.LHS) if err != nil || literal.IsError() { // If the RHS here is a literal value, it would be good @@ -179,14 +163,14 @@ func (b *Builder) compileConstIn(r *exprContext, e *dag.BinaryExpr) (expr.Evalua if eql == nil || err != nil { return nil, nil } - operand, err := b.compileExpr(r, e.RHS) + operand, err := b.compileExpr(e.RHS) if err != nil { return nil, err } return expr.NewFilter(operand, expr.Contains(eql)), nil } -func (b *Builder) compileConstCompare(ectx *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileConstCompare(e *dag.BinaryExpr) (expr.Evaluator, error) { switch e.Op { case "==", "!=", "<", "<=", ">", ">=": default: @@ -203,19 +187,19 @@ func (b *Builder) compileConstCompare(ectx *exprContext, e *dag.BinaryExpr) (exp // non-error situation that isn't a simple comparison. return nil, nil } - operand, err := b.compileExpr(ectx, e.LHS) + operand, err := b.compileExpr(e.LHS) if err != nil { return nil, err } return expr.NewFilter(operand, comparison), nil } -func (b *Builder) compileSearch(ectx *exprContext, search *dag.Search) (expr.Evaluator, error) { +func (b *Builder) compileSearch(search *dag.Search) (expr.Evaluator, error) { val, err := zson.ParseValue(b.zctx(), search.Value) if err != nil { return nil, err } - e, err := b.compileExpr(ectx, search.Expr) + e, err := b.compileExpr(search.Expr) if err != nil { return nil, err } @@ -228,24 +212,24 @@ func (b *Builder) compileSearch(ectx *exprContext, search *dag.Search) (expr.Eva return expr.NewSearch(search.Text, val, e) } -func (b *Builder) compileSlice(ectx *exprContext, container dag.Expr, slice *dag.BinaryExpr) (expr.Evaluator, error) { - from, err := b.compileExprWithEmpty(ectx, slice.LHS) +func (b *Builder) compileSlice(container dag.Expr, slice *dag.BinaryExpr) (expr.Evaluator, error) { + from, err := b.compileExprWithEmpty(slice.LHS) if err != nil { return nil, err } - to, err := b.compileExprWithEmpty(ectx, slice.RHS) + to, err := b.compileExprWithEmpty(slice.RHS) if err != nil { return nil, err } - e, err := b.compileExpr(ectx, container) + e, err := b.compileExpr(container) if err != nil { return nil, err } return expr.NewSlice(b.zctx(), e, from, to), nil } -func (b *Builder) compileUnary(ectx *exprContext, unary dag.UnaryExpr) (expr.Evaluator, error) { - e, err := b.compileExpr(ectx, unary.Operand) +func (b *Builder) compileUnary(unary dag.UnaryExpr) (expr.Evaluator, error) { + e, err := b.compileExpr(unary.Operand) if err != nil { return nil, err } @@ -259,48 +243,48 @@ func (b *Builder) compileUnary(ectx *exprContext, unary dag.UnaryExpr) (expr.Eva } } -func (b *Builder) compileConditional(ectx *exprContext, node dag.Conditional) (expr.Evaluator, error) { - predicate, err := b.compileExpr(ectx, node.Cond) +func (b *Builder) compileConditional(node dag.Conditional) (expr.Evaluator, error) { + predicate, err := b.compileExpr(node.Cond) if err != nil { return nil, err } - thenExpr, err := b.compileExpr(ectx, node.Then) + thenExpr, err := b.compileExpr(node.Then) if err != nil { return nil, err } - elseExpr, err := b.compileExpr(ectx, node.Else) + elseExpr, err := b.compileExpr(node.Else) if err != nil { return nil, err } return expr.NewConditional(b.zctx(), predicate, thenExpr, elseExpr), nil } -func (b *Builder) compileDotExpr(ectx *exprContext, dot *dag.Dot) (expr.Evaluator, error) { - record, err := b.compileExpr(ectx, dot.LHS) +func (b *Builder) compileDotExpr(dot *dag.Dot) (expr.Evaluator, error) { + record, err := b.compileExpr(dot.LHS) if err != nil { return nil, err } return expr.NewDotExpr(b.zctx(), record, dot.RHS), nil } -func (b *Builder) compileLval(ectx *exprContext, e dag.Expr) (*expr.Lval, error) { +func (b *Builder) compileLval(e dag.Expr) (*expr.Lval, error) { switch e := e.(type) { case *dag.BinaryExpr: if e.Op != "[" { return nil, fmt.Errorf("internal error: invalid lval %#v", e) } - lhs, err := b.compileLval(ectx, e.LHS) + lhs, err := b.compileLval(e.LHS) if err != nil { return nil, err } - rhs, err := b.compileExpr(ectx, e.RHS) + rhs, err := b.compileExpr(e.RHS) if err != nil { return nil, err } lhs.Elems = append(lhs.Elems, expr.NewExprLvalElem(b.zctx(), rhs)) return lhs, nil case *dag.Dot: - lhs, err := b.compileLval(ectx, e.LHS) + lhs, err := b.compileLval(e.LHS) if err != nil { return nil, err } @@ -316,21 +300,21 @@ func (b *Builder) compileLval(ectx *exprContext, e dag.Expr) (*expr.Lval, error) return nil, fmt.Errorf("internal error: invalid lval %#v", e) } -func (b *Builder) compileAssignment(ectx *exprContext, node *dag.Assignment) (expr.Assignment, error) { - lhs, err := b.compileLval(ectx, node.LHS) +func (b *Builder) compileAssignment(node *dag.Assignment) (expr.Assignment, error) { + lhs, err := b.compileLval(node.LHS) if err != nil { return expr.Assignment{}, err } - rhs, err := b.compileExpr(ectx, node.RHS) + rhs, err := b.compileExpr(node.RHS) if err != nil { return expr.Assignment{}, fmt.Errorf("rhs of assigment expression: %w", err) } return expr.Assignment{LHS: lhs, RHS: rhs}, err } -func (b *Builder) compileCall(ectx *exprContext, call dag.Call) (expr.Evaluator, error) { +func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { if tf := expr.NewShaperTransform(call.Name); tf != 0 { - return b.compileShaper(ectx, call, tf) + return b.compileShaper(call, tf) } var path field.Path // First check if call is to a user defined function, otherwise check for @@ -348,42 +332,42 @@ func (b *Builder) compileCall(ectx *exprContext, call dag.Call) (expr.Evaluator, dagPath := &dag.This{Kind: "This", Path: path} args = append([]dag.Expr{dagPath}, args...) } - exprs, err := b.compileExprs(ectx, args) + exprs, err := b.compileExprs(args) if err != nil { return nil, fmt.Errorf("%s(): bad argument: %w", call.Name, err) } return expr.NewCall(b.zctx(), fn, exprs), nil } -func (b *Builder) compileMapCall(ectx *exprContext, a *dag.MapCall) (expr.Evaluator, error) { - e, err := b.compileExpr(ectx, a.Expr) +func (b *Builder) compileMapCall(a *dag.MapCall) (expr.Evaluator, error) { + e, err := b.compileExpr(a.Expr) if err != nil { return nil, err } - inner, err := b.compileExpr(ectx, a.Inner) + inner, err := b.compileExpr(a.Inner) if err != nil { return nil, err } return expr.NewMapCall(b.zctx(), e, inner), nil } -func (b *Builder) compileShaper(ectx *exprContext, node dag.Call, tf expr.ShaperTransform) (expr.Evaluator, error) { +func (b *Builder) compileShaper(node dag.Call, tf expr.ShaperTransform) (expr.Evaluator, error) { args := node.Args - field, err := b.compileExpr(ectx, args[0]) + field, err := b.compileExpr(args[0]) if err != nil { return nil, err } - typExpr, err := b.compileExpr(ectx, args[1]) + typExpr, err := b.compileExpr(args[1]) if err != nil { return nil, err } return expr.NewShaper(b.zctx(), field, typExpr, tf) } -func (b *Builder) compileExprs(ectx *exprContext, in []dag.Expr) ([]expr.Evaluator, error) { +func (b *Builder) compileExprs(in []dag.Expr) ([]expr.Evaluator, error) { var exprs []expr.Evaluator for _, e := range in { - ev, err := b.compileExpr(ectx, e) + ev, err := b.compileExpr(e) if err != nil { return nil, err } @@ -392,8 +376,8 @@ func (b *Builder) compileExprs(ectx *exprContext, in []dag.Expr) ([]expr.Evaluat return exprs, nil } -func (b *Builder) compileRegexpMatch(ectx *exprContext, match *dag.RegexpMatch) (expr.Evaluator, error) { - e, err := b.compileExpr(ectx, match.Expr) +func (b *Builder) compileRegexpMatch(match *dag.RegexpMatch) (expr.Evaluator, error) { + e, err := b.compileExpr(match.Expr) if err != nil { return nil, err } @@ -404,8 +388,8 @@ func (b *Builder) compileRegexpMatch(ectx *exprContext, match *dag.RegexpMatch) return expr.NewRegexpMatch(re, e), nil } -func (b *Builder) compileRegexpSearch(ectx *exprContext, search *dag.RegexpSearch) (expr.Evaluator, error) { - e, err := b.compileExpr(ectx, search.Expr) +func (b *Builder) compileRegexpSearch(search *dag.RegexpSearch) (expr.Evaluator, error) { + e, err := b.compileExpr(search.Expr) if err != nil { return nil, err } @@ -417,12 +401,12 @@ func (b *Builder) compileRegexpSearch(ectx *exprContext, search *dag.RegexpSearc return expr.SearchByPredicate(expr.Contains(match), e), nil } -func (b *Builder) compileRecordExpr(ectx *exprContext, record *dag.RecordExpr) (expr.Evaluator, error) { +func (b *Builder) compileRecordExpr(record *dag.RecordExpr) (expr.Evaluator, error) { var elems []expr.RecordElem for _, elem := range record.Elems { switch elem := elem.(type) { case *dag.Field: - e, err := b.compileExpr(ectx, elem.Value) + e, err := b.compileExpr(elem.Value) if err != nil { return nil, err } @@ -431,7 +415,7 @@ func (b *Builder) compileRecordExpr(ectx *exprContext, record *dag.RecordExpr) ( Field: e, }) case *dag.Spread: - e, err := b.compileExpr(ectx, elem.Expr) + e, err := b.compileExpr(elem.Expr) if err != nil { return nil, err } @@ -441,34 +425,34 @@ func (b *Builder) compileRecordExpr(ectx *exprContext, record *dag.RecordExpr) ( return expr.NewRecordExpr(b.zctx(), elems) } -func (b *Builder) compileArrayExpr(ectx *exprContext, array *dag.ArrayExpr) (expr.Evaluator, error) { - elems, err := b.compileVectorElems(ectx, array.Elems) +func (b *Builder) compileArrayExpr(array *dag.ArrayExpr) (expr.Evaluator, error) { + elems, err := b.compileVectorElems(array.Elems) if err != nil { return nil, err } return expr.NewArrayExpr(b.zctx(), elems), nil } -func (b *Builder) compileSetExpr(ectx *exprContext, set *dag.SetExpr) (expr.Evaluator, error) { - elems, err := b.compileVectorElems(ectx, set.Elems) +func (b *Builder) compileSetExpr(set *dag.SetExpr) (expr.Evaluator, error) { + elems, err := b.compileVectorElems(set.Elems) if err != nil { return nil, err } return expr.NewSetExpr(b.zctx(), elems), nil } -func (b *Builder) compileVectorElems(ectx *exprContext, elems []dag.VectorElem) ([]expr.VectorElem, error) { +func (b *Builder) compileVectorElems(elems []dag.VectorElem) ([]expr.VectorElem, error) { var out []expr.VectorElem for _, elem := range elems { switch elem := elem.(type) { case *dag.Spread: - e, err := b.compileExpr(ectx, elem.Expr) + e, err := b.compileExpr(elem.Expr) if err != nil { return nil, err } out = append(out, expr.VectorElem{Spread: e}) case *dag.VectorValue: - e, err := b.compileExpr(ectx, elem.Expr) + e, err := b.compileExpr(elem.Expr) if err != nil { return nil, err } @@ -478,14 +462,14 @@ func (b *Builder) compileVectorElems(ectx *exprContext, elems []dag.VectorElem) return out, nil } -func (b *Builder) compileMapExpr(ectx *exprContext, m *dag.MapExpr) (expr.Evaluator, error) { +func (b *Builder) compileMapExpr(m *dag.MapExpr) (expr.Evaluator, error) { var entries []expr.Entry for _, f := range m.Entries { - key, err := b.compileExpr(ectx, f.Key) + key, err := b.compileExpr(f.Key) if err != nil { return nil, err } - val, err := b.compileExpr(ectx, f.Value) + val, err := b.compileExpr(f.Value) if err != nil { return nil, err } @@ -498,12 +482,11 @@ func (b *Builder) compileOverExpr(over *dag.OverExpr) (expr.Evaluator, error) { if over.Body == nil { return nil, errors.New("over expression requires a lateral query body") } - var ectx exprContext - names, lets, err := b.compileDefs(&ectx, over.Defs) + names, lets, err := b.compileDefs(over.Defs) if err != nil { return nil, err } - exprs, err := b.compileExprs(&ectx, over.Exprs) + exprs, err := b.compileExprs(over.Exprs) if err != nil { return nil, err } diff --git a/compiler/kernel/filter.go b/compiler/kernel/filter.go index 75b4c65868..51b50a630c 100644 --- a/compiler/kernel/filter.go +++ b/compiler/kernel/filter.go @@ -17,7 +17,7 @@ func (f *Filter) AsEvaluator() (expr.Evaluator, error) { if f == nil { return nil, nil } - return f.builder.compileExpr(nil, f.pushdown) + return f.builder.compileExpr(f.pushdown) } func (f *Filter) AsBufferFilter() (*expr.BufferFilter, error) { @@ -39,7 +39,7 @@ func (f *DeleteFilter) AsEvaluator() (expr.Evaluator, error) { // expression so we get all values that don't match. We also add a missing // call so if the expression results in an error("missing") the value is // kept. - return f.builder.compileExpr(nil, &dag.BinaryExpr{ + return f.builder.compileExpr(&dag.BinaryExpr{ Kind: "BinaryExpr", Op: "or", LHS: &dag.UnaryExpr{ diff --git a/compiler/kernel/groupby.go b/compiler/kernel/groupby.go index 0e66bed0a4..b43dd210cd 100644 --- a/compiler/kernel/groupby.go +++ b/compiler/kernel/groupby.go @@ -13,24 +13,24 @@ import ( ) func (b *Builder) compileGroupBy(parent zbuf.Puller, summarize *dag.Summarize) (*groupby.Op, error) { - var ectx exprContext - keys, err := b.compileAssignments(&ectx, summarize.Keys) + b.exprReset() + keys, err := b.compileAssignments(summarize.Keys) if err != nil { return nil, err } - names, reducers, err := b.compileAggAssignments(&ectx, summarize.Aggs) + names, reducers, err := b.compileAggAssignments(summarize.Aggs) if err != nil { return nil, err } dir := order.Direction(summarize.InputSortDir) - return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut, ectx.resetters) + return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut, b.resetters) } -func (b *Builder) compileAggAssignments(ectx *exprContext, assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) { +func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) { names := make(field.List, 0, len(assignments)) aggs := make([]*expr.Aggregator, 0, len(assignments)) for _, assignment := range assignments { - name, agg, err := b.compileAggAssignment(ectx, assignment) + name, agg, err := b.compileAggAssignment(assignment) if err != nil { return nil, nil, err } @@ -40,7 +40,7 @@ func (b *Builder) compileAggAssignments(ectx *exprContext, assignments []dag.Ass return names, aggs, nil } -func (b *Builder) compileAggAssignment(ectx *exprContext, assignment dag.Assignment) (field.Path, *expr.Aggregator, error) { +func (b *Builder) compileAggAssignment(assignment dag.Assignment) (field.Path, *expr.Aggregator, error) { aggAST, ok := assignment.RHS.(*dag.Agg) if !ok { return nil, nil, errors.New("aggregator is not an aggregation expression") @@ -49,23 +49,23 @@ func (b *Builder) compileAggAssignment(ectx *exprContext, assignment dag.Assignm if !ok { return nil, nil, fmt.Errorf("internal error: aggregator assignment LHS is not a static path: %#v", assignment.LHS) } - m, err := b.compileAgg(ectx, aggAST) + m, err := b.compileAgg(aggAST) return this.Path, m, err } -func (b *Builder) compileAgg(ectx *exprContext, agg *dag.Agg) (*expr.Aggregator, error) { +func (b *Builder) compileAgg(agg *dag.Agg) (*expr.Aggregator, error) { name := agg.Name var err error var arg expr.Evaluator if agg.Expr != nil { - arg, err = b.compileExpr(ectx, agg.Expr) + arg, err = b.compileExpr(agg.Expr) if err != nil { return nil, err } } var where expr.Evaluator if agg.Where != nil { - where, err = b.compileExpr(ectx, agg.Where) + where, err = b.compileExpr(agg.Where) if err != nil { return nil, err } diff --git a/compiler/kernel/op.go b/compiler/kernel/op.go index 7acabbc998..3cfbf563a5 100644 --- a/compiler/kernel/op.go +++ b/compiler/kernel/op.go @@ -49,13 +49,14 @@ import ( var ErrJoinParents = errors.New("join requires two upstream parallel query paths") type Builder struct { - rctx *runtime.Context - mctx *zed.Context - source *data.Source - readers []zio.Reader - progress *zbuf.Progress - deletes *sync.Map - funcs map[string]expr.Function + rctx *runtime.Context + mctx *zed.Context + source *data.Source + readers []zio.Reader + progress *zbuf.Progress + deletes *sync.Map + funcs map[string]expr.Function + resetters expr.Resetters } func NewBuilder(rctx *runtime.Context, source *data.Source) *Builder { @@ -120,13 +121,15 @@ func (b *Builder) Deletes() *sync.Map { return b.deletes } +func (b *Builder) exprReset() { b.resetters = nil } + func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) { switch v := o.(type) { case *dag.Summarize: return b.compileGroupBy(parent, v) case *dag.Cut: - var ectx exprContext - assignments, err := b.compileAssignments(&ectx, v.Args) + b.exprReset() + assignments, err := b.compileAssignments(v.Args) if err != nil { return nil, err } @@ -135,7 +138,7 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) if v.Quiet { cutter.Quiet() } - return op.NewApplier(b.rctx, parent, cutter, ectx.resetters), nil + return op.NewApplier(b.rctx, parent, cutter, b.resetters), nil case *dag.Drop: if len(v.Args) == 0 { return nil, errors.New("drop: no fields given") @@ -151,12 +154,12 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) dropper := expr.NewDropper(b.rctx.Zctx, fields) return op.NewApplier(b.rctx, parent, dropper, expr.NopResetter), nil case *dag.Sort: - var ectx exprContext - fields, err := b.compileExprs(&ectx, v.Args) + b.exprReset() + fields, err := b.compileExprs(v.Args) if err != nil { return nil, err } - sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst, ectx.resetters) + sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst, b.resetters) if err != nil { return nil, fmt.Errorf("compiling sort: %w", err) } @@ -178,36 +181,36 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) case *dag.Pass: return pass.New(parent), nil case *dag.Filter: - var ectx exprContext - f, err := b.compileExpr(&ectx, v.Expr) + b.exprReset() + f, err := b.compileExpr(v.Expr) if err != nil { return nil, fmt.Errorf("compiling filter: %w", err) } - return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f), ectx.resetters), nil + return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f), b.resetters), nil case *dag.Top: - var ectx exprContext - fields, err := b.compileExprs(&ectx, v.Args) + b.exprReset() + fields, err := b.compileExprs(v.Args) if err != nil { return nil, fmt.Errorf("compiling top: %w", err) } - return top.New(b.rctx.Zctx, parent, v.Limit, fields, ectx.resetters, v.Flush), nil + return top.New(b.rctx.Zctx, parent, v.Limit, fields, b.resetters, v.Flush), nil case *dag.Put: - var ectx exprContext - clauses, err := b.compileAssignments(&ectx, v.Args) + b.exprReset() + clauses, err := b.compileAssignments(v.Args) if err != nil { return nil, err } putter := expr.NewPutter(b.rctx.Zctx, clauses) - return op.NewApplier(b.rctx, parent, putter, ectx.resetters), nil + return op.NewApplier(b.rctx, parent, putter, b.resetters), nil case *dag.Rename: - var ectx exprContext + b.exprReset() var srcs, dsts []*expr.Lval for _, a := range v.Args { - src, err := b.compileLval(&ectx, a.RHS) + src, err := b.compileLval(a.RHS) if err != nil { return nil, err } - dst, err := b.compileLval(&ectx, a.LHS) + dst, err := b.compileLval(a.LHS) if err != nil { return nil, err } @@ -215,7 +218,7 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) dsts = append(dsts, dst) } renamer := expr.NewRenamer(b.rctx.Zctx, srcs, dsts) - return op.NewApplier(b.rctx, parent, renamer, ectx.resetters), nil + return op.NewApplier(b.rctx, parent, renamer, b.resetters), nil case *dag.Fuse: return fuse.New(b.rctx, parent) case *dag.Shape: @@ -229,21 +232,21 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) if err != nil { return nil, err } - var ectx exprContext - args, err := b.compileExprs(&ectx, v.Args) + b.exprReset() + args, err := b.compileExprs(v.Args) if err != nil { return nil, err } - return explode.New(b.rctx.Zctx, parent, args, ectx.resetters, typ, v.As) + return explode.New(b.rctx.Zctx, parent, args, b.resetters, typ, v.As) case *dag.Over: return b.compileOver(parent, v) case *dag.Yield: - var ectx exprContext - exprs, err := b.compileExprs(&ectx, v.Exprs) + b.exprReset() + exprs, err := b.compileExprs(v.Exprs) if err != nil { return nil, err } - t := yield.New(parent, exprs, ectx.resetters) + t := yield.New(parent, exprs, b.resetters) return t, nil case *dag.PoolScan: if parent != nil { @@ -360,11 +363,11 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) } } -func (b *Builder) compileDefs(ectx *exprContext, defs []dag.Def) ([]string, []expr.Evaluator, error) { +func (b *Builder) compileDefs(defs []dag.Def) ([]string, []expr.Evaluator, error) { exprs := make([]expr.Evaluator, 0, len(defs)) names := make([]string, 0, len(defs)) for _, def := range defs { - e, err := b.compileExpr(ectx, def.Expr) + e, err := b.compileExpr(def.Expr) if err != nil { return nil, nil, err } @@ -378,16 +381,16 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller, if len(over.Defs) != 0 && over.Body == nil { return nil, errors.New("internal error: over operator has defs but no body") } - var ectx exprContext - withNames, withExprs, err := b.compileDefs(&ectx, over.Defs) + b.exprReset() + withNames, withExprs, err := b.compileDefs(over.Defs) if err != nil { return nil, err } - exprs, err := b.compileExprs(&ectx, over.Exprs) + exprs, err := b.compileExprs(over.Exprs) if err != nil { return nil, err } - enter := traverse.NewOver(b.rctx, parent, exprs, ectx.resetters) + enter := traverse.NewOver(b.rctx, parent, exprs, b.resetters) if over.Body == nil { return enter, nil } @@ -407,10 +410,10 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller, return scope.NewExit(exit), nil } -func (b *Builder) compileAssignments(ectx *exprContext, assignments []dag.Assignment) ([]expr.Assignment, error) { +func (b *Builder) compileAssignments(assignments []dag.Assignment) ([]expr.Assignment, error) { keys := make([]expr.Assignment, 0, len(assignments)) for _, assignment := range assignments { - a, err := b.compileAssignment(ectx, &assignment) + a, err := b.compileAssignment(&assignment) if err != nil { return nil, err } @@ -446,7 +449,7 @@ func (b *Builder) compileScope(scope *dag.Scope, parents []zbuf.Puller) ([]zbuf. // where aggregation expressions in udfs do not have separate state per // invocation. The fix for this might use exprContext to compile udf // expressions per invocation. - if err := b.compileFuncs(&exprContext{}, scope.Funcs); err != nil { + if err := b.compileFuncs(scope.Funcs); err != nil { return nil, err } return b.compileSeq(scope.Body, parents) @@ -494,7 +497,7 @@ func (b *Builder) compileScatter(par *dag.Scatter, parents []zbuf.Puller) ([]zbu return ops, nil } -func (b *Builder) compileFuncs(ectx *exprContext, fns []*dag.Func) error { +func (b *Builder) compileFuncs(fns []*dag.Func) error { udfs := make([]*expr.UDF, 0, len(fns)) for _, f := range fns { if _, ok := b.funcs[f.Name]; ok { @@ -506,7 +509,7 @@ func (b *Builder) compileFuncs(ectx *exprContext, fns []*dag.Func) error { } for i := range fns { var err error - if udfs[i].Body, err = b.compileExpr(ectx, fns[i].Expr); err != nil { + if udfs[i].Body, err = b.compileExpr(fns[i].Expr); err != nil { return err } } @@ -518,12 +521,12 @@ func (b *Builder) compileExprSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([ if len(parents) > 1 { parent = combine.New(b.rctx, parents) } - var ectx exprContext - e, err := b.compileExpr(&ectx, swtch.Expr) + b.exprReset() + e, err := b.compileExpr(swtch.Expr) if err != nil { return nil, err } - s := exprswitch.New(b.rctx, parent, e, ectx.resetters) + s := exprswitch.New(b.rctx, parent, e, b.resetters) var exits []zbuf.Puller for _, c := range swtch.Cases { var val *zed.Value @@ -551,16 +554,16 @@ func (b *Builder) compileSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([]zbu if len(parents) > 1 { parent = combine.New(b.rctx, parents) } - var ectx exprContext + b.exprReset() cases := make([]expr.Evaluator, len(swtch.Cases)) for i, c := range swtch.Cases { var err error - cases[i], err = b.compileExpr(&ectx, c.Expr) + cases[i], err = b.compileExpr(c.Expr) if err != nil { return nil, fmt.Errorf("compiling switch case filter: %w", err) } } - switcher := switcher.New(b.rctx, parent, ectx.resetters) + switcher := switcher.New(b.rctx, parent, b.resetters) var ops []zbuf.Puller for i, c := range cases { o, err := b.compileSeq(swtch.Cases[i].Path, []zbuf.Puller{switcher.AddCase(c)}) @@ -591,17 +594,17 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error if len(parents) != 2 { return nil, ErrJoinParents } - var ectx exprContext - assignments, err := b.compileAssignments(&ectx, o.Args) + b.exprReset() + assignments, err := b.compileAssignments(o.Args) if err != nil { return nil, err } lhs, rhs := splitAssignments(assignments) - leftKey, err := b.compileExpr(&ectx, o.LeftKey) + leftKey, err := b.compileExpr(o.LeftKey) if err != nil { return nil, err } - rightKey, err := b.compileExpr(&ectx, o.RightKey) + rightKey, err := b.compileExpr(o.RightKey) if err != nil { return nil, err } @@ -621,19 +624,19 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error default: return nil, fmt.Errorf("unknown kind of join: '%s'", o.Style) } - join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs, ectx.resetters) + join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs, b.resetters) if err != nil { return nil, err } return []zbuf.Puller{join}, nil case *dag.Merge: - var ectx exprContext - e, err := b.compileExpr(&ectx, o.Expr) + b.exprReset() + e, err := b.compileExpr(o.Expr) if err != nil { return nil, err } cmp := expr.NewComparator(true, o.Order == order.Desc, e).WithMissingAsNull() - return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare, ectx.resetters)}, nil + return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare, b.resetters)}, nil case *dag.Combine: return []zbuf.Puller{combine.New(b.rctx, parents)}, nil default: @@ -686,7 +689,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) { if in == nil { return zed.Null, nil } - e, err := b.compileExpr(nil, in) + e, err := b.compileExpr(in) if err != nil { return zed.Null, err } @@ -702,7 +705,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) { func compileExpr(in dag.Expr) (expr.Evaluator, error) { b := NewBuilder(runtime.NewContext(context.Background(), zed.NewContext()), nil) - return b.compileExpr(nil, in) + return b.compileExpr(in) } func EvalAtCompileTime(zctx *zed.Context, in dag.Expr) (val zed.Value, err error) { diff --git a/runtime/sam/op/fork/fork.go b/runtime/sam/op/fork/fork.go index 044a9a7338..a55772e5b0 100644 --- a/runtime/sam/op/fork/fork.go +++ b/runtime/sam/op/fork/fork.go @@ -39,3 +39,5 @@ func (s splitter) Forward(r *op.Router, b zbuf.Batch) bool { } return true } + +func (s splitter) Reset() {} diff --git a/runtime/sam/op/router.go b/runtime/sam/op/router.go index 30f3aafa2d..2abf1e4426 100644 --- a/runtime/sam/op/router.go +++ b/runtime/sam/op/router.go @@ -11,6 +11,7 @@ import ( type Selector interface { Forward(*Router, zbuf.Batch) bool + expr.Resetter } type Router struct { @@ -94,11 +95,7 @@ func (r *Router) blocked() bool { // after receiving the EOS, it's done will be captured as soon as we unblock // all channels. func (r *Router) sendEOS(err error) bool { - defer func() { - if r, ok := r.selector.(expr.Resetter); ok { - r.Reset() - } - }() + defer r.selector.Reset() // First, we need to send EOS to all non-blocked legs and // catch any dones in progress. This result in all routes // being blocked.