Skip to content

Commit

Permalink
planner: Refactor framework of rule engine.(Support interaction rule …
Browse files Browse the repository at this point in the history
…of new rule engine ) (#47526)

close #43360
  • Loading branch information
elsa0520 authored Oct 13, 2023
1 parent 08fbb42 commit 93a834a
Show file tree
Hide file tree
Showing 23 changed files with 132 additions and 75 deletions.
35 changes: 33 additions & 2 deletions planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ var optRuleList = []logicalOptRule{
&resolveExpand{},
}

// Interaction Rule List
/* The interaction rule will be trigger when it satisfies following conditions:
1. The related rule has been trigger and changed the plan
2. The interaction rule is enabled
*/
var optInteractionRuleList = map[logicalOptRule]logicalOptRule{}

type logicalOptimizeOp struct {
// tracer is goring to track optimize steps during rule optimizing
tracer *tracing.LogicalOptimizeTracer
Expand Down Expand Up @@ -154,7 +161,14 @@ func (op *logicalOptimizeOp) recordFinalLogicalPlan(final LogicalPlan) {

// logicalOptRule means a logical optimizing rule, which contains decorrelate, ppd, column pruning, etc.
type logicalOptRule interface {
optimize(context.Context, LogicalPlan, *logicalOptimizeOp) (LogicalPlan, error)
/* Return Parameters:
1. LogicalPlan: The optimized LogicalPlan after rule is applied
2. bool: Used to judge whether the plan is changed or not by logical rule.
If the plan is changed, it will return true.
The default value is false. It means that no interaction rule will be triggered.
3. error: If there is error during the rule optimizer, it will be thrown
*/
optimize(context.Context, LogicalPlan, *logicalOptimizeOp) (LogicalPlan, bool, error)
name() string
}

Expand Down Expand Up @@ -1125,6 +1139,7 @@ func logicalOptimize(ctx context.Context, flag uint64, logic LogicalPlan) (Logic
}()
}
var err error
var againRuleList []logicalOptRule
for i, rule := range optRuleList {
// The order of flags is same as the order of optRule in the list.
// We use a bitmask to record which opt rules should be used. If the i-th bit is 1, it means we should
Expand All @@ -1133,11 +1148,27 @@ func logicalOptimize(ctx context.Context, flag uint64, logic LogicalPlan) (Logic
continue
}
opt.appendBeforeRuleOptimize(i, rule.name(), logic)
logic, err = rule.optimize(ctx, logic, opt)
var planChanged bool
logic, planChanged, err = rule.optimize(ctx, logic, opt)
if err != nil {
return nil, err
}
// Compute interaction rules that should be optimized again
interactionRule, ok := optInteractionRuleList[rule]
if planChanged && ok && isLogicalRuleDisabled(interactionRule) {
againRuleList = append(againRuleList, interactionRule)
}
}

// Trigger the interaction rule
for i, rule := range againRuleList {
opt.appendBeforeRuleOptimize(i, rule.name(), logic)
logic, _, err = rule.optimize(ctx, logic, opt)
if err != nil {
return nil, err
}
}

opt.recordFinalLogicalPlan(logic)
return logic, err
}
Expand Down
20 changes: 11 additions & 9 deletions planner/core/plan_stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ import (

type collectPredicateColumnsPoint struct{}

func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) {
func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
if plan.SCtx().GetSessionVars().InRestrictedSQL {
return plan, nil
return plan, planChanged, nil
}
predicateNeeded := variable.EnableColumnTracking.Load()
syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait * time.Millisecond.Nanoseconds()
Expand All @@ -45,7 +46,7 @@ func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan
plan.SCtx().UpdateColStatsUsage(predicateColumns)
}
if !histNeeded {
return plan, nil
return plan, planChanged, nil
}

// Prepare the table metadata to avoid repeatedly fetching from the infoSchema below.
Expand All @@ -69,9 +70,9 @@ func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan
histNeededItems := collectHistNeededItems(histNeededColumns, histNeededIndices)
if histNeeded && len(histNeededItems) > 0 {
err := RequestLoadStats(plan.SCtx(), histNeededItems, syncWait)
return plan, err
return plan, planChanged, err
}
return plan, nil
return plan, planChanged, nil
}

