Skip to content

Commit

Permalink
delegate path formula harmonization to TreeNode
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Sep 19, 2024
1 parent fecbfc4 commit c23c119
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 117 deletions.
10 changes: 8 additions & 2 deletions paynt/synthesizer/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def harmonize_inconsistent_scheduler(self, family):

def verify_family(self, family):
self.num_families_considered += 1

self.quotient.build(family)
if family.mdp is None:
self.num_families_skipped += 1
Expand All @@ -85,6 +84,7 @@ def verify_family(self, family):
return
self.harmonize_inconsistent_scheduler(family)


def counters_reset(self):
self.num_families_considered = 0
self.num_families_skipped = 0
Expand All @@ -110,6 +110,10 @@ def synthesize_tree(self, depth:int):

def synthesize_tree_sequence(self, opt_result_value):
tree_hint = None
global_timeout = paynt.utils.timer.GlobalTimer.global_timer.time_limit_seconds
if global_timeout is None:
global_timeout = 300
depth_timeout = global_timeout / 2 / SynthesizerDecisionTree.tree_depth
for depth in range(SynthesizerDecisionTree.tree_depth+1):
print()
self.quotient.set_depth(depth)
Expand All @@ -120,7 +124,8 @@ def synthesize_tree_sequence(self, opt_result_value):
self.counters_reset()
self.stat = paynt.synthesizer.statistic.Statistic(self)
self.stat.start(family)
self.synthesis_timer = paynt.utils.timer.Timer()
timeout = depth_timeout if depth < SynthesizerDecisionTree.tree_depth else None
self.synthesis_timer = paynt.utils.timer.Timer(timeout)
self.synthesis_timer.start()
families = [family]

Expand All @@ -133,6 +138,7 @@ def synthesize_tree_sequence(self, opt_result_value):
self.synthesize_one(family)
self.stat.finished_synthesis()
self.stat.print()
self.synthesis_timer = None
self.counters_print()

new_assignment_synthesized = self.best_assignment != best_assignment_old
Expand Down
139 changes: 85 additions & 54 deletions payntbind/src/synthesis/quotient/ColoringSmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,38 @@ ColoringSmt<ValueType>::ColoringSmt(
STORM_LOG_THROW(variable_found, storm::exceptions::UnexpectedException, "Unexpected variable name.");
}

// create substitution variables
z3::expr_vector state_substitution_variables(ctx);
z3::expr_vector choice_substitution_variables(ctx);
for(auto const& name: variable_name) {
z3::expr variable = ctx.int_const(name.c_str());
state_substitution_variables.push_back(variable);
choice_substitution_variables.push_back(variable);
}
z3::expr action_substitution_variable = ctx.int_const("act");
choice_substitution_variables.push_back(action_substitution_variable);

// create the tree
uint64_t num_nodes = tree_list.size();
uint64_t num_actions = *std::max_element(choice_to_action.begin(),choice_to_action.end())-1;
uint64_t num_actions = *std::max_element(choice_to_action.begin(),choice_to_action.end())+1;
for(uint64_t node = 0; node < num_nodes; ++node) {
auto [parent,child_true,child_false] = tree_list[node];
STORM_LOG_THROW(
(child_true != num_nodes) == (child_false != num_nodes), storm::exceptions::UnexpectedException,
"Inner node has only one child."
);
if(child_true != num_nodes) {
tree.push_back(std::make_shared<InnerNode>(node,ctx,this->variable_name,this->variable_domain));
tree.push_back(std::make_shared<InnerNode>(node,ctx,this->variable_name,this->variable_domain,state_substitution_variables));
} else {
tree.push_back(std::make_shared<TerminalNode>(node,ctx,this->variable_name,this->variable_domain,num_actions));
tree.push_back(std::make_shared<TerminalNode>(node,ctx,this->variable_name,this->variable_domain,num_actions,action_substitution_variable));
}
}
getRoot()->createTree(tree_list,tree);

// create substitution variables
z3::expr_vector substitution_variables(ctx);
for(auto const& name: variable_name) {
substitution_variables.push_back(ctx.int_const(name.c_str()));
}
substitution_variables.push_back(ctx.int_const("act"));
getRoot()->createHoles(family);
getRoot()->createPaths(substitution_variables);
harmonizing_variable = ctx.int_const("__harm__");
getRoot()->createPathsHarmonizing(substitution_variables, harmonizing_variable);
getRoot()->createPaths(harmonizing_variable);

