Skip to content

Commit

Permalink
DT synthesis: smaller inner node formulae
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Sep 26, 2024
1 parent 5b16207 commit baa5633
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/buildtest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
run: docker push ${{ matrix.buildType.imageName }}:${{ matrix.buildType.dockerTag }}

deploy-mdp:
name: Deploy on latest (mdp) (${{ matrix.buildType.name }})
name: Deploy on branch (mdp) (${{ matrix.buildType.name }})
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
2 changes: 2 additions & 0 deletions paynt/quotient/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def __init__(self, mdp, specification):
super().__init__(specification=specification)
updated = payntbind.synthesis.restoreActionsInAbsorbingStates(mdp)
if updated is not None: mdp = updated
# action_labels, _ payntbind.synthesis.extractActionLabels(mdp)
if MdpQuotient.add_dont_care_action:
mdp = payntbind.synthesis.addDontCareAction(mdp)
# stormpy.export_to_drn(mdp, sketch_path+".drn")

self.quotient_mdp = mdp
self.choice_destinations = payntbind.synthesis.computeChoiceDestinations(mdp)
Expand Down
4 changes: 3 additions & 1 deletion paynt/synthesizer/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def verify_family(self, family):
self.check_specification_for_mdp(family)
if not family.analysis_result.can_improve:
return
self.harmonize_inconsistent_scheduler(family)
# self.harmonize_inconsistent_scheduler(family)


def counters_reset(self):
Expand Down Expand Up @@ -115,6 +115,8 @@ def synthesize_tree_sequence(self, opt_result_value):
global_timeout = 300
depth_timeout = global_timeout / 2 / SynthesizerDecisionTree.tree_depth
for depth in range(SynthesizerDecisionTree.tree_depth+1):
print()
print("DEPTH = ", depth)
print()
self.quotient.set_depth(depth)
best_assignment_old = self.best_assignment
Expand Down
110 changes: 11 additions & 99 deletions payntbind/src/synthesis/quotient/ColoringSmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ ColoringSmt<ValueType>::ColoringSmt(
solver(ctx), harmonizing_variable(ctx), one_consistency_check(one_consistency_check) {

timers[__FUNCTION__].start();
// std::cout << __FUNCTION__ << " start" << std::endl;

for(uint64_t state = 0; state < numStates(); ++state) {
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
Expand Down Expand Up @@ -109,7 +108,7 @@ ColoringSmt<ValueType>::ColoringSmt(
}
}

// create choice substitutions
// create state substitutions
std::vector<z3::expr_vector> state_substitution_expr;
for(uint64_t state = 0; state < numStates(); ++state) {
z3::expr_vector substitution_expr(ctx);
Expand Down Expand Up @@ -141,7 +140,8 @@ ColoringSmt<ValueType>::ColoringSmt(
state_path_expression.push_back(z3::expr_vector(ctx));
for(uint64_t path = 0; path < numPaths(); ++path) {
z3::expr_vector substituted(ctx);
getRoot()->substitutePrefixExpression(getRoot()->paths[path], state_substitution_expr[state], substituted);
// getRoot()->substitutePrefixExpression(getRoot()->paths[path], state_substitution_expr[state], substituted);
getRoot()->substitutePrefixExpression(getRoot()->paths[path], state_valuation[state], substituted);
state_path_expression[state].push_back(z3::mk_or(substituted));
}
}
Expand Down Expand Up @@ -415,12 +415,6 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
timers[__FUNCTION__].start();
std::vector<std::vector<uint64_t>> hole_options_vector(family.numHoles());

if(one_consistency_check) {
solver.pop();
timers[__FUNCTION__].stop();
return std::make_pair(false,hole_options_vector);
}

timers["areChoicesConsistent::1 is scheduler consistent?"].start();
solver.push();
getRoot()->addFamilyEncoding(subfamily,solver);
Expand All @@ -432,7 +426,6 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
solver.add(choice_path_expresssion[choice][path], label);
}
}
// std::cout << "(1) added choices: " << choices.getNumberOfSetBits() << std::endl;
bool consistent = check();
timers["areChoicesConsistent::1 is scheduler consistent?"].stop();

Expand All @@ -444,6 +437,14 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
timers[__FUNCTION__].stop();
return std::make_pair(true,hole_options_vector);
}

if(one_consistency_check) {
solver.pop();
solver.pop();
timers[__FUNCTION__].stop();
return std::make_pair(false,hole_options_vector);
}

solver.pop();

timers["areChoicesConsistent::2 better unsat core"].start();
Expand Down Expand Up @@ -539,95 +540,6 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh



