Skip to content

Commit

Permalink
policy and decision tree export
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Nov 12, 2024
1 parent d4df15e commit 66dde3f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions paynt/synthesizer/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def counters_print(self):

def export_decision_tree(self, decision_tree, export_filename_base):
tree = decision_tree.to_graphviz()
# tree_filename = export_filename_base + ".dot"
# with open(tree_filename, 'w') as file:
# file.write(tree.source)
# logger.info(f"exported decision tree to {tree_filename}")
tree_filename = export_filename_base + ".dot"
with open(tree_filename, 'w') as file:
file.write(tree.source)
logger.info(f"exported decision tree to {tree_filename}")

tree_visualization_filename = export_filename_base + ".png"
tree.render(export_filename_base, format="png", cleanup=True) # using export_filename_base since graphviz appends .png by default
Expand Down
2 changes: 1 addition & 1 deletion paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def evaluate_all(self, family, prop, keep_value_only=False):


def run(self, optimum_threshold=None):
return self.evaluate(export_filename_base=None)
return self.evaluate(export_filename_base=paynt.synthesizer.synthesizer.Synthesizer.export_synthesis_filename_base)


def export_evaluation_result(self, evaluations, export_filename_base):
Expand Down
2 changes: 1 addition & 1 deletion paynt/synthesizer/synthesizer_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def update_optimum(self, family):
self.quotient.specification.optimality.update_optimum(iv)
self.best_assignment = ia
self.best_assignment_value = iv
logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds")
# logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds")
if isinstance(self.quotient, paynt.quotient.pomdp.PomdpQuotient):
self.stat.new_fsc_found(family.analysis_result.improving_value, ia, self.quotient.policy_size(ia))

Expand Down

0 comments on commit 66dde3f

Please sign in to comment.