func (collectPredicateColumnsPoint) name() string {
Expand All @@ -80,15 +81,16 @@ func (collectPredicateColumnsPoint) name() string {

type syncWaitStatsLoadPoint struct{}

func (syncWaitStatsLoadPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) {
func (syncWaitStatsLoadPoint) optimize(_ context.Context, plan LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
if plan.SCtx().GetSessionVars().InRestrictedSQL {
return plan, nil
return plan, planChanged, nil
}
if plan.SCtx().GetSessionVars().StmtCtx.IsSyncStatsFailed {
return plan, nil
return plan, planChanged, nil
}
err := SyncWaitStatsLoad(plan)
return plan, err
return plan, planChanged, err
}

func (syncWaitStatsLoadPoint) name() string {
Expand Down
13 changes: 7 additions & 6 deletions planner/core/rule_aggregation_elimination.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,25 +254,26 @@ func wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetT
return expression.BuildCastFunction(ctx, arg, targetTp)
}

func (a *aggregationEliminator) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (a *aggregationEliminator) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
newChildren := make([]LogicalPlan, 0, len(p.Children()))
for _, child := range p.Children() {
newChild, err := a.optimize(ctx, child, opt)
newChild, planChanged, err := a.optimize(ctx, child, opt)
if err != nil {
return nil, err
return nil, planChanged, err
}
newChildren = append(newChildren, newChild)
}
p.SetChildren(newChildren...)
agg, ok := p.(*LogicalAggregation)
if !ok {
return p, nil
return p, planChanged, nil
}
a.tryToEliminateDistinct(agg, opt)
if proj := a.tryToEliminateAggregation(agg, opt); proj != nil {
return proj, nil
return proj, planChanged, nil
}
return p, nil
return p, planChanged, nil
}

func (*aggregationEliminator) name() string {
Expand Down
6 changes: 4 additions & 2 deletions planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,10 @@ func (*aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, uni
return newAgg, nil
}

func (a *aggregationPushDownSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
return a.aggPushDown(p, opt)
func (a *aggregationPushDownSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
newLogicalPlan, err := a.aggPushDown(p, opt)
return newLogicalPlan, planChanged, err
}

func (a *aggregationPushDownSolver) tryAggPushDownForUnion(union *LogicalUnionAll, agg *LogicalAggregation, opt *logicalOptimizeOp) error {
Expand Down
13 changes: 7 additions & 6 deletions planner/core/rule_aggregation_skew_rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,24 +273,25 @@ func appendSkewDistinctAggRewriteTraceStep(agg *LogicalAggregation, result Logic
opt.appendStepToCurrent(agg.ID(), agg.TP(), reason, action)
}

func (a *skewDistinctAggRewriter) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (a *skewDistinctAggRewriter) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
newChildren := make([]LogicalPlan, 0, len(p.Children()))
for _, child := range p.Children() {
newChild, err := a.optimize(ctx, child, opt)
newChild, planChanged, err := a.optimize(ctx, child, opt)
if err != nil {
return nil, err
return nil, planChanged, err
}
newChildren = append(newChildren, newChild)
}
p.SetChildren(newChildren...)
agg, ok := p.(*LogicalAggregation)
if !ok {
return p, nil
return p, planChanged, nil
}
if newAgg := a.rewriteSkewDistinctAgg(agg, opt); newAgg != nil {
return newAgg, nil
return newAgg, planChanged, nil
}
return p, nil
return p, planChanged, nil
}

func (*skewDistinctAggRewriter) name() string {
Expand Down
5 changes: 3 additions & 2 deletions planner/core/rule_build_key_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import (

type buildKeySolver struct{}

func (*buildKeySolver) optimize(_ context.Context, p LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, error) {
func (*buildKeySolver) optimize(_ context.Context, p LogicalPlan, _ *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
buildKeyInfo(p)
return p, nil
return p, planChanged, nil
}

// buildKeyInfo recursively calls LogicalPlan's BuildKeyInfo method.
Expand Down
5 changes: 3 additions & 2 deletions planner/core/rule_column_pruning.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ import (
type columnPruner struct {
}

func (*columnPruner) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (*columnPruner) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
err := lp.PruneColumns(lp.Schema().Columns, opt)
return lp, err
return lp, planChanged, err
}

// ExprsHasSideEffects checks if any of the expressions has side effects.
Expand Down
7 changes: 4 additions & 3 deletions planner/core/rule_constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ type constantPropagationSolver struct {
// which is mainly implemented in the interface "constantPropagation" of LogicalPlan.
// Currently only the Logical Join implements this function. (Used for the subquery in FROM List)
// In the future, the Logical Apply will implements this function. (Used for the subquery in WHERE or SELECT list)
func (cp *constantPropagationSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (cp *constantPropagationSolver) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
// constant propagation root plan
newRoot := p.constantPropagation(nil, 0, opt)

Expand All @@ -60,9 +61,9 @@ func (cp *constantPropagationSolver) optimize(_ context.Context, p LogicalPlan,
}

if newRoot == nil {
return p, nil
return p, planChanged, nil
}
return newRoot, nil
return newRoot, planChanged, nil
}

// execOptimize optimize constant propagation exclude root plan node
Expand Down
29 changes: 15 additions & 14 deletions planner/core/rule_decorrelate.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ func (*decorrelateSolver) aggDefaultValueMap(agg *LogicalAggregation) map[int]*e
}

// optimize implements logicalOptRule interface.
func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
if apply, ok := p.(*LogicalApply); ok {
outerPlan := apply.children[0]
innerPlan := apply.children[1]
Expand Down Expand Up @@ -272,13 +273,13 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
proj.SetSchema(apply.Schema())
proj.Exprs = append(expression.Column2Exprs(outerPlan.Schema().Clone().Columns), proj.Exprs...)
apply.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema()))
np, err := s.optimize(ctx, p, opt)
np, planChanged, err := s.optimize(ctx, p, opt)
if err != nil {
return nil, err
return nil, planChanged, err
}
proj.SetChildren(np)
appendMoveProjTraceStep(apply, np, proj, opt)
return proj, nil
return proj, planChanged, nil
}
appendRemoveProjTraceStep(apply, proj, opt)
return s.optimize(ctx, p, opt)
Expand Down Expand Up @@ -313,7 +314,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
for i, col := range outerPlan.Schema().Columns {
first, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false)
if err != nil {
return nil, err
return nil, planChanged, err
}
newAggFuncs = append(newAggFuncs, first)

Expand Down Expand Up @@ -343,20 +344,20 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
}
desc, err := aggregation.NewAggFuncDesc(agg.SCtx(), agg.AggFuncs[i].Name, aggArgs, agg.AggFuncs[i].HasDistinct)
if err != nil {
return nil, err
return nil, planChanged, err
}
newAggFuncs = append(newAggFuncs, desc)
}
agg.AggFuncs = newAggFuncs
np, err := s.optimize(ctx, p, opt)
np, planChanged, err := s.optimize(ctx, p, opt)
if err != nil {
return nil, err
return nil, planChanged, err
}
agg.SetChildren(np)
appendPullUpAggTraceStep(apply, np, agg, opt)
// TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions.
// agg.buildProjectionIfNecessary()
return agg, nil
return agg, planChanged, nil
}
// We can pull up the equal conditions below the aggregation as the join key of the apply, if only
// the equal conditions contain the correlated column of this apply.
Expand Down Expand Up @@ -391,7 +392,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
if agg.schema.ColumnIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 {
newFunc, err := aggregation.NewAggFuncDesc(apply.SCtx(), ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false)
if err != nil {
return nil, err
return nil, planChanged, err
}
agg.AggFuncs = append(agg.AggFuncs, newFunc)
agg.schema.Append(clonedCol)
Expand Down Expand Up @@ -444,18 +445,18 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
NoOptimize:
// CTE's logical optimization is independent.
if _, ok := p.(*LogicalCTE); ok {
return p, nil
return p, planChanged, nil
}
newChildren := make([]LogicalPlan, 0, len(p.Children()))
for _, child := range p.Children() {
np, err := s.optimize(ctx, child, opt)
np, planChanged, err := s.optimize(ctx, child, opt)
if err != nil {
return nil, err
return nil, planChanged, err
}
newChildren = append(newChildren, np)
}
p.SetChildren(newChildren...)
return p, nil
return p, planChanged, nil
}

func (*decorrelateSolver) name() string {
Expand Down
5 changes: 3 additions & 2 deletions planner/core/rule_derive_topn_from_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ func windowIsTopN(p *LogicalSelection) (bool, uint64) {
return false, 0
}

func (*deriveTopNFromWindow) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
return p.deriveTopN(opt), nil
func (*deriveTopNFromWindow) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
return p.deriveTopN(opt), planChanged, nil
}

func (s *baseLogicalPlan) deriveTopN(opt *logicalOptimizeOp) LogicalPlan {
Expand Down
5 changes: 3 additions & 2 deletions planner/core/rule_eliminate_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ type projectionEliminator struct {
}

// optimize implements the logicalOptRule interface.
func (pe *projectionEliminator) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (pe *projectionEliminator) optimize(_ context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
root := pe.eliminate(lp, make(map[string]*expression.Column), false, opt)
return root, nil
return root, planChanged, nil
}

// eliminate eliminates the redundant projection in a logical plan.
Expand Down
7 changes: 4 additions & 3 deletions planner/core/rule_generate_column_substitute.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ type ExprColumnMap map[expression.Expression]*expression.Column
// For example: select a+1 from t order by a+1, with a virtual generate column c as (a+1) and
// an index on c. We need to replace a+1 with c so that we can use the index on c.
// See also https://dev.mysql.com/doc/refman/8.0/en/generated-column-index-optimizations.html
func (gc *gcSubstituter) optimize(ctx context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (gc *gcSubstituter) optimize(ctx context.Context, lp LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
exprToColumn := make(ExprColumnMap)
collectGenerateColumn(lp, exprToColumn)
if len(exprToColumn) == 0 {
return lp, nil
return lp, planChanged, nil
}
return gc.substitute(ctx, lp, exprToColumn, opt), nil
return gc.substitute(ctx, lp, exprToColumn, opt), planChanged, nil
}

// collectGenerateColumn collect the generate column and save them to a map from their expressions to themselves.
Expand Down
5 changes: 3 additions & 2 deletions planner/core/rule_join_elimination.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ func (o *outerJoinEliminator) doOptimize(p LogicalPlan, aggCols []*expression.Co
return p, nil
}

func (o *outerJoinEliminator) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
func (o *outerJoinEliminator) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, bool, error) {
planChanged := false
p, err := o.doOptimize(p, nil, nil, opt)
return p, err
return p, planChanged, err
}

func (*outerJoinEliminator) name() string {
Expand Down
Loading

0 comments on commit 93a834a

Please sign in to comment.