template<typename ValueType>
std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areChoicesConsistentUseHint(BitVector const& choices, Family const& subfamily, std::vector<std::pair<uint64_t,uint64_t>> const& unsat_core_hint) {
timers[__FUNCTION__].start();
std::vector<std::vector<uint64_t>> hole_options_vector(family.numHoles());

timers["areChoicesConsistent::2 better unsat core"].start();
solver.push();
getRoot()->addFamilyEncoding(subfamily,solver);
solver.push();
std::queue<uint64_t> unexplored_states;
BitVector state_reached(numStates(),false);
for(auto [choice,path]: unsat_core_hint) {
uint64_t state = choice_to_state[choice];
if(not state_reached[state]) {
unexplored_states.push(state);
state_reached.set(state,true);
}
}
if(not state_reached[initial_state]) {
unexplored_states.push(initial_state);
state_reached.set(initial_state,true);
}
bool consistent = true;
while(not unexplored_states.empty()) {
uint64_t state = unexplored_states.front(); unexplored_states.pop();
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
if(not choices[choice]) {
continue;
}
for(uint64_t path: state_path_enabled[state]) {
const char *label = choice_path_label[choice][path].c_str();
solver.add(choice_path_expresssion[choice][path], label);
}
consistent = check();
if(not consistent) {
break;
}
for(uint64_t dst: choice_destinations[choice]) {
if(not state_reached[dst]) {
unexplored_states.push(dst);
state_reached.set(dst,true);
}
}
break;
}
}
timers["areChoicesConsistent::2 better unsat core"].stop();

if(consistent) {
z3::model model = solver.get_model();
solver.pop();
solver.pop();
getRoot()->loadHoleAssignmentFromModel(model,hole_options_vector);
timers[__FUNCTION__].stop();
return std::make_pair(true,hole_options_vector);
}
z3::expr_vector unsat_core_expr = solver.unsat_core();
solver.pop();
loadUnsatCore(unsat_core_expr,subfamily);

if(PRINT_UNSAT_CORE)
std::cout << "-- unsat core start --" << std::endl;
timers["areChoicesConsistent::3 unsat core analysis"].start();
solver.push();
solver.add(0 <= harmonizing_variable and harmonizing_variable < family.numHoles(), "harmonizing_domain");
for(auto [choice,path]: this->unsat_core) {
solver.add(choice_path_expresssion_harm[choice][path]);
}
consistent = check();
STORM_LOG_THROW(consistent, storm::exceptions::UnexpectedException, "harmonized UNSAT core is not SAT");
z3::model model = solver.get_model();
solver.pop();
solver.pop();

uint64_t harmonizing_hole = model.eval(harmonizing_variable).get_numeral_uint64();
getRoot()->loadHoleAssignmentFromModel(model,hole_options_vector);
getRoot()->loadHoleAssignmentFromModelHarmonizing(model,hole_options_vector,harmonizing_hole);
if(hole_options_vector[harmonizing_hole][0] > hole_options_vector[harmonizing_hole][1]) {
uint64_t tmp = hole_options_vector[harmonizing_hole][0];
hole_options_vector[harmonizing_hole][0] = hole_options_vector[harmonizing_hole][1];
hole_options_vector[harmonizing_hole][1] = tmp;
}
if(PRINT_UNSAT_CORE)
std::cout << "-- unsat core end --" << std::endl;
timers["areChoicesConsistent::3 unsat core analysis"].stop();

timers[__FUNCTION__].stop();
return std::make_pair(false,hole_options_vector);
}


