diff --git a/paynt/parser/prism_parser.py b/paynt/parser/prism_parser.py index cdfa73523..d204ef72b 100644 --- a/paynt/parser/prism_parser.py +++ b/paynt/parser/prism_parser.py @@ -40,8 +40,7 @@ def read_prism(cls, sketch_path, properties_path, relative_error, discount_facto holes = None if len(hole_definitions) > 0: logger.info("processing hole definitions...") - prism, hole_expressions, holes = PrismParser.parse_holes( - prism, expression_parser, hole_definitions) + prism, hole_expressions, holes = PrismParser.parse_holes(prism, expression_parser, hole_definitions) specification = PrismParser.parse_specification(properties_path, relative_error, discount_factor, prism, holes) @@ -74,18 +73,23 @@ def load_sketch_prism(cls, sketch_path): sketch_lines = f.readlines() # replace hole definitions with constants - hole_re = re.compile(r'^hole\s+(.*?)\s+(.*?)\s+in\s+\{(.*?)\};$') + hole_re_brace = re.compile(r'^\s*hole\s+(.*?)\s+(.*?)\s+in\s+\{(.*?)\}\s*;\s*$') + # hole_re_bracket = re.compile(r'^\s*hole\s+(.*?)\s+(.*?)\s+in\s+[(.*?)]\s+;\s+$') sketch_output = [] - hole_definitions = {} + hole_definitions = [] for line in sketch_lines: - match = hole_re.search(line) - if match is not None: - hole_type = match.group(1) - hole_name = match.group(2) - hole_options = match.group(3).replace(" ", "") - hole_definitions[hole_name] = hole_options - line = f"const {hole_type} {hole_name};\n" - sketch_output.append(line) + match = hole_re_brace.search(line) + if match is None: + sketch_output.append(line) + continue + hole_type = match.group(1) + hole_name = match.group(2) + hole_options = match.group(3).replace(" ", "") + hole_definitions.append( (hole_name,hole_type,hole_options) ) + sketch_output.append(f"const {hole_type} {hole_name};\n") + sketch_output.append(f"const {hole_type} {hole_name}_MIN;\n") + sketch_output.append(f"const {hole_type} {hole_name}_MAX;\n") + # store modified sketch to a temporary file tmp_path = sketch_path + str(uuid.uuid4()) @@ -110,8 +114,19 @@ def parse_holes(cls, prism, expression_parser, hole_definitions): # parse hole definitions holes = Holes() hole_expressions = [] - for hole_name,definition in hole_definitions.items(): - options = definition.split(",") + hole_min = [] + hole_max = [] + for hole_name,hole_type,hole_options in hole_definitions: + if ".." in hole_options: + assert hole_type == "int", "cannot use range-based definitions for non-integer hole types" + options = hole_options.split("..") + range_start = int(options[0]) + range_end = int(options[1]) + hole_min.append(range_start) + hole_max.append(range_end) + options = [str(o) for o in range(range_start,range_end+1)] + else: + options = hole_options.split(",") expressions = [expression_parser.parse(o) for o in options] hole_expressions.append(expressions) @@ -120,9 +135,20 @@ def parse_holes(cls, prism, expression_parser, hole_definitions): hole = Hole(hole_name, options, option_labels) holes.append(hole) + # substitute constants used as min/max values of holes + hole_range_definitions = {} + for hole_index,hole in enumerate(holes): + var_min = prism.get_constant(f"{hole.name}_MIN").expression_variable + hole_range_definitions[var_min] = expression_parser.parse(str(hole_min[hole_index])) + var_max = prism.get_constant(f"{hole.name}_MAX").expression_variable + hole_range_definitions[var_max] = expression_parser.parse(str(hole_max[hole_index])) + prism = prism.define_constants(hole_range_definitions) + # check that all undefined constants are indeed the holes - undefined_constants = [c for c in prism.constants if not c.defined] - assert len(undefined_constants) == len(holes), "some constants were unspecified" + hole_names = [hole.name for hole in holes] + for c in prism.constants: + if not c.defined: + assert c.name in hole_names, f"constant {c.name} was not specified" # convert single-valued holes to a defined constant trivial_holes_definitions = {}