From 1cd8fd03c7cbf7a9f82fe6d95b0c5c4587c640be Mon Sep 17 00:00:00 2001 From: Ladislav Dokoupil Date: Sat, 23 Nov 2024 17:58:46 +0100 Subject: [PATCH] change policy tree output format --- paynt/quotient/mdp_family.py | 29 +++++++++++++++++------------ paynt/synthesizer/policy_tree.py | 13 +++++-------- paynt/synthesizer/synthesizer.py | 6 +++--- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/paynt/quotient/mdp_family.py b/paynt/quotient/mdp_family.py index 49792104..ca335c86 100644 --- a/paynt/quotient/mdp_family.py +++ b/paynt/quotient/mdp_family.py @@ -1,11 +1,9 @@ -import stormpy import payntbind import paynt.family.family import paynt.quotient.quotient import paynt.models.models -import collections import json import logging @@ -110,17 +108,24 @@ def policy_to_state_valuation_actions(self, policy): ] return state_valuation_to_action - def policy_to_json(self, state_valuation_to_action, indent=""): - import json - json_string = "[\n" - for index,valuation_action in enumerate(state_valuation_to_action): - valuation,action = valuation_action - if index > 0: - json_string += ",\n" - json_string += indent + json.dumps(valuation_action) - json_string += "\n" + indent + "]" - return json_string + def policy_to_json(self, state_valuation_to_action, dt_control=True): + ''' + :param state_valuation_to_action: a list of tuples (valuation,action) where valuation is a dictionary of variable + :param dt_control: if True, outputs JSON in the format expected by the DT control tool, + otherwise simpler format is used + ''' + json_whole = [] + for index, valuation_action in enumerate(state_valuation_to_action): + if dt_control: + json_unit = {} + valuation, action = valuation_action + json_unit["c"] = [{"origin": {"action-label": action}}] + json_unit["s"] = valuation + json_whole.append(json_unit) + else: + json_whole.append(valuation_action) + return json_whole def fix_and_apply_policy_to_family(self, family, policy): diff --git a/paynt/synthesizer/policy_tree.py b/paynt/synthesizer/policy_tree.py index a249a2fe..2e671d3e 100644 --- a/paynt/synthesizer/policy_tree.py +++ b/paynt/synthesizer/policy_tree.py @@ -792,21 +792,18 @@ def evaluate_all(self, family, prop, keep_value_only=False): def run(self, optimum_threshold=None): - return self.evaluate(export_filename_base=paynt.synthesizer.synthesizer.Synthesizer.export_synthesis_filename_base) + return self.evaluate() def export_evaluation_result(self, evaluations, export_filename_base): import json policies = self.policy_tree.extract_policies(self.quotient) - policies_string = "{\n" + policies_json = {} for index,key_value in enumerate(policies.items()): policy_id,policy = key_value - if index > 0: - policies_string += ",\n" - policy_json = self.quotient.policy_to_json(policy, indent= " ") - - policies_string += f'"{policy_id}" : {policy_json}' - policies_string += "}\n" + policy_json = self.quotient.policy_to_json(policy) + policies_json[policy_id] = policy_json + policies_string = json.dumps(policies_json, indent=4) policies_filename = export_filename_base + ".json" with open(policies_filename, 'w') as file: diff --git a/paynt/synthesizer/synthesizer.py b/paynt/synthesizer/synthesizer.py index dac612a3..822bf012 100644 --- a/paynt/synthesizer/synthesizer.py +++ b/paynt/synthesizer/synthesizer.py @@ -119,7 +119,7 @@ def export_evaluation_result(self, evaluations, export_filename_base): ''' to be overridden ''' pass - def evaluate(self, family=None, prop=None, keep_value_only=False, print_stats=True, export_filename_base=None): + def evaluate(self, family=None, prop=None, keep_value_only=False, print_stats=True): ''' Evaluate each member of the family wrt the given property. :param family if None, then the design space of the quotient will be used @@ -143,8 +143,8 @@ def evaluate(self, family=None, prop=None, keep_value_only=False, print_stats=Tr self.stat.finished_evaluation(evaluations) logger.info("evaluation finished") - if export_filename_base is not None: - self.export_evaluation_result(evaluations, export_filename_base) + if self.export_synthesis_filename_base is not None: + self.export_evaluation_result(evaluations, self.export_synthesis_filename_base) if print_stats: self.stat.print()