template class ColoringSmt<>;
Expand Down
3 changes: 0 additions & 3 deletions payntbind/src/synthesis/quotient/ColoringSmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ class ColoringSmt {
std::pair<bool,std::vector<std::vector<uint64_t>>> areChoicesConsistent(
BitVector const& choices, Family const& subfamily
);
std::pair<bool,std::vector<std::vector<uint64_t>>> areChoicesConsistentUseHint(
BitVector const& choices, Family const& subfamily, std::vector<std::pair<uint64_t,uint64_t>> const& unsat_core_hint
);

std::map<std::string,storm::utility::Stopwatch> timers;
std::vector<std::pair<std::string,double>> getProfilingInfo() {
Expand Down
27 changes: 23 additions & 4 deletions payntbind/src/synthesis/quotient/TreeNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ uint64_t TerminalNode::getPathActionHole(std::vector<bool> const& path) {
}


void TerminalNode::substitutePrefixExpression(std::vector<bool> const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const {
void TerminalNode::substitutePrefixExpression(std::vector<bool> const& path, std::vector<uint64_t> const& state_valuation, z3::expr_vector & substituted) const {
//
}

Expand Down Expand Up @@ -374,10 +374,29 @@ uint64_t InnerNode::getPathActionHole(std::vector<bool> const& path) {
}


void InnerNode::substitutePrefixExpression(std::vector<bool> const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const {
void InnerNode::substitutePrefixExpression(std::vector<bool> const& path, std::vector<uint64_t> const& state_valuation, z3::expr_vector & substituted) const {
bool step_to_true_child = path[depth];
z3::expr step = step_to_true_child ? step_true : step_false;
substituted.push_back(step.substitute(state_substitution_variables,state_valuation));
// z3::expr step = step_to_true_child ? step_true : step_false;
// substituted.push_back(step.substitute(state_substitution_variables,state_valuation));

z3::expr_vector step_options(ctx);
z3::expr const& dv = decision_hole.solver_variable;
for(uint64_t variable = 0; variable < numVariables(); ++variable) {
z3::expr const& vv = variable_hole[variable].solver_variable;
// mind the negation below
if(step_to_true_child) {
if(state_valuation[variable] > 0) {
// not (Vi = vj => sj<=xj)
step_options.push_back( dv == ctx.int_val(variable) and not(ctx.int_val(state_valuation[variable]) <= vv));
}
} else {
if(state_valuation[variable] < variable_domain[variable].size()-1) {
step_options.push_back( dv == ctx.int_val(variable) and not(ctx.int_val(state_valuation[variable]) > vv));
}
}
}
substituted.push_back(z3::mk_or(step_options));

getChild(step_to_true_child)->substitutePrefixExpression(path,state_valuation,substituted);
}

Expand Down
6 changes: 3 additions & 3 deletions payntbind/src/synthesis/quotient/TreeNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class TreeNode {
virtual uint64_t getPathActionHole(std::vector<bool> const& path) {return 0;}

/** Add a step expression evaluated for a given state valuation. */
virtual void substitutePrefixExpression(std::vector<bool> const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const {};
virtual void substitutePrefixExpression(std::vector<bool> const& path, std::vector<uint64_t> const& state_valuation, z3::expr_vector & substituted) const {};
/** Add an action expression evaluated for a given state valuation. */
virtual z3::expr substituteActionExpression(std::vector<bool> const& path, uint64_t action) const {return z3::expr(ctx);};
/** Add a step expression evaluated for a given state valuation (harmonizing). */
Expand Down Expand Up @@ -175,7 +175,7 @@ class TerminalNode: public TreeNode {
void createPaths(z3::expr const& harmonizing_variable) override;
uint64_t getPathActionHole(std::vector<bool> const& path);

void substitutePrefixExpression(std::vector<bool> const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const override;
void substitutePrefixExpression(std::vector<bool> const& path, std::vector<uint64_t> const& state_valuation, z3::expr_vector & substituted) const override;
z3::expr substituteActionExpression(std::vector<bool> const& path, uint64_t action) const override;
void substitutePrefixExpressionHarmonizing(std::vector<bool> const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const override;
z3::expr substituteActionExpressionHarmonizing(std::vector<bool> const& path, uint64_t action, z3::expr const& harmonizing_variable) const override;
Expand Down Expand Up @@ -229,7 +229,7 @@ class InnerNode: public TreeNode {
void createPaths(z3::expr const& harmonizing_variable) override;
uint64_t getPathActionHole(std::vector<bool> const& path);

void substitutePrefixExpression(std::vector<bool> const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const override;
void substitutePrefixExpression(std::vector<bool> const& path, std::vector<uint64_t> const& state_valuation, z3::expr_vector & substituted) const override;
z3::expr substituteActionExpression(std::vector<bool> const& path, uint64_t action) const override;
void substitutePrefixExpressionHarmonizing(std::vector<bool> const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const override;
z3::expr substituteActionExpressionHarmonizing(std::vector<bool> const& path, uint64_t action, z3::expr const& harmonizing_variable) const override;
Expand Down
1 change: 0 additions & 1 deletion payntbind/src/synthesis/quotient/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ void bindings_coloring(py::module& m) {
.def("selectCompatibleChoices", py::overload_cast<synthesis::Family const&, storm::storage::BitVector const&>(&synthesis::ColoringSmt<>::selectCompatibleChoices))
.def("areChoicesConsistent", &synthesis::ColoringSmt<>::areChoicesConsistent)
.def_property_readonly("unsat_core", [](synthesis::ColoringSmt<>& coloring) {return coloring.unsat_core;})
.def("areChoicesConsistentUseHint", &synthesis::ColoringSmt<>::areChoicesConsistentUseHint)
.def("getProfilingInfo", &synthesis::ColoringSmt<>::getProfilingInfo)
;
}
3 changes: 1 addition & 2 deletions payntbind/src/synthesis/translation/choiceTransformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,9 @@ std::shared_ptr<storm::models::sparse::Model<ValueType>> addDontCareAction(
components.transitionMatrix = builder.build();
auto rewardModels = synthesis::translateRewardModels(model,translated_to_original_choice,translated_choice_mask);
for(auto & [name,reward_model]: rewardModels) {
std::cout << "processing " << name << std::endl;
std::vector<ValueType> & choice_reward = reward_model.getStateActionRewardVector();
ValueType reward_sum = 0;
for(uint64_t state = 0; state < num_states; ++state) {
ValueType reward_sum = 0;
uint64_t new_translated_choice = row_groups_new[state+1]-1;
uint64_t state_num_choices = new_translated_choice-row_groups_new[state];
for(uint64_t translated_choice = row_groups_new[state]; translated_choice < new_translated_choice; ++translated_choice) {
Expand Down

0 comments on commit baa5633

Please sign in to comment.