for(uint64_t state = 0; state < numStates(); ++state) {
state_path_enabled.push_back(BitVector(numPaths()));
}
Expand Down Expand Up @@ -105,14 +110,26 @@ ColoringSmt<ValueType>::ColoringSmt(
}

// create choice substitutions
std::vector<z3::expr_vector> state_substitution_expr;
for(uint64_t state = 0; state < numStates(); ++state) {
z3::expr_vector substitution_expr(ctx);
for(uint64_t value: state_valuation[state]) {
substitution_expr.push_back(ctx.int_val(value));
}
state_substitution_expr.push_back(substitution_expr);
}

// create choice substitutions
// std::vector<z3::expr_vector> choice_action_substitution_expr;
std::vector<z3::expr_vector> choice_substitution_expr;
for(uint64_t state = 0; state < numStates(); ++state) {
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
z3::expr_vector substitution_expr(ctx);
for(uint64_t value: state_valuation[state]) {
substitution_expr.push_back(ctx.int_val(value));
}
substitution_expr.push_back(ctx.int_val(choice_to_action[choice]));
z3::expr action_substitution_expr = ctx.int_val(choice_to_action[choice]);
substitution_expr.push_back(action_substitution_expr);
choice_substitution_expr.push_back(substitution_expr);
}
}
Expand All @@ -132,17 +149,41 @@ ColoringSmt<ValueType>::ColoringSmt(

// create choice colors
timers["ColoringSmt:: create choice colors"].start();

choice_path_label.resize(numChoices());
for(uint64_t state = 0; state < numStates(); ++state) {
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
std::vector<std::string> path_label;
z3::expr_vector path_evaluated(ctx);
for(uint64_t path = 0; path < numPaths(); ++path) {
std::string label = "p" + std::to_string(choice) + "_" + std::to_string(path);
path_label.push_back(label);
path_evaluated.push_back(path_expression[path].substitute(substitution_variables,choice_substitution_expr[choice]));
// path_evaluated.push_back(path_expression[path].substitute(substitution_variables,choice_substitution_expr[choice]).simplify());
choice_path_label[choice].push_back(label);
}
}
}

/*std::vector<z3::expr_vector> state_path_expresssion;
for(uint64_t state = 0; state < numStates(); ++state) {
z3::expr_vector path_evaluated(ctx);
for(uint64_t path = 0; path < numPaths(); ++path) {
path_evaluated.push_back(path_expression[path].substitute(state_substitution_variables,state_substitution_expr[state]));
}
state_path_expresssion.push_back(path_evaluated);
}
for(uint64_t state = 0; state < numStates(); ++state) {
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
z3::expr_vector path_evaluated(ctx);
for(uint64_t path = 0; path < numPaths(); ++path) {
path_evaluated.push_back(state_path_expresssion[state][path].substitute(action_substitution_variables,choice_action_substitution_expr[choice]));
}
choice_path_expresssion.push_back(path_evaluated);
}
}*/
for(uint64_t state = 0; state < numStates(); ++state) {
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
z3::expr_vector path_evaluated(ctx);
for(uint64_t path = 0; path < numPaths(); ++path) {
path_evaluated.push_back(path_expression[path].substitute(choice_substitution_variables,choice_substitution_expr[choice]));
}
choice_path_label.push_back(path_label);
choice_path_expresssion.push_back(path_evaluated);
}
}
Expand All @@ -153,47 +194,34 @@ ColoringSmt<ValueType>::ColoringSmt(
return;
}

timers["ColoringSmt:: create harmonizing variants (1)"].start();
// create harmonizing variants
std::vector<const TreeNode::Hole *> all_holes(family.numHoles());
getRoot()->loadAllHoles(all_holes);
std::vector<z3::expr_vector> hole_what;
std::vector<z3::expr_vector> hole_with;
for(const TreeNode::Hole *hole: all_holes) {
z3::expr_vector what(ctx); what.push_back(hole->solver_variable); hole_what.push_back(what);
z3::expr_vector with(ctx); with.push_back(hole->solver_variable_harm); hole_with.push_back(with);
}

