From 64b769365d7eaf44c2f9ae80c287abd7f751aead Mon Sep 17 00:00:00 2001 From: Roman Andriushchenko Date: Thu, 10 Oct 2024 12:56:49 +0200 Subject: [PATCH] DT synthesis: faster formulae creation, fix of various bugs --- paynt/parser/sketch.py | 5 +- paynt/quotient/mdp.py | 9 +-- paynt/quotient/quotient.py | 8 ++- paynt/synthesizer/decision_tree.py | 22 ++++--- paynt/synthesizer/synthesizer_ar.py | 2 +- paynt/verification/property.py | 4 ++ .../src/synthesis/quotient/ColoringSmt.cpp | 15 +++-- .../src/synthesis/quotient/ColoringSmt.h | 6 +- payntbind/src/synthesis/quotient/TreeNode.cpp | 66 ++++++++++++------- payntbind/src/synthesis/quotient/TreeNode.h | 24 +++++-- 10 files changed, 106 insertions(+), 55 deletions(-) diff --git a/paynt/parser/sketch.py b/paynt/parser/sketch.py index 27a8e528..44478833 100644 --- a/paynt/parser/sketch.py +++ b/paynt/parser/sketch.py @@ -88,13 +88,14 @@ def load_sketch(cls, sketch_path, properties_path, specification = PrismParser.parse_specification(properties_path, relative_error) filetype = "drn" project_path = os.path.dirname(sketch_path) - valuations_path = project_path + "/state_valuations.json" + valuations_filename = "state-valuations.json" + valuations_path = project_path + "/" + valuations_filename state_valuations = None if os.path.exists(valuations_path) and os.path.isfile(valuations_path): with open(valuations_path) as file: state_valuations = json.load(file) if state_valuations is not None: - logger.info(f"found state_valuations.json, adding to the model...") + logger.info(f"found state valuations in {valuations_path}, adding to the model...") explicit_quotient = payntbind.synthesis.addStateValuations(explicit_quotient,state_valuations) except Exception as e: print(e) diff --git a/paynt/quotient/mdp.py b/paynt/quotient/mdp.py index 5b58f04d..53f4b24f 100644 --- a/paynt/quotient/mdp.py +++ b/paynt/quotient/mdp.py @@ -181,8 +181,8 @@ def to_graphviz(self, graphviz_tree, variables, action_labels): graphviz_tree.node(self.graphviz_id, label=node_label, shape="box", style="rounded", margin="0.05,0.05") if not self.is_terminal: - graphviz_tree.edge(self.graphviz_id,self.child_true.graphviz_id,label="True") - graphviz_tree.edge(self.graphviz_id,self.child_false.graphviz_id,label="False") + graphviz_tree.edge(self.graphviz_id,self.child_true.graphviz_id,label="T") + graphviz_tree.edge(self.graphviz_id,self.child_false.graphviz_id,label="F") @@ -235,8 +235,9 @@ def to_list(self): node_info[node.identifier] = (parent,child_true,child_false) return node_info - def simplify(self): - self.root.simplify(self.variables, self.state_valuations) + def simplify(self, target_state_mask): + state_valuations = [self.state_valuations[state] for state in ~target_state_mask] + self.root.simplify(self.variables, state_valuations) def to_string(self): return self.root.to_string(self.variables,self.quotient.action_labels) diff --git a/paynt/quotient/quotient.py b/paynt/quotient/quotient.py index 8f22931f..f1a4d79c 100644 --- a/paynt/quotient/quotient.py +++ b/paynt/quotient/quotient.py @@ -328,6 +328,12 @@ def identify_absorbing_states(self, model): break return state_is_absorbing - def identify_target_states(self, model, prop): + def identify_target_states(self, model=None, prop=None): + if model is None: + model = self.quotient_mdp + if prop is None: + prop = self.get_property() + if prop.is_discounted_reward: + return stormpy.BitVector(model.nr_states,False) target_label = prop.get_target_label() return model.labeling.get_states(target_label) diff --git a/paynt/synthesizer/decision_tree.py b/paynt/synthesizer/decision_tree.py index 03c48357..f8720c38 100644 --- a/paynt/synthesizer/decision_tree.py +++ b/paynt/synthesizer/decision_tree.py @@ -86,7 +86,9 @@ def verify_family(self, family): self.check_specification_for_mdp(family) if not family.analysis_result.can_improve: return - # self.harmonize_inconsistent_scheduler(family) + if SynthesizerDecisionTree.scheduler_path is not None: + return + self.harmonize_inconsistent_scheduler(family) def counters_reset(self): @@ -134,7 +136,7 @@ def synthesize_tree_sequence(self, opt_result_value): self.best_tree = self.best_tree_value = None global_timeout = paynt.utils.timer.GlobalTimer.global_timer.time_limit_seconds - if global_timeout is None: global_timeout = 1800 + if global_timeout is None: global_timeout = 900 depth_timeout = global_timeout / 2 / SynthesizerDecisionTree.tree_depth for depth in range(SynthesizerDecisionTree.tree_depth+1): print() @@ -257,18 +259,20 @@ def run(self, optimum_threshold=None): if self.best_tree is None: logger.info("no admissible tree found") else: - self.best_tree.simplify() - logger.info(f"printing the synthesized tree below:") - print(self.best_tree.to_string()) - + target_states = self.quotient.identify_target_states() + self.best_tree.simplify(target_states) depth = self.best_tree.get_depth() + num_nodes = len(self.best_tree.collect_nonterminals()) + logger.info(f"synthesized tree of depth {depth} with {num_nodes} decision nodes") if self.quotient.specification.has_optimality: logger.info(f"the synthesized tree has value {self.best_tree_value}") - num_nodes = len(self.best_tree.collect_nonterminals()) - logger.info(f"the synthesized tree is of depth {depth} and has {num_nodes} decision nodes") + logger.info(f"printing the synthesized tree below:") + print(self.best_tree.to_string()) + if self.export_synthesis_filename_base is not None: self.export_decision_tree(self.best_tree, self.export_synthesis_filename_base) - time_total = paynt.utils.timer.GlobalTimer.read() + time_total = round(paynt.utils.timer.GlobalTimer.read(),2) + logger.info(f"synthesis finished after {time_total} seconds") # print() # for name,time in self.quotient.coloring.getProfilingInfo(): diff --git a/paynt/synthesizer/synthesizer_ar.py b/paynt/synthesizer/synthesizer_ar.py index 1af7788c..f09a412d 100644 --- a/paynt/synthesizer/synthesizer_ar.py +++ b/paynt/synthesizer/synthesizer_ar.py @@ -100,7 +100,7 @@ def update_optimum(self, family): self.quotient.specification.optimality.update_optimum(iv) self.best_assignment = ia self.best_assignment_value = iv - # logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds") + logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds") if isinstance(self.quotient, paynt.quotient.pomdp.PomdpQuotient): self.stat.new_fsc_found(family.analysis_result.improving_value, ia, self.quotient.policy_size(ia)) diff --git a/paynt/verification/property.py b/paynt/verification/property.py index a0523269..553646eb 100644 --- a/paynt/verification/property.py +++ b/paynt/verification/property.py @@ -124,6 +124,10 @@ def __str__(self): def reward(self): return self.formula.is_reward_operator + @property + def is_discounted_reward(self): + return self.formula.is_reward_operator and self.formula.subformula.is_discounted_total_reward_formula + @property def maximizing(self): return not self.minimizing diff --git a/payntbind/src/synthesis/quotient/ColoringSmt.cpp b/payntbind/src/synthesis/quotient/ColoringSmt.cpp index d64d9ec1..3c9723a7 100644 --- a/payntbind/src/synthesis/quotient/ColoringSmt.cpp +++ b/payntbind/src/synthesis/quotient/ColoringSmt.cpp @@ -137,20 +137,20 @@ ColoringSmt::ColoringSmt( std::vector state_path_expression; for(uint64_t state = 0; state < numStates(); ++state) { + getRoot()->createPrefixSubstitutions(state_valuation[state]); state_path_expression.push_back(z3::expr_vector(ctx)); for(uint64_t path = 0; path < numPaths(); ++path) { - z3::expr_vector substituted(ctx); - // getRoot()->substitutePrefixExpression(getRoot()->paths[path], state_substitution_expr[state], substituted); - getRoot()->substitutePrefixExpression(getRoot()->paths[path], state_valuation[state], substituted); - state_path_expression[state].push_back(z3::mk_or(substituted)); + z3::expr_vector evaluated(ctx); + getRoot()->substitutePrefixExpression(getRoot()->paths[path], evaluated); + state_path_expression[state].push_back(z3::mk_or(evaluated)); } } std::vector action_path_expression; for(uint64_t action = 0; action < num_actions; ++action) { action_path_expression.push_back(z3::expr_vector(ctx)); for(uint64_t path = 0; path < numPaths(); ++path) { - z3::expr substituted = getRoot()->substituteActionExpression(getRoot()->paths[path], action); - action_path_expression[action].push_back(substituted); + z3::expr evaluated = getRoot()->substituteActionExpression(getRoot()->paths[path], action); + action_path_expression[action].push_back(evaluated); } } @@ -173,10 +173,11 @@ ColoringSmt::ColoringSmt( timers["ColoringSmt:: create harmonizing variants"].start(); std::vector state_path_expression_harmonizing; for(uint64_t state = 0; state < numStates(); ++state) { + getRoot()->createPrefixSubstitutionsHarmonizing(state_substitution_expr[state]); state_path_expression_harmonizing.push_back(z3::expr_vector(ctx)); for(uint64_t path = 0; path < numPaths(); ++path) { z3::expr_vector evaluated(ctx); - getRoot()->substitutePrefixExpressionHarmonizing(getRoot()->paths[path], state_substitution_expr[state], evaluated); + getRoot()->substitutePrefixExpressionHarmonizing(getRoot()->paths[path], evaluated); state_path_expression_harmonizing[state].push_back(z3::mk_or(evaluated)); } } diff --git a/payntbind/src/synthesis/quotient/ColoringSmt.h b/payntbind/src/synthesis/quotient/ColoringSmt.h index 88653b94..d1a38237 100644 --- a/payntbind/src/synthesis/quotient/ColoringSmt.h +++ b/payntbind/src/synthesis/quotient/ColoringSmt.h @@ -65,6 +65,9 @@ class ColoringSmt { protected: + /** If true, the object will be setup for one consistency check. */ + bool disable_counterexamples; + /** The initial state. */ const uint64_t initial_state; /** Valuation of each state. */ @@ -133,9 +136,6 @@ class ColoringSmt { bool PRINT_UNSAT_CORE = false; void loadUnsatCore(z3::expr_vector const& unsat_core_expr, Family const& subfamily); - /** If true, the object will be setup for one consistency check. */ - bool disable_counterexamples; - }; } \ No newline at end of file diff --git a/payntbind/src/synthesis/quotient/TreeNode.cpp b/payntbind/src/synthesis/quotient/TreeNode.cpp index ea68af24..c351c8dd 100644 --- a/payntbind/src/synthesis/quotient/TreeNode.cpp +++ b/payntbind/src/synthesis/quotient/TreeNode.cpp @@ -164,8 +164,11 @@ uint64_t TerminalNode::getPathActionHole(std::vector const& path) { return action_hole.hole; } +void TerminalNode::createPrefixSubstitutions(std::vector const& state_valuation) { + // +} -void TerminalNode::substitutePrefixExpression(std::vector const& path, std::vector const& state_valuation, z3::expr_vector & substituted) const { +void TerminalNode::substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const { // } @@ -173,7 +176,11 @@ z3::expr TerminalNode::substituteActionExpression(std::vector const& path, return action_hole.solver_variable == (int)action; } -void TerminalNode::substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const { +void TerminalNode::createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) { + // +} + +void TerminalNode::substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const { // } @@ -227,7 +234,7 @@ InnerNode::InnerNode( z3::expr_vector const& state_substitution_variables ) : TreeNode(identifier,ctx,variable_name,variable_domain), decision_hole(false,ctx), state_substitution_variables(state_substitution_variables), - step_true(ctx), step_false(ctx), step_true_harm(ctx), step_false_harm(ctx) {} + step_true(ctx), step_false(ctx), substituted_true(ctx), substituted_false(ctx), step_true_harm(ctx), step_false_harm(ctx) {} void InnerNode::createHoles(Family& family) { decision_hole.hole = family.addHole(numVariables()); @@ -374,43 +381,59 @@ uint64_t InnerNode::getPathActionHole(std::vector const& path) { } -void InnerNode::substitutePrefixExpression(std::vector const& path, std::vector const& state_valuation, z3::expr_vector & substituted) const { - bool step_to_true_child = path[depth]; - // z3::expr step = step_to_true_child ? step_true : step_false; - // substituted.push_back(step.substitute(state_substitution_variables,state_valuation)); - - z3::expr_vector step_options(ctx); +void InnerNode::createPrefixSubstitutions(std::vector const& state_valuation) { + z3::expr_vector step_options_true(ctx); + z3::expr_vector step_options_false(ctx); z3::expr const& dv = decision_hole.solver_variable; for(uint64_t variable = 0; variable < numVariables(); ++variable) { z3::expr const& vv = variable_hole[variable].solver_variable; // mind the negation below - if(step_to_true_child) { - if(state_valuation[variable] > 0) { - // not (Vi = vj => sj<=xj) - step_options.push_back( dv == ctx.int_val(variable) and not(ctx.int_val(state_valuation[variable]) <= vv)); - } - } else { - if(state_valuation[variable] < variable_domain[variable].size()-1) { - step_options.push_back( dv == ctx.int_val(variable) and not(ctx.int_val(state_valuation[variable]) > vv)); - } + if(state_valuation[variable] > 0) { + // not (Vi = vj => sj<=xj) + step_options_true.push_back( dv == ctx.int_val(variable) and not(ctx.int_val(state_valuation[variable]) <= vv)); + } + if(state_valuation[variable] < variable_domain[variable].size()-1) { + step_options_false.push_back( dv == ctx.int_val(variable) and not(ctx.int_val(state_valuation[variable]) > vv)); } } - substituted.push_back(z3::mk_or(step_options)); + this->substituted_true = z3::mk_or(step_options_true); + this->substituted_false = z3::mk_or(step_options_false); + child_true->createPrefixSubstitutions(state_valuation); + child_false->createPrefixSubstitutions(state_valuation); +} - getChild(step_to_true_child)->substitutePrefixExpression(path,state_valuation,substituted); +void InnerNode::substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const { + bool step_to_true_child = path[depth]; + substituted.push_back(step_to_true_child ? substituted_true : substituted_false); + getChild(step_to_true_child)->substitutePrefixExpression(path,substituted); } z3::expr InnerNode::substituteActionExpression(std::vector const& path, uint64_t action) const { return getChild(path[depth])->substituteActionExpression(path,action); } -void InnerNode::substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const { +/*void InnerNode::substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const { bool step_to_true_child = path[depth]; z3::expr step = step_to_true_child ? step_true_harm : step_false_harm; substituted.push_back(step.substitute(state_substitution_variables,state_valuation)); getChild(step_to_true_child)->substitutePrefixExpressionHarmonizing(path,state_valuation,substituted); +}*/ + +void InnerNode::createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) { + this->substituted_true = step_true_harm.substitute(state_substitution_variables,state_valuation); + this->substituted_false = step_false_harm.substitute(state_substitution_variables,state_valuation); + child_true->createPrefixSubstitutionsHarmonizing(state_valuation); + child_false->createPrefixSubstitutionsHarmonizing(state_valuation); +} + +void InnerNode::substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const { + bool step_to_true_child = path[depth]; + substituted.push_back(step_to_true_child ? substituted_true : substituted_false); + getChild(step_to_true_child)->substitutePrefixExpressionHarmonizing(path,substituted); } + + z3::expr InnerNode::substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const { return getChild(path[depth])->substituteActionExpressionHarmonizing(path,action,harmonizing_variable); } @@ -430,7 +453,6 @@ bool InnerNode::isPathEnabledInState( ) const { bool step_to_true_child = path[depth]; for(uint64_t variable = 0; variable < numVariables(); ++variable) { - z3::expr const& dv = decision_hole.solver_variable; if(not subfamily.holeContains(decision_hole.hole,variable)) { continue; } diff --git a/payntbind/src/synthesis/quotient/TreeNode.h b/payntbind/src/synthesis/quotient/TreeNode.h index 9dc378fc..71860682 100644 --- a/payntbind/src/synthesis/quotient/TreeNode.h +++ b/payntbind/src/synthesis/quotient/TreeNode.h @@ -118,11 +118,14 @@ class TreeNode { virtual uint64_t getPathActionHole(std::vector const& path) {return 0;} /** Add a step expression evaluated for a given state valuation. */ - virtual void substitutePrefixExpression(std::vector const& path, std::vector const& state_valuation, z3::expr_vector & substituted) const {}; + virtual void createPrefixSubstitutions(std::vector const& state_valuation) {}; + virtual void substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const {}; /** Add an action expression evaluated for a given state valuation. */ virtual z3::expr substituteActionExpression(std::vector const& path, uint64_t action) const {return z3::expr(ctx);}; + /** Add a step expression evaluated for a given state valuation (harmonizing). */ - virtual void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const {}; + virtual void createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) {}; + virtual void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const {}; /** Add an action expression evaluated for a given state valuation (harmonizing). */ virtual z3::expr substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const {return z3::expr(ctx);}; @@ -175,9 +178,12 @@ class TerminalNode: public TreeNode { void createPaths(z3::expr const& harmonizing_variable) override; uint64_t getPathActionHole(std::vector const& path); - void substitutePrefixExpression(std::vector const& path, std::vector const& state_valuation, z3::expr_vector & substituted) const override; + void createPrefixSubstitutions(std::vector const& state_valuation) override; + void substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpression(std::vector const& path, uint64_t action) const override; - void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const override; + + void createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) override; + void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const override; void addFamilyEncoding(Family const& subfamily, z3::solver & solver) const override; @@ -217,6 +223,9 @@ class InnerNode: public TreeNode { z3::expr step_true_harm; z3::expr step_false_harm; + z3::expr substituted_true; + z3::expr substituted_false; + InnerNode( uint64_t identifier, z3::context & ctx, std::vector const& variable_name, @@ -229,9 +238,12 @@ class InnerNode: public TreeNode { void createPaths(z3::expr const& harmonizing_variable) override; uint64_t getPathActionHole(std::vector const& path); - void substitutePrefixExpression(std::vector const& path, std::vector const& state_valuation, z3::expr_vector & substituted) const override; + void createPrefixSubstitutions(std::vector const& state_valuation) override; + void substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpression(std::vector const& path, uint64_t action) const override; - void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const override; + + void createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) override; + void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const override; void addFamilyEncoding(Family const& subfamily, z3::solver & solver) const override;