Skip to content

Commit

Permalink
DT synthesis: tree postprocessing and export
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Sep 26, 2024
1 parent baa5633 commit 03040ca
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 64 deletions.
12 changes: 5 additions & 7 deletions paynt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def setup_logger(log_path = None):
help="path to output file for SAYNT belief FSC")
@click.option("--export-fsc-paynt", type=click.Path(), default=None,
help="path to output file for SAYNT inductive FSC")
@click.option("--export-evaluation", type=click.Path(), default=None,
help="base filename to output evaluation result")
@click.option("--export-synthesis", type=click.Path(), default=None,
help="base filename to output synthesis result")

@click.option("--mdp-split-wrt-mdp", is_flag=True, default=False,
help="if set, MDP abstraction scheduler will be used for splitting, otherwise game abstraction scheduler will be used")
Expand All @@ -126,8 +126,6 @@ def setup_logger(log_path = None):
help="decision tree synthesis: tree depth")
@click.option("--tree-enumeration", is_flag=True, default=False,
help="decision tree synthesis: if set, all trees of size at most tree_depth will be enumerated")
@click.option("--add-dont-care-action", is_flag=True, default=False,
help="decision tree synthesis: if set, an explicit action simulating a random action selection will be added to each state")

@click.option(
"--constraint-bound", type=click.FLOAT, help="bound for creating constrained POMDP for Cassandra models",
Expand All @@ -148,9 +146,9 @@ def paynt_run(
fsc_synthesis, fsc_memory_size, posterior_aware,
storm_pomdp, iterative_storm, get_storm_result, storm_options, prune_storm,
use_storm_cutoffs, unfold_strategy_storm,
export_fsc_storm, export_fsc_paynt, export_evaluation,
export_fsc_storm, export_fsc_paynt, export_synthesis,
mdp_split_wrt_mdp, mdp_discard_unreachable_choices, mdp_use_randomized_abstraction,
tree_depth, tree_enumeration, add_dont_care_action,
tree_depth, tree_enumeration,
constraint_bound,
ce_generator,
profiling
Expand All @@ -166,6 +164,7 @@ def paynt_run(

# set CLI parameters
paynt.quotient.quotient.Quotient.disable_expected_visits = disable_expected_visits
paynt.synthesizer.synthesizer.Synthesizer.export_synthesis_filename_base = export_synthesis
paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.conflict_generator_type = ce_generator
paynt.quotient.pomdp.PomdpQuotient.initial_memory_size = fsc_memory_size
paynt.quotient.pomdp.PomdpQuotient.posterior_aware = posterior_aware
Expand All @@ -177,7 +176,6 @@ def paynt_run(

paynt.synthesizer.decision_tree.SynthesizerDecisionTree.tree_depth = tree_depth
paynt.synthesizer.decision_tree.SynthesizerDecisionTree.tree_enumeration = tree_enumeration
paynt.quotient.mdp.MdpQuotient.add_dont_care_action = add_dont_care_action

storm_control = None
if storm_pomdp:
Expand Down
4 changes: 2 additions & 2 deletions paynt/parser/jani.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def construct_edge(self, edge, substitution = None):
for templ_edge_dest in edge.template_edge.destinations:
assignments = templ_edge_dest.assignments.clone()
if substitution is not None:
# assignments.substitute(substitution, substitute_transcendental_numbers=True)
assignments.substitute(substitution) # legacy version
assignments.substitute(substitution, substitute_transcendental_numbers=True)
# assignments.substitute(substitution) # legacy version
templ_edge.add_destination(stormpy.storage.JaniTemplateEdgeDestination(assignments))

new_edge = stormpy.storage.JaniEdge(
Expand Down
146 changes: 117 additions & 29 deletions paynt/quotient/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import stormpy
import payntbind
import json
import graphviz

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,33 +49,21 @@ def __str__(self):
domain = f"[{self.domain_min}..{self.domain_max}]"
return f"{self.name}:{domain}"

@classmethod
def from_model(cls, model):
assert model.has_state_valuations(), "model has no state valuations"
sv = model.state_valuations
valuation = json.loads(str(sv.get_json(0)))
variable_name = [var_name for var_name in valuation]
state_valuations = []
for state in range(model.nr_states):
valuation = json.loads(str(sv.get_json(state)))
valuation = [valuation[var_name] for var_name in variable_name]
state_valuations.append(valuation)
variables = [Variable(var,var_name,state_valuations) for var,var_name in enumerate(variable_name)]
variables = [v for v in variables if len(v.domain) > 1]
return variables



class DecisionTreeNode:

def __init__(self, parent):
self.parent = parent
self.variable_index = None
self.child_true = None
self.child_false = None
self.identifier = None
self.holes = None
self.hole_assignment = None

self.action = None
self.variable = None
self.variable_bound = None

@property
def is_terminal(self):
Expand All @@ -93,6 +82,11 @@ def add_children(self):
self.child_true = DecisionTreeNode(self)
self.child_false = DecisionTreeNode(self)

def get_depth(self):
if self.is_terminal:
return 0
return 1 + max([child.get_depth() for child in self.child_nodes])

def assign_identifiers(self, identifier=0):
self.identifier = identifier
if self.is_terminal:
Expand All @@ -109,26 +103,97 @@ def associate_holes(self, node_hole_info):
self.child_false.associate_holes(node_hole_info)

def associate_assignment(self, assignment):
self.hole_assignment = [assignment.hole_options(hole)[0] for hole in self.holes]
hole_assignment = [assignment.hole_options(hole)[0] for hole in self.holes]
if self.is_terminal:
self.action = hole_assignment[0]
return

self.variable = hole_assignment[0]
self.variable_bound = hole_assignment[self.variable+1]

self.child_true.associate_assignment(assignment)
self.child_false.associate_assignment(assignment)

def apply_hint(self, subfamily, tree_hint):
if self.is_terminal or tree_hint.is_terminal:
return
for hole_index,option in enumerate(tree_hint.hole_assignment):
hole = self.holes[hole_index]
subfamily.hole_set_options(hole,[option])

variable_hint = tree_hint.variable
subfamily.hole_set_options(self.holes[0],[variable_hint])
subfamily.hole_set_options(self.holes[variable_hint+1],[tree_hint.variable_bound])
self.child_true.apply_hint(subfamily,tree_hint.child_true)
self.child_false.apply_hint(subfamily,tree_hint.child_false)

def simplify(self, variables, state_valuations):
if self.is_terminal:
return

bound = variables[self.variable].domain[self.variable_bound]
state_valuations_true = [valuation for valuation in state_valuations if valuation[self.variable] <= bound]
state_valuations_false = [valuation for valuation in state_valuations if valuation[self.variable] > bound]
child_skip = None
if len(state_valuations_true) == 0:
child_skip = self.child_false
elif len(state_valuations_false) == 0:
child_skip = self.child_true
if child_skip is not None:
self.variable = child_skip.variable
self.variable_bound = child_skip.variable_bound
self.action = child_skip.action
self.child_true = child_skip.child_true
self.child_false = child_skip.child_false
self.simplify(variables,state_valuations)
return

self.child_true.simplify(variables, state_valuations_true)
self.child_false.simplify(variables, state_valuations_false)
if not self.is_terminal and self.child_true.is_terminal and self.child_false.is_terminal and self.child_true.action == self.child_false.action:
self.variable = self.variable_bound = None
self.action = self.child_true.action
self.child_true = self.child_false = None

def to_string(self, variables, action_labels, indent_level=0, indent_size=2):
indent = " "*indent_level*indent_size
if self.is_terminal:
return indent + f"{action_labels[self.action]}" + "\n"
var = variables[self.variable]
s = ""
s += indent + f"if {var.name}<={var.domain[self.variable_bound]}:" + "\n"
s += self.child_true.to_string(variables,action_labels,indent_level+1)
s += indent + f"else:" + "\n"
s += self.child_false.to_string(variables,action_labels,indent_level+1)
return s

@property
def graphviz_id(self):
return str(self.identifier)

def to_graphviz(self, graphviz_tree, variables, action_labels):
if not self.is_terminal:
for child in self.child_nodes:
child.to_graphviz(graphviz_tree,variables,action_labels)

if self.is_terminal:
node_label = action_labels[self.action]
else:
var = variables[self.variable]
node_label = f"{var.name}<={var.domain[self.variable_bound]}"

graphviz_tree.node(self.graphviz_id, label=node_label, shape="box", style="rounded", margin="0.05,0.05")
if not self.is_terminal:
graphviz_tree.edge(self.graphviz_id,self.child_true.graphviz_id,label="True")
graphviz_tree.edge(self.graphviz_id,self.child_false.graphviz_id,label="False")



class DecisionTree:

def __init__(self, model):
self.variables = Variable.from_model(model)
def __init__(self, quotient, variable_name, state_valuations):
self.quotient = quotient
self.state_valuations = state_valuations
variables = [Variable(var,var_name,state_valuations) for var,var_name in enumerate(variable_name)]
variables = [v for v in variables if len(v.domain) > 1]
self.variables = variables
logger.debug(f"found the following {len(self.variables)} variables: {[str(v) for v in self.variables]}")
self.reset()

Expand All @@ -142,6 +207,9 @@ def set_depth(self, depth:int):
node.add_children()
self.root.assign_identifiers()

def get_depth(self):
return self.root.get_depth()

def collect_nodes(self, node_condition=None):
if node_condition is None:
node_condition = lambda node : True
Expand Down Expand Up @@ -170,25 +238,45 @@ def to_list(self):
node_info[node.identifier] = (parent,child_true,child_false)
return node_info

def simplify(self):
self.root.simplify(self.variables, self.state_valuations)

def to_string(self):
return self.root.to_string(self.variables,self.quotient.action_labels)

def to_graphviz(self):
logging.getLogger("graphviz").setLevel(logging.WARNING)
logging.getLogger("graphviz.sources").setLevel(logging.ERROR)
graphviz_tree = graphviz.Digraph(comment="decision tree")
self.root.to_graphviz(graphviz_tree,self.variables,self.quotient.action_labels)
return graphviz_tree

class MdpQuotient(paynt.quotient.quotient.Quotient):

# if set, an explicit action simulating a random action selection will be added to each state
add_dont_care_action = False
class MdpQuotient(paynt.quotient.quotient.Quotient):

def __init__(self, mdp, specification):
super().__init__(specification=specification)
updated = payntbind.synthesis.restoreActionsInAbsorbingStates(mdp)
if updated is not None: mdp = updated
# action_labels, _ payntbind.synthesis.extractActionLabels(mdp)
if MdpQuotient.add_dont_care_action:
action_labels,_ = payntbind.synthesis.extractActionLabels(mdp)
if "__random__" not in action_labels:
logger.debug("adding explicit don't-care action to every state...")
mdp = payntbind.synthesis.addDontCareAction(mdp)
# stormpy.export_to_drn(mdp, sketch_path+".drn")

self.quotient_mdp = mdp
self.choice_destinations = payntbind.synthesis.computeChoiceDestinations(mdp)
self.action_labels,self.choice_to_action = payntbind.synthesis.extractActionLabels(mdp)
self.decision_tree = DecisionTree(mdp)

assert mdp.has_state_valuations(), "model has no state valuations"
sv = mdp.state_valuations
valuation = json.loads(str(sv.get_json(0)))
variable_name = [var_name for var_name in valuation]
state_valuations = []
for state in range(mdp.nr_states):
valuation = json.loads(str(sv.get_json(state)))
valuation = [valuation[var_name] for var_name in variable_name]
state_valuations.append(valuation)
self.decision_tree = DecisionTree(self,variable_name,state_valuations)

self.coloring = None
self.family = None
Expand Down
Loading

0 comments on commit 03040ca

Please sign in to comment.