Skip to content

Commit

Permalink
Merge pull request #28 from TheGreatfpmK/new-master
Browse files Browse the repository at this point in the history
Finding one policy for given family of MDPs
  • Loading branch information
TheGreatfpmK authored Dec 6, 2023
2 parents fefb966 + 1ba265a commit 4b75c14
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 1 deletion.
44 changes: 43 additions & 1 deletion paynt/quotient/quotient.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ def build(self, family):
# prepare to discard designs
self.discarded = 0


def build_with_second_coloring(self, family, main_coloring, main_family):
''' Construct the quotient MDP for the family. '''

# select actions compatible with the family and restrict the quotient
alt_hole_selected_actions,alt_selected_actions,alt_selected_actions_bv = self.coloring.select_actions(family)
main_hole_selected_actions,main_selected_actions,main_selected_actions_bv = main_coloring.select_actions(main_family)

selected_actions_bv = main_selected_actions_bv.__and__(alt_selected_actions_bv)
main_family.mdp = self.build_from_choice_mask(selected_actions_bv)
main_family.mdp.design_space = main_family
family.mdp = self.build_from_choice_mask(selected_actions_bv)
family.mdp.design_space = family

# cash restriction information
main_family.hole_selected_actions = main_hole_selected_actions
main_family.selected_actions = main_selected_actions
main_family.selected_actions_bv = selected_actions_bv

# prepare to discard designs
self.discarded = 0


@staticmethod
def mdp_to_dtmc(mdp):
Expand Down Expand Up @@ -129,7 +151,27 @@ def scheduler_selection(self, mdp, scheduler):
selection[hole_index].add(option)
selection = [list(options) for options in selection]

return selection
return selection

def scheduler_selection_with_coloring(self, mdp, scheduler, coloring):
''' Get hole options involved in the scheduler selection. '''
assert scheduler.memoryless and scheduler.deterministic

# construct DTMC that corresponds to this scheduler and filter reachable states/choices
choices = scheduler.compute_action_support(mdp.model.nondeterministic_choice_indices)
dtmc,_,choice_map = self.restrict_mdp(mdp.model, choices)
choices = [ choice_map[state] for state in range(dtmc.nr_states) ]

# map relevant choices to hole options
selection = [set() for hole_index in mdp.design_space.hole_indices]
for choice in choices:
global_choice = mdp.quotient_choice_map[choice]
choice_options = coloring.action_to_hole_options[global_choice]
for hole_index,option in choice_options.items():
selection[hole_index].add(option)
selection = [list(options) for options in selection]

return selection


@staticmethod
Expand Down
199 changes: 199 additions & 0 deletions paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import stormpy.synthesis

import paynt.quotient.holes
import paynt.quotient.coloring
import paynt.quotient.models
import paynt.synthesizer.synthesizer

Expand Down Expand Up @@ -570,6 +571,196 @@ def split(self, family, prop, hole_selection, splitter):
subfamilies.append(subfamily)

return suboptions,subfamilies


def create_action_coloring(self, quotient_mdp):

holes = paynt.quotient.holes.Holes()
action_to_hole_options = []
for state in quotient_mdp.states:

state_actions = self.quotient.state_to_actions[int(state)]
if len(state_actions) <= 1:
for action in range(quotient_mdp.get_nr_available_actions(state)):
action_to_hole_options.append({})
continue

name = f'state_{state}'
options = list(range(len(state_actions)))
option_labels = [self.quotient.action_labels[action] for action in state_actions]
hole = paynt.quotient.holes.Hole(name, options, option_labels)
holes.append(hole)

for action in range(quotient_mdp.get_nr_available_actions(state)):
choice = quotient_mdp.get_choice_index(state, action)
choice_index = -1
for index, action_list in enumerate(list(self.quotient.state_action_choices[int(state)])):
if choice in action_list:
choice_index = index
break
assert choice_index != -1

hole_options = {len(holes)-1: state_actions.index(choice_index)}
action_to_hole_options.append(hole_options)

coloring = paynt.quotient.coloring.Coloring(quotient_mdp, holes, action_to_hole_options)

return coloring


def update_scores(self, score_lists, selection):
for hole, score_list in score_lists.items():
for choice in selection[hole]:
if choice not in score_list:
score_list.append(choice)


def create_policy(self, scheduler, family):
choice_to_action = []
for choice in range(family.mdp.choices):
action = self.quotient.choice_to_action[family.mdp.quotient_choice_map[choice]]
choice_to_action.append(action)

policy = self.quotient.empty_policy()
for state in range(family.mdp.model.nr_states):
state_choice = scheduler.get_choice(state).get_deterministic_choice()
choice = family.mdp.model.transition_matrix.get_row_group_start(state) + state_choice
action = choice_to_action[choice]
quotient_state = family.mdp.quotient_state_map[state]
policy[quotient_state] = action

return policy


# synthesize one policy for family of MDPs (if such policy exists)
# set all_sat=True if all MDPs in the family are sat
# returns - True, unsat_families, sat_families, policy
# - False, unsat_families, sat_families, sat_policies
def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration_limit=0):
sat_mdp_families = []
sat_mdp_policies = []
unsat_mdp_families = []

