diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 2b163d577ebfb..fa7c2396ecb77 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -117,6 +117,13 @@ var optRuleList = []logicalOptRule{ &resolveExpand{}, } +// Interaction Rule List +/* The interaction rule will be trigger when it satisfies following conditions: +1. The related rule has been trigger and changed the plan +2. The interaction rule is enabled +*/ +var optInteractionRuleList = map[logicalOptRule]logicalOptRule{} + type logicalOptimizeOp struct { // tracer is goring to track optimize steps during rule optimizing tracer *tracing.LogicalOptimizeTracer @@ -154,7 +161,14 @@ func (op *logicalOptimizeOp) recordFinalLogicalPlan(final LogicalPlan) { // logicalOptRule means a logical optimizing rule, which contains decorrelate, ppd, column pruning, etc. type logicalOptRule interface { - optimize(context.Context, LogicalPlan, *logicalOptimizeOp) (LogicalPlan, error) + /* Return Parameters: + 1. LogicalPlan: The optimized LogicalPlan after rule is applied + 2. bool: Used to judge whether the plan is changed or not by logical rule. + If the plan is changed, it will return true. + The default value is false. It means that no interaction rule will be triggered. + 3. error: If there is error during the rule optimizer, it will be thrown + */ + optimize(context.Context, LogicalPlan, *logicalOptimizeOp) (LogicalPlan, bool, error) name() string } @@ -1125,6 +1139,7 @@ func logicalOptimize(ctx context.Context, flag uint64, logic LogicalPlan) (Logic }() } var err error + var againRuleList []logicalOptRule for i, rule := range optRuleList { // The order of flags is same as the order of optRule in the list. // We use a bitmask to record which opt rules should be used. If the i-th bit is 1, it means we should @@ -1133,11 +1148,27 @@ func logicalOptimize(ctx context.Context, flag uint64, logic LogicalPlan) (Logic continue } opt.appendBeforeRuleOptimize(i, rule.name(), logic) - logic, err = rule.optimize(ctx, logic, opt) + var planChanged bool + logic, planChanged, err = rule.optimize(ctx, logic, opt) + if err != nil { + return nil, err + } + // Compute interaction rules that should be optimized again + interactionRule, ok := optInteractionRuleList[rule] + if planChanged && ok && isLogicalRuleDisabled(interactionRule) { + againRuleList = append(againRuleList, interactionRule) + } + } + + // Trigger the interaction rule + for i, rule := range againRuleList { + opt.appendBeforeRuleOptimize(i, rule.name(), logic) + logic, _, err = rule.optimize(ctx, logic, opt) if err != nil { return nil, err } } + opt.recordFinalLogicalPlan(logic) return logic, err } diff --git a/planner/core/plan_stats.go b/planner/core/plan_stats.go index a650d72b0935b..7b71b67dd2510 100644 --- a/planner/core/plan_stats.go +++ b/planner/core/plan_stats.go @@ -33,9 +33,10 @@ import ( type collectPredicateColumnsPoint struct{} -func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) { +func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false if plan.SCtx().GetSessionVars().InRestrictedSQL { - return plan, nil + return plan, planChanged, nil } predicateNeeded := variable.EnableColumnTracking.Load() syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait * time.Millisecond.Nanoseconds() @@ -45,7 +46,7 @@ func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan plan.SCtx().UpdateColStatsUsage(predicateColumns) } if !histNeeded { - return plan, nil + return plan, planChanged, nil } // Prepare the table metadata to avoid repeatedly fetching from the infoSchema below. @@ -69,9 +70,9 @@ func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan histNeededItems := collectHistNeededItems(histNeededColumns, histNeededIndices) if histNeeded && len(histNeededItems) > 0 { err := RequestLoadStats(plan.SCtx(), histNeededItems, syncWait) - return plan, err + return plan, planChanged, err } - return plan, nil + return plan, planChanged, nil } func (collectPredicateColumnsPoint) name() string { @@ -80,15 +81,16 @@ func (collectPredicateColumnsPoint) name() string { type syncWaitStatsLoadPoint struct{} -func (syncWaitStatsLoadPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) { +func (syncWaitStatsLoadPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false if plan.SCtx().GetSessionVars().InRestrictedSQL { - return plan, nil + return plan, planChanged, nil } if plan.SCtx().GetSessionVars().StmtCtx.IsSyncStatsFailed { - return plan, nil + return plan, planChanged, nil } err := SyncWaitStatsLoad(plan) - return plan, err + return plan, planChanged, err } func (syncWaitStatsLoadPoint) name() string { diff --git a/planner/core/rule_aggregation_elimination.go b/planner/core/rule_aggregation_elimination.go index cf4494c5d6f32..62b9b4ceeb59d 100644 --- a/planner/core/rule_aggregation_elimination.go +++ b/planner/core/rule_aggregation_elimination.go @@ -254,25 +254,26 @@ func wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetT return expression.BuildCastFunction(ctx, arg, targetTp) } -func (a *aggregationEliminator) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (a *aggregationEliminator) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - newChild, err := a.optimize(ctx, child, opt) + newChild, planChanged, err := a.optimize(ctx, child, opt) if err != nil { - return nil, err + return nil, planChanged, err } newChildren = append(newChildren, newChild) } p.SetChildren(newChildren...) agg, ok := p.(*LogicalAggregation) if !ok { - return p, nil + return p, planChanged, nil } a.tryToEliminateDistinct(agg, opt) if proj := a.tryToEliminateAggregation(agg, opt); proj != nil { - return proj, nil + return proj, planChanged, nil } - return p, nil + return p, planChanged, nil } func (*aggregationEliminator) name() string { diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 714e7dae2f126..beaf3377cba12 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -431,8 +431,10 @@ func (*aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, uni return newAgg, nil } -func (a *aggregationPushDownSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { - return a.aggPushDown(p, opt) +func (a *aggregationPushDownSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false + newLogicalPlan, err := a.aggPushDown(p, opt) + return newLogicalPlan, planChanged, err } func (a *aggregationPushDownSolver) tryAggPushDownForUnion(union *LogicalUnionAll, agg *LogicalAggregation, opt *logicalOptimizeOp) error { diff --git a/planner/core/rule_aggregation_skew_rewrite.go b/planner/core/rule_aggregation_skew_rewrite.go index 577533f152bbb..8652d37c43da2 100644 --- a/planner/core/rule_aggregation_skew_rewrite.go +++ b/planner/core/rule_aggregation_skew_rewrite.go @@ -273,24 +273,25 @@ func appendSkewDistinctAggRewriteTraceStep(agg *LogicalAggregation, result Logic opt.appendStepToCurrent(agg.ID(), agg.TP(), reason, action) } -func (a *skewDistinctAggRewriter) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (a *skewDistinctAggRewriter) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - newChild, err := a.optimize(ctx, child, opt) + newChild, planChanged, err := a.optimize(ctx, child, opt) if err != nil { - return nil, err + return nil, planChanged, err } newChildren = append(newChildren, newChild) } p.SetChildren(newChildren...) agg, ok := p.(*LogicalAggregation) if !ok { - return p, nil + return p, planChanged, nil } if newAgg := a.rewriteSkewDistinctAgg(agg, opt); newAgg != nil { - return newAgg, nil + return newAgg, planChanged, nil } - return p, nil + return p, planChanged, nil } func (*skewDistinctAggRewriter) name() string { diff --git a/planner/core/rule_build_key_info.go b/planner/core/rule_build_key_info.go index bd957174c4296..43eac5bd59a59 100644 --- a/planner/core/rule_build_key_info.go +++ b/planner/core/rule_build_key_info.go @@ -25,9 +25,10 @@ import ( type buildKeySolver struct{} -func (*buildKeySolver) optimize(_ context.Context, p LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) { +func (*buildKeySolver) optimize(_ context.Context, p LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false buildKeyInfo(p) - return p, nil + return p, planChanged, nil } // buildKeyInfo recursively calls LogicalPlan's BuildKeyInfo method. diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 2a96eeee1320b..8c51cd677eeac 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -31,9 +31,10 @@ import ( type columnPruner struct { } -func (*columnPruner) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (*columnPruner) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false err := lp.PruneColumns(lp.Schema().Columns, opt) - return lp, err + return lp, planChanged, err } // ExprsHasSideEffects checks if any of the expressions has side effects. diff --git a/planner/core/rule_constant_propagation.go b/planner/core/rule_constant_propagation.go index b4ccd6ca7ca86..25bb3b14ec637 100644 --- a/planner/core/rule_constant_propagation.go +++ b/planner/core/rule_constant_propagation.go @@ -50,7 +50,8 @@ type constantPropagationSolver struct { // which is mainly implemented in the interface "constantPropagation" of LogicalPlan. // Currently only the Logical Join implements this function. (Used for the subquery in FROM List) // In the future, the Logical Apply will implements this function. (Used for the subquery in WHERE or SELECT list) -func (cp *constantPropagationSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (cp *constantPropagationSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false // constant propagation root plan newRoot := p.constantPropagation(nil, 0, opt) @@ -60,9 +61,9 @@ func (cp *constantPropagationSolver) optimize(_ context.Context, p LogicalPlan, } if newRoot == nil { - return p, nil + return p, planChanged, nil } - return newRoot, nil + return newRoot, planChanged, nil } // execOptimize optimize constant propagation exclude root plan node diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index 000a1a4acb291..824cbaa759528 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -193,7 +193,8 @@ func (*decorrelateSolver) aggDefaultValueMap(agg *LogicalAggregation) map[int]*e } // optimize implements logicalOptRule interface. -func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false if apply, ok := p.(*LogicalApply); ok { outerPlan := apply.children[0] innerPlan := apply.children[1] @@ -272,13 +273,13 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo proj.SetSchema(apply.Schema()) proj.Exprs = append(expression.Column2Exprs(outerPlan.Schema().Clone().Columns), proj.Exprs...) apply.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema())) - np, err := s.optimize(ctx, p, opt) + np, planChanged, err := s.optimize(ctx, p, opt) if err != nil { - return nil, err + return nil, planChanged, err } proj.SetChildren(np) appendMoveProjTraceStep(apply, np, proj, opt) - return proj, nil + return proj, planChanged, nil } appendRemoveProjTraceStep(apply, proj, opt) return s.optimize(ctx, p, opt) @@ -313,7 +314,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo for i, col := range outerPlan.Schema().Columns { first, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) if err != nil { - return nil, err + return nil, planChanged, err } newAggFuncs = append(newAggFuncs, first) @@ -343,20 +344,20 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo } desc, err := aggregation.NewAggFuncDesc(agg.SCtx(), agg.AggFuncs[i].Name, aggArgs, agg.AggFuncs[i].HasDistinct) if err != nil { - return nil, err + return nil, planChanged, err } newAggFuncs = append(newAggFuncs, desc) } agg.AggFuncs = newAggFuncs - np, err := s.optimize(ctx, p, opt) + np, planChanged, err := s.optimize(ctx, p, opt) if err != nil { - return nil, err + return nil, planChanged, err } agg.SetChildren(np) appendPullUpAggTraceStep(apply, np, agg, opt) // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. // agg.buildProjectionIfNecessary() - return agg, nil + return agg, planChanged, nil } // We can pull up the equal conditions below the aggregation as the join key of the apply, if only // the equal conditions contain the correlated column of this apply. @@ -391,7 +392,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo if agg.schema.ColumnIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 { newFunc, err := aggregation.NewAggFuncDesc(apply.SCtx(), ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) if err != nil { - return nil, err + return nil, planChanged, err } agg.AggFuncs = append(agg.AggFuncs, newFunc) agg.schema.Append(clonedCol) @@ -444,18 +445,18 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo NoOptimize: // CTE's logical optimization is independent. if _, ok := p.(*LogicalCTE); ok { - return p, nil + return p, planChanged, nil } newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - np, err := s.optimize(ctx, child, opt) + np, planChanged, err := s.optimize(ctx, child, opt) if err != nil { - return nil, err + return nil, planChanged, err } newChildren = append(newChildren, np) } p.SetChildren(newChildren...) - return p, nil + return p, planChanged, nil } func (*decorrelateSolver) name() string { diff --git a/planner/core/rule_derive_topn_from_window.go b/planner/core/rule_derive_topn_from_window.go index 50d7d085ef2cf..5b6b8ee179ee4 100644 --- a/planner/core/rule_derive_topn_from_window.go +++ b/planner/core/rule_derive_topn_from_window.go @@ -116,8 +116,9 @@ func windowIsTopN(p *LogicalSelection) (bool, uint64) { return false, 0 } -func (*deriveTopNFromWindow) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { - return p.deriveTopN(opt), nil +func (*deriveTopNFromWindow) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false + return p.deriveTopN(opt), planChanged, nil } func (s *baseLogicalPlan) deriveTopN(opt *logicalOptimizeOp) LogicalPlan { diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 97b8ade37cb0f..27a23d576a675 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -167,9 +167,10 @@ type projectionEliminator struct { } // optimize implements the logicalOptRule interface. -func (pe *projectionEliminator) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (pe *projectionEliminator) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false root := pe.eliminate(lp, make(map[string]*expression.Column), false, opt) - return root, nil + return root, planChanged, nil } // eliminate eliminates the redundant projection in a logical plan. diff --git a/planner/core/rule_generate_column_substitute.go b/planner/core/rule_generate_column_substitute.go index 0baae3496d850..476a81e871104 100644 --- a/planner/core/rule_generate_column_substitute.go +++ b/planner/core/rule_generate_column_substitute.go @@ -36,13 +36,14 @@ type ExprColumnMap map[expression.Expression]*expression.Column // For example: select a+1 from t order by a+1, with a virtual generate column c as (a+1) and // an index on c. We need to replace a+1 with c so that we can use the index on c. // See also https://dev.mysql.com/doc/refman/8.0/en/generated-column-index-optimizations.html -func (gc *gcSubstituter) optimize(ctx context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (gc *gcSubstituter) optimize(ctx context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false exprToColumn := make(ExprColumnMap) collectGenerateColumn(lp, exprToColumn) if len(exprToColumn) == 0 { - return lp, nil + return lp, planChanged, nil } - return gc.substitute(ctx, lp, exprToColumn, opt), nil + return gc.substitute(ctx, lp, exprToColumn, opt), planChanged, nil } // collectGenerateColumn collect the generate column and save them to a map from their expressions to themselves. diff --git a/planner/core/rule_join_elimination.go b/planner/core/rule_join_elimination.go index 9168842a7dfa1..45a1cb278f1c5 100644 --- a/planner/core/rule_join_elimination.go +++ b/planner/core/rule_join_elimination.go @@ -245,9 +245,10 @@ func (o *outerJoinEliminator) doOptimize(p LogicalPlan, aggCols []*expression.Co return p, nil } -func (o *outerJoinEliminator) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (o *outerJoinEliminator) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false p, err := o.doOptimize(p, nil, nil, opt) - return p, err + return p, planChanged, err } func (*outerJoinEliminator) name() string { diff --git a/planner/core/rule_join_reorder.go b/planner/core/rule_join_reorder.go index 9005424f06570..dc1e949fe95f2 100644 --- a/planner/core/rule_join_reorder.go +++ b/planner/core/rule_join_reorder.go @@ -222,13 +222,14 @@ type joinTypeWithExtMsg struct { outerBindCondition []expression.Expression } -func (s *joinReOrderSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (s *joinReOrderSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false tracer := &joinReorderTrace{cost: map[string]float64{}, opt: opt} tracer.traceJoinReorder(p) p, err := s.optimizeRecursive(p.SCtx(), p, tracer) tracer.traceJoinReorder(p) appendJoinReorderTraceStep(tracer, p, opt) - return p, err + return p, planChanged, err } // optimizeRecursive recursively collects join groups and applies join reorder algorithm for each group. diff --git a/planner/core/rule_max_min_eliminate.go b/planner/core/rule_max_min_eliminate.go index 2ea1a35e4cb2b..fd3a098aea3f3 100644 --- a/planner/core/rule_max_min_eliminate.go +++ b/planner/core/rule_max_min_eliminate.go @@ -36,8 +36,9 @@ import ( type maxMinEliminator struct { } -func (a *maxMinEliminator) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { - return a.eliminateMaxMin(p, opt), nil +func (a *maxMinEliminator) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false + return a.eliminateMaxMin(p, opt), planChanged, nil } // composeAggsByInnerJoin composes the scalar aggregations by cartesianJoin. diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index ad2e6ea5db8da..3a2da51f912ef 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -61,9 +61,10 @@ const FullRange = -1 // partitionProcessor is here because it's easier to prune partition after predicate push down. type partitionProcessor struct{} -func (s *partitionProcessor) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (s *partitionProcessor) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false p, err := s.rewriteDataSource(lp, opt) - return p, err + return p, planChanged, err } func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index d4eb30951deb9..8ef2f5ff325cd 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -42,9 +42,10 @@ type exprPrefixAdder struct { lengths []int } -func (*ppdSolver) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (*ppdSolver) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false _, p := lp.PredicatePushDown(nil, opt) - return p, nil + return p, planChanged, nil } func addSelection(p LogicalPlan, child LogicalPlan, conditions []expression.Expression, chIdx int, opt *logicalOptimizeOp) { diff --git a/planner/core/rule_predicate_simplification.go b/planner/core/rule_predicate_simplification.go index f1342419a1b92..98672bc97575d 100644 --- a/planner/core/rule_predicate_simplification.go +++ b/planner/core/rule_predicate_simplification.go @@ -65,8 +65,9 @@ func findPredicateType(expr expression.Expression) (*expression.Column, predicat return nil, otherPredicate } -func (*predicateSimplification) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { - return p.predicateSimplification(opt), nil +func (*predicateSimplification) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false + return p.predicateSimplification(opt), planChanged, nil } func (s *baseLogicalPlan) predicateSimplification(opt *logicalOptimizeOp) LogicalPlan { diff --git a/planner/core/rule_push_down_sequence.go b/planner/core/rule_push_down_sequence.go index c1d7ac5f44a42..70d150c674d91 100644 --- a/planner/core/rule_push_down_sequence.go +++ b/planner/core/rule_push_down_sequence.go @@ -23,8 +23,9 @@ func (*pushDownSequenceSolver) name() string { return "push_down_sequence" } -func (pdss *pushDownSequenceSolver) optimize(_ context.Context, lp LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) { - return pdss.recursiveOptimize(nil, lp), nil +func (pdss *pushDownSequenceSolver) optimize(_ context.Context, lp LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false + return pdss.recursiveOptimize(nil, lp), planChanged, nil } func (pdss *pushDownSequenceSolver) recursiveOptimize(pushedSequence *LogicalSequence, lp LogicalPlan) LogicalPlan { diff --git a/planner/core/rule_resolve_grouping_expand.go b/planner/core/rule_resolve_grouping_expand.go index c0074be5deed1..cfd3a89939793 100644 --- a/planner/core/rule_resolve_grouping_expand.go +++ b/planner/core/rule_resolve_grouping_expand.go @@ -72,10 +72,12 @@ type resolveExpand struct { // (upper required) (grouping sets columns appended) // // Expand operator itself is kind like a projection, while difference is that it has a multi projection list, named as leveled projection. -func (*resolveExpand) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { +func (*resolveExpand) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false // As you see, Expand's leveled projection should be built after all column-prune is done. So we just make generating-leveled-projection // as the last rule of logical optimization, which is more clear. (spark has column prune action before building expand) - return genExpand(p, opt) + newLogicalPlan, err := genExpand(p, opt) + return newLogicalPlan, planChanged, err } func (*resolveExpand) name() string { diff --git a/planner/core/rule_result_reorder.go b/planner/core/rule_result_reorder.go index afb17b487d123..0b207ad0be594 100644 --- a/planner/core/rule_result_reorder.go +++ b/planner/core/rule_result_reorder.go @@ -39,12 +39,13 @@ This rule reorders results by modifying or injecting a Sort operator: type resultReorder struct { } -func (rs *resultReorder) optimize(_ context.Context, lp LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) { +func (rs *resultReorder) optimize(_ context.Context, lp LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false ordered := rs.completeSort(lp) if !ordered { lp = rs.injectSort(lp) } - return lp, nil + return lp, planChanged, nil } func (rs *resultReorder) completeSort(lp LogicalPlan) bool { diff --git a/planner/core/rule_semi_join_rewrite.go b/planner/core/rule_semi_join_rewrite.go index 25e4f624ddfe5..db2737cd30706 100644 --- a/planner/core/rule_semi_join_rewrite.go +++ b/planner/core/rule_semi_join_rewrite.go @@ -25,8 +25,10 @@ import ( type semiJoinRewriter struct { } -func (smj *semiJoinRewriter) optimize(_ context.Context, p LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) { - return smj.recursivePlan(p) +func (smj *semiJoinRewriter) optimize(_ context.Context, p LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false + newLogicalPlan, err := smj.recursivePlan(p) + return newLogicalPlan, planChanged, err } func (*semiJoinRewriter) name() string { diff --git a/planner/core/rule_topn_push_down.go b/planner/core/rule_topn_push_down.go index 3760d690284a0..d9b94ec124e50 100644 --- a/planner/core/rule_topn_push_down.go +++ b/planner/core/rule_topn_push_down.go @@ -28,8 +28,9 @@ import ( type pushDownTopNOptimizer struct { } -func (*pushDownTopNOptimizer) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) { - return p.pushDownTopN(nil, opt), nil +func (*pushDownTopNOptimizer) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) { + planChanged := false + return p.pushDownTopN(nil, opt), planChanged, nil } func (s *baseLogicalPlan) pushDownTopN(topN *LogicalTopN, opt *logicalOptimizeOp) LogicalPlan {