std::vector<std::vector<std::vector<uint64_t>>> path_step_holes(numPaths());
for(uint64_t path = 0; path < numPaths(); ++path) {
getRoot()->loadPathStepHoles(getRoot()->paths[path],path_step_holes[path]);
}

z3::expr_vector path_expression_harmonizing(ctx);
for(uint64_t path = 0; path < numPaths(); ++path) {
z3::expr_vector variants(ctx);
variants.push_back(path_expression[path]);
for(uint64_t step = 0; step < path_step_expression[path].size(); ++step) {
for(uint64_t hole: path_step_holes[path][step]) {
variants.push_back(
(harmonizing_variable == (int)hole) and path_step_expression[path][step].substitute(hole_what[hole],hole_with[hole])
);
}
timers["ColoringSmt:: create harmonizing variants"].start();
std::vector<z3::expr_vector> state_path_expression_harmonizing;
for(uint64_t state = 0; state < numStates(); ++state) {
state_path_expression_harmonizing.push_back(z3::expr_vector(ctx));
for(uint64_t path = 0; path < numPaths(); ++path) {
z3::expr_vector evaluated(ctx);
getRoot()->substitutePrefixExpressionHarmonizing(getRoot()->paths[path], state_substitution_expr[state], evaluated);
state_path_expression_harmonizing[state].push_back(z3::mk_or(evaluated));
}
path_expression_harmonizing.push_back(z3::mk_or(variants));
}
timers["ColoringSmt:: create harmonizing variants (1)"].stop();

for(uint64_t choice = 0; choice < numChoices(); ++choice) {
choice_path_expresssion_harm.push_back(z3::expr_vector(ctx));
std::vector<z3::expr_vector> action_path_expression_harmonizing;
for(uint64_t action = 0; action < num_actions; ++action) {
action_path_expression_harmonizing.push_back(z3::expr_vector(ctx));
for(uint64_t path = 0; path < numPaths(); ++path) {
z3::expr evaluated = getRoot()->substituteActionExpressionHarmonizing(getRoot()->paths[path], action, harmonizing_variable);
action_path_expression_harmonizing[action].push_back(evaluated);
}
}
timers["ColoringSmt:: create harmonizing variants (2)"].start();
for(uint64_t path = 0; path < numPaths(); ++path) {
for(uint64_t choice = 0; choice < numChoices(); ++choice) {
choice_path_expresssion_harm[choice].push_back(path_expression_harmonizing[path].substitute(substitution_variables,choice_substitution_expr[choice]));
for(uint64_t state = 0; state < numStates(); ++state) {
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
choice_path_expresssion_harm.push_back(z3::expr_vector(ctx));
uint64_t action = choice_to_action[choice];
for(uint64_t path = 0; path < numPaths(); ++path) {
choice_path_expresssion_harm[choice].push_back(state_path_expression_harmonizing[state][path] or action_path_expression_harmonizing[action][path]);
}
}
}
timers["ColoringSmt:: create harmonizing variants (2)"].stop();
timers["ColoringSmt:: create harmonizing variants"].stop();

timers[__FUNCTION__].stop();
}
Expand Down Expand Up @@ -428,6 +456,7 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
solver.add(choice_path_expresssion[choice][path], label);
}
}
// std::cout << "(1) added choices: " << choices.getNumberOfSetBits() << std::endl;
bool consistent = check();
timers["areChoicesConsistent::1 is scheduler consistent?"].stop();

Expand All @@ -448,6 +477,7 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
BitVector state_reached(numStates(),false);
state_reached.set(initial_state,true);
consistent = true;
uint64_t num_choices_added = 0;
while(consistent) {
STORM_LOG_THROW(not unexplored_states.empty(), storm::exceptions::UnexpectedException, "all states explored but UNSAT core not found");
uint64_t state = unexplored_states.front(); unexplored_states.pop();
Expand All @@ -459,6 +489,7 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
const char *label = choice_path_label[choice][path].c_str();
solver.add(choice_path_expresssion[choice][path], label);
}
// std::cout << "(2) adding choice " << (++num_choices_added) << std::endl;
consistent = check();
if(not consistent) {
break;
Expand Down
Loading

0 comments on commit c23c119

Please sign in to comment.