diff --git a/paynt/quotient/quotient.py b/paynt/quotient/quotient.py index 98655707d..dd6a46c02 100644 --- a/paynt/quotient/quotient.py +++ b/paynt/quotient/quotient.py @@ -91,6 +91,28 @@ def build(self, family): # prepare to discard designs self.discarded = 0 + + def build_with_second_coloring(self, family, main_coloring, main_family): + ''' Construct the quotient MDP for the family. ''' + + # select actions compatible with the family and restrict the quotient + alt_hole_selected_actions,alt_selected_actions,alt_selected_actions_bv = self.coloring.select_actions(family) + main_hole_selected_actions,main_selected_actions,main_selected_actions_bv = main_coloring.select_actions(main_family) + + selected_actions_bv = main_selected_actions_bv.__and__(alt_selected_actions_bv) + main_family.mdp = self.build_from_choice_mask(selected_actions_bv) + main_family.mdp.design_space = main_family + family.mdp = self.build_from_choice_mask(selected_actions_bv) + family.mdp.design_space = family + + # cash restriction information + main_family.hole_selected_actions = main_hole_selected_actions + main_family.selected_actions = main_selected_actions + main_family.selected_actions_bv = selected_actions_bv + + # prepare to discard designs + self.discarded = 0 + @staticmethod def mdp_to_dtmc(mdp): @@ -129,7 +151,27 @@ def scheduler_selection(self, mdp, scheduler): selection[hole_index].add(option) selection = [list(options) for options in selection] - return selection + return selection + + def scheduler_selection_with_coloring(self, mdp, scheduler, coloring): + ''' Get hole options involved in the scheduler selection. ''' + assert scheduler.memoryless and scheduler.deterministic + + # construct DTMC that corresponds to this scheduler and filter reachable states/choices + choices = scheduler.compute_action_support(mdp.model.nondeterministic_choice_indices) + dtmc,_,choice_map = self.restrict_mdp(mdp.model, choices) + choices = [ choice_map[state] for state in range(dtmc.nr_states) ] + + # map relevant choices to hole options + selection = [set() for hole_index in mdp.design_space.hole_indices] + for choice in choices: + global_choice = mdp.quotient_choice_map[choice] + choice_options = coloring.action_to_hole_options[global_choice] + for hole_index,option in choice_options.items(): + selection[hole_index].add(option) + selection = [list(options) for options in selection] + + return selection @staticmethod diff --git a/paynt/synthesizer/policy_tree.py b/paynt/synthesizer/policy_tree.py index 0930c9baa..ddfcb19dc 100644 --- a/paynt/synthesizer/policy_tree.py +++ b/paynt/synthesizer/policy_tree.py @@ -1,6 +1,7 @@ import stormpy.synthesis import paynt.quotient.holes +import paynt.quotient.coloring import paynt.quotient.models import paynt.synthesizer.synthesizer @@ -570,6 +571,196 @@ def split(self, family, prop, hole_selection, splitter): subfamilies.append(subfamily) return suboptions,subfamilies + + + def create_action_coloring(self, quotient_mdp): + + holes = paynt.quotient.holes.Holes() + action_to_hole_options = [] + for state in quotient_mdp.states: + + state_actions = self.quotient.state_to_actions[int(state)] + if len(state_actions) <= 1: + for action in range(quotient_mdp.get_nr_available_actions(state)): + action_to_hole_options.append({}) + continue + + name = f'state_{state}' + options = list(range(len(state_actions))) + option_labels = [self.quotient.action_labels[action] for action in state_actions] + hole = paynt.quotient.holes.Hole(name, options, option_labels) + holes.append(hole) + + for action in range(quotient_mdp.get_nr_available_actions(state)): + choice = quotient_mdp.get_choice_index(state, action) + choice_index = -1 + for index, action_list in enumerate(list(self.quotient.state_action_choices[int(state)])): + if choice in action_list: + choice_index = index + break + assert choice_index != -1 + + hole_options = {len(holes)-1: state_actions.index(choice_index)} + action_to_hole_options.append(hole_options) + + coloring = paynt.quotient.coloring.Coloring(quotient_mdp, holes, action_to_hole_options) + + return coloring + + + def update_scores(self, score_lists, selection): + for hole, score_list in score_lists.items(): + for choice in selection[hole]: + if choice not in score_list: + score_list.append(choice) + + + def create_policy(self, scheduler, family): + choice_to_action = [] + for choice in range(family.mdp.choices): + action = self.quotient.choice_to_action[family.mdp.quotient_choice_map[choice]] + choice_to_action.append(action) + + policy = self.quotient.empty_policy() + for state in range(family.mdp.model.nr_states): + state_choice = scheduler.get_choice(state).get_deterministic_choice() + choice = family.mdp.model.transition_matrix.get_row_group_start(state) + state_choice + action = choice_to_action[choice] + quotient_state = family.mdp.quotient_state_map[state] + policy[quotient_state] = action + + return policy + + + # synthesize one policy for family of MDPs (if such policy exists) + # set all_sat=True if all MDPs in the family are sat + # returns - True, unsat_families, sat_families, policy + # - False, unsat_families, sat_families, sat_policies + def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration_limit=0): + sat_mdp_families = [] + sat_mdp_policies = [] + unsat_mdp_families = [] + + # coloring for MDP choices + # TODO move this outside of the function since this needs to be created only once or create coloring on family.mdp + action_coloring = self.create_action_coloring(self.quotient.quotient_mdp) + action_family = paynt.quotient.holes.DesignSpace(action_coloring.holes) + + # create MDP subfamilies + for hole_assignment in family.all_combinations(): + subfamily = family.copy() + for hole_index, hole_option in enumerate(hole_assignment): + subfamily.assume_hole_options(hole_index, [hole_option]) + + # find out which mdps are sat and unsat + if not all_sat: + self.quotient.build(subfamily) + primary_result = subfamily.mdp.model_check_property(prop) + self.stat.iteration_mdp(subfamily.mdp.states) + + if primary_result.sat == False: + unsat_mdp_families.append(subfamily) + continue + + sat_mdp_families.append(subfamily) + policy = self.create_policy(primary_result.result.scheduler, subfamily) + sat_mdp_policies.append(policy) + else: + sat_mdp_families.append(subfamily) + + # no sat mdps + if len(sat_mdp_families) == 0: + return False, unsat_mdp_families, sat_mdp_families, None + + if len(sat_mdp_policies) == 0: + sat_mdp_policies = [None for _ in sat_mdp_families] + + action_family_stack = [action_family] + iter = 0 + + # AR for policies + while action_family_stack: + + if iteration_limit>0 and iter>iteration_limit: + break + + current_action_family = action_family_stack.pop(-1) + current_results = [] + + score_lists = {hole:[] for hole in current_action_family.hole_indices if len(current_action_family[hole].options) > 1} + + # try to find controller inconsistency across the MDPs + # if the controllers are consistent, return True + for index, mdp_subfamily in enumerate(sat_mdp_families): + self.quotient.build_with_second_coloring(mdp_subfamily, action_coloring, current_action_family) # maybe copy to new family? + + primary_result = current_action_family.mdp.model_check_property(prop) + self.stat.iteration_mdp(current_action_family.mdp.states) + + # discard the family as soon as one MDP is unsat + if primary_result.sat == False: + current_results.append(False) + break + + # add policy if current mdp doesn't have one yet + # TODO maybe this can be done after some number of controllers are consistent? + if sat_mdp_policies[index] == None: + policy = self.create_policy(primary_result.result.scheduler, mdp_subfamily) + sat_mdp_policies[index] = policy + + current_results.append(primary_result) + selection = self.quotient.scheduler_selection_with_coloring(current_action_family.mdp, primary_result.result.scheduler, action_coloring) + self.update_scores(score_lists, selection) + + scores = {hole:len(score_list) for hole, score_list in score_lists.items()} + + splitters = self.quotient.holes_with_max_score(scores) + splitter = splitters[0] + + # refinement as soon as the first inconsistency is found + if scores[splitter] > 1: + break + else: + for index, (result, family) in enumerate(zip(current_results, sat_mdp_families)): + policy = self.create_policy(result.result.scheduler, family) + sat_mdp_policies[index] = policy + return True, unsat_mdp_families, sat_mdp_families, sat_mdp_policies + + if False in current_results: + continue + + used_options = score_lists[splitter] + core_suboptions = [[option] for option in used_options] + other_suboptions = [option for option in current_action_family[splitter].options if option not in used_options] + if other_suboptions: + other_suboptions = [other_suboptions] + else: + other_suboptions = [] + suboptions = other_suboptions + core_suboptions # DFS solves core first + + subfamilies = [] + current_action_family.splitter = splitter + new_design_space = current_action_family.copy() + for suboption in suboptions: + subholes = new_design_space.subholes(splitter, suboption) + action_subfamily = paynt.quotient.holes.DesignSpace(subholes) + action_subfamily.assume_hole_options(splitter, suboption) + subfamilies.append(action_subfamily) + + action_family_stack = action_family_stack + subfamilies + + iter += 1 + + # compute policies for the sat mdps that were never analysed + mdps_without_policy = [index for index, policy in enumerate(sat_mdp_policies) if policy is None] + for mdp_index in mdps_without_policy: + self.quotient.build(sat_mdp_families[mdp_index]) + primary_result = sat_mdp_families[mdp_index].mdp.model_check_property(prop) + self.stat.iteration_mdp(sat_mdp_families[mdp_index].mdp.states) + policy = self.create_policy(primary_result.result.scheduler, sat_mdp_families[mdp_index]) + sat_mdp_policies[mdp_index] = policy + + return False, unsat_mdp_families, sat_mdp_families, sat_mdp_policies @@ -581,6 +772,11 @@ def synthesize_policy_tree(self, family): # game_solver.enable_profiling(True) policy_tree = PolicyTree(family) + # self.quotient.build(policy_tree.root.family) + # policy_exists, _ = self.synthesize_policy_for_tree_node(policy_tree.root.family, prop) + # print(policy_exists) + # exit() + reference_policy = None policy_tree_leaves = [policy_tree.root] while policy_tree_leaves: @@ -592,6 +788,9 @@ def synthesize_policy_tree(self, family): policy_tree_node.policy = result.policy policy_tree_node.policy_source = result.policy_source + # if family.size < 8: + # policy_exists, unsat, sat, policy = self.synthesize_policy_for_tree_node(family, prop) + if result.policy == False: reference_policy = None self.explore(family)