# coloring for MDP choices
# TODO move this outside of the function since this needs to be created only once or create coloring on family.mdp
action_coloring = self.create_action_coloring(self.quotient.quotient_mdp)
action_family = paynt.quotient.holes.DesignSpace(action_coloring.holes)

# create MDP subfamilies
for hole_assignment in family.all_combinations():
subfamily = family.copy()
for hole_index, hole_option in enumerate(hole_assignment):
subfamily.assume_hole_options(hole_index, [hole_option])

# find out which mdps are sat and unsat
if not all_sat:
self.quotient.build(subfamily)
primary_result = subfamily.mdp.model_check_property(prop)
self.stat.iteration_mdp(subfamily.mdp.states)

if primary_result.sat == False:
unsat_mdp_families.append(subfamily)
continue

sat_mdp_families.append(subfamily)
policy = self.create_policy(primary_result.result.scheduler, subfamily)
sat_mdp_policies.append(policy)
else:
sat_mdp_families.append(subfamily)

# no sat mdps
if len(sat_mdp_families) == 0:
return False, unsat_mdp_families, sat_mdp_families, None

if len(sat_mdp_policies) == 0:
sat_mdp_policies = [None for _ in sat_mdp_families]

action_family_stack = [action_family]
iter = 0

# AR for policies
while action_family_stack:

if iteration_limit>0 and iter>iteration_limit:
break

current_action_family = action_family_stack.pop(-1)
current_results = []

score_lists = {hole:[] for hole in current_action_family.hole_indices if len(current_action_family[hole].options) > 1}

# try to find controller inconsistency across the MDPs
# if the controllers are consistent, return True
for index, mdp_subfamily in enumerate(sat_mdp_families):
self.quotient.build_with_second_coloring(mdp_subfamily, action_coloring, current_action_family) # maybe copy to new family?

primary_result = current_action_family.mdp.model_check_property(prop)
self.stat.iteration_mdp(current_action_family.mdp.states)

# discard the family as soon as one MDP is unsat
if primary_result.sat == False:
current_results.append(False)
break

# add policy if current mdp doesn't have one yet
# TODO maybe this can be done after some number of controllers are consistent?
if sat_mdp_policies[index] == None:
policy = self.create_policy(primary_result.result.scheduler, mdp_subfamily)
sat_mdp_policies[index] = policy

current_results.append(primary_result)
selection = self.quotient.scheduler_selection_with_coloring(current_action_family.mdp, primary_result.result.scheduler, action_coloring)
self.update_scores(score_lists, selection)

scores = {hole:len(score_list) for hole, score_list in score_lists.items()}

splitters = self.quotient.holes_with_max_score(scores)
splitter = splitters[0]

# refinement as soon as the first inconsistency is found
if scores[splitter] > 1:
break
else:
for index, (result, family) in enumerate(zip(current_results, sat_mdp_families)):
policy = self.create_policy(result.result.scheduler, family)
sat_mdp_policies[index] = policy
return True, unsat_mdp_families, sat_mdp_families, sat_mdp_policies

if False in current_results:
continue

used_options = score_lists[splitter]
core_suboptions = [[option] for option in used_options]
other_suboptions = [option for option in current_action_family[splitter].options if option not in used_options]
if other_suboptions:
other_suboptions = [other_suboptions]
else:
other_suboptions = []
suboptions = other_suboptions + core_suboptions # DFS solves core first

subfamilies = []
current_action_family.splitter = splitter
new_design_space = current_action_family.copy()
for suboption in suboptions:
subholes = new_design_space.subholes(splitter, suboption)
action_subfamily = paynt.quotient.holes.DesignSpace(subholes)
action_subfamily.assume_hole_options(splitter, suboption)
subfamilies.append(action_subfamily)

action_family_stack = action_family_stack + subfamilies

iter += 1

# compute policies for the sat mdps that were never analysed
mdps_without_policy = [index for index, policy in enumerate(sat_mdp_policies) if policy is None]
for mdp_index in mdps_without_policy:
self.quotient.build(sat_mdp_families[mdp_index])
primary_result = sat_mdp_families[mdp_index].mdp.model_check_property(prop)
self.stat.iteration_mdp(sat_mdp_families[mdp_index].mdp.states)
policy = self.create_policy(primary_result.result.scheduler, sat_mdp_families[mdp_index])
sat_mdp_policies[mdp_index] = policy

return False, unsat_mdp_families, sat_mdp_families, sat_mdp_policies



Expand All @@ -581,6 +772,11 @@ def synthesize_policy_tree(self, family):
# game_solver.enable_profiling(True)
policy_tree = PolicyTree(family)

# self.quotient.build(policy_tree.root.family)
# policy_exists, _ = self.synthesize_policy_for_tree_node(policy_tree.root.family, prop)
# print(policy_exists)
# exit()

reference_policy = None
policy_tree_leaves = [policy_tree.root]
while policy_tree_leaves:
Expand All @@ -592,6 +788,9 @@ def synthesize_policy_tree(self, family):
policy_tree_node.policy = result.policy
policy_tree_node.policy_source = result.policy_source

# if family.size < 8:
# policy_exists, unsat, sat, policy = self.synthesize_policy_for_tree_node(family, prop)

if result.policy == False:
reference_policy = None
self.explore(family)
Expand Down

0 comments on commit 4b75c14

Please sign in to comment.