Skip to content

Commit

Permalink
construct model union
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Nov 14, 2024
1 parent bf64fa2 commit 31ec5d7
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 13 deletions.
2 changes: 1 addition & 1 deletion paynt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import paynt.utils.timer
import paynt.parser.sketch

import paynt.quotient
import paynt.quotient.quotient
import paynt.quotient.pomdp
import paynt.quotient.decpomdp
import paynt.quotient.storm_pomdp_control
Expand Down
6 changes: 1 addition & 5 deletions payntbind/lib/payntbind/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import sys

if sys.version_info[0] == 2:
raise ImportError('Python 2.x is not supported for stormpy.')

import stormpy
from .synthesis import *

__version__ = "unknown"
Expand Down
1 change: 0 additions & 1 deletion payntbind/lib/payntbind/synthesis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from . import synthesis
from .synthesis import *
2 changes: 1 addition & 1 deletion payntbind/src/synthesis/translation/SubPomdpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace synthesis {
auto translated_num_states = state_translator.numTranslations();
auto translated_num_choices = choice_translator.numTranslations();
storm::storage::SparseMatrixBuilder<ValueType> builder(
translated_num_choices, translated_num_states, 0, true, true, translated_num_states
translated_num_choices, translated_num_states, 0, false, true, translated_num_states
);
for(uint64_t translated_state = 0; translated_state < translated_num_states; ++translated_state) {
if(translated_state == translated_initial_state) {
Expand Down
7 changes: 4 additions & 3 deletions payntbind/src/synthesis/translation/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#include "src/synthesis/translation/componentTranslations.h"
#include "src/synthesis/translation/choiceTransformation.h"

#include <storm/exceptions/InvalidModelException.h>
#include <storm/utility/builder.h>
#include <storm/transformer/SubsystemBuilder.h>
namespace synthesis {


}
void bindings_translation(py::module& m) {

m.def("computeChoiceDestinations", &synthesis::computeChoiceDestinations<double>);
Expand All @@ -17,6 +17,7 @@ void bindings_translation(py::module& m) {
m.def("enableAllActions", py::overload_cast<storm::models::sparse::Model<double> const&>(&synthesis::enableAllActions<double>));
m.def("restoreActionsInAbsorbingStates", &synthesis::restoreActionsInAbsorbingStates<double>);
m.def("addDontCareAction", &synthesis::addDontCareAction<double>);
m.def("createModelUnion", &synthesis::createModelUnion<double>);

py::class_<synthesis::SubPomdpBuilder<double>, std::shared_ptr<synthesis::SubPomdpBuilder<double>>>(m, "SubPomdpBuilder")
.def(py::init<storm::models::sparse::Pomdp<double> const&>())
Expand Down
143 changes: 141 additions & 2 deletions payntbind/src/synthesis/translation/choiceTransformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

#include "src/synthesis/translation/componentTranslations.h"

#include <storm/exceptions/InvalidArgumentException.h>
#include <storm/exceptions/InvalidModelException.h>
#include <storm/exceptions/NotSupportedException.h>
#include <storm/exceptions/UnexpectedException.h>
#include <storm/exceptions/InvalidArgumentException.h>
#include <storm/models/sparse/Pomdp.h>
#include <storm/utility/builder.h>
#include <storm/transformer/SubsystemBuilder.h>
#include <storm/utility/builder.h>


namespace synthesis {

Expand Down Expand Up @@ -418,6 +420,140 @@ std::shared_ptr<storm::models::sparse::Model<ValueType>> addDontCareAction(
return storm::utility::builder::buildModelFromComponents<ValueType,storm::models::sparse::StandardRewardModel<ValueType>>(model.getType(),std::move(components));
}



template<typename ValueType>
std::shared_ptr<storm::models::sparse::Model<ValueType>> createModelUnion(
std::vector<std::shared_ptr<storm::models::sparse::Model<ValueType>>> const& models
) {
uint64_t num_models = models.size();
STORM_LOG_THROW(num_models > 0, storm::exceptions::InvalidArgumentException, "the list of models is empty");

uint64_t union_initial_state = 0;
uint64_t union_num_states = 1;
uint64_t union_num_choices = 1;
std::vector<uint64_t> state_offset;
std::vector<uint64_t> choice_offset;
for(uint64_t model_index = 0; model_index < num_models; ++model_index) {
state_offset.push_back(union_num_states);
choice_offset.push_back(union_num_choices);
auto model = models[model_index];
union_num_states += model->getNumberOfStates();
union_num_choices += model->getNumberOfChoices();
}

storm::storage::sparse::ModelComponents<ValueType> components;
storm::models::sparse::StateLabeling union_state_labeling(union_num_states);
union_state_labeling.addLabel("init");
union_state_labeling.addLabelToState("init",union_initial_state);
for(uint64_t model_index = 0; model_index < num_models; ++model_index) {
auto model = models[model_index];
storm::models::sparse::StateLabeling const& state_labeling = model->getStateLabeling();
for (auto const& label : state_labeling.getLabels()) {
if(not union_state_labeling.containsLabel(label)) {
union_state_labeling.addLabel(label);
}
}
for(uint64_t state = 0; state < model->getNumberOfStates(); ++state) {
uint64_t union_state = state_offset[model_index] + state;
for(std::string const& label: state_labeling.getLabelsOfState(state)) {
if(label == "init") {
continue;
}
union_state_labeling.addLabelToState(label,union_state);
}
}
}
components.stateLabeling = union_state_labeling;

if(models[0]->getType() == storm::models::ModelType::Pomdp) {
std::vector<uint32_t> state_observation(union_num_states);
uint64_t num_observations = 0;
for(uint64_t model_index = 0; model_index < num_models; ++model_index) {
auto model = models[model_index];
auto pomdp = static_cast<storm::models::sparse::Pomdp<ValueType> const&>(*model);
if(pomdp.getNrObservations() > num_observations) {
num_observations = pomdp.getNrObservations();
}
for(uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
uint64_t union_state = state_offset[model_index] + state;
state_observation[union_state] = pomdp.getObservation(state);
}
}
state_observation[union_initial_state] = num_observations;
components.observabilityClasses = state_observation;
}

// skipping state and observation valuations

storm::models::sparse::ChoiceLabeling union_choice_labeling(union_num_choices);
union_choice_labeling.addLabel(NO_ACTION_LABEL);
union_choice_labeling.addLabelToChoice(NO_ACTION_LABEL,0);
storm::storage::SparseMatrixBuilder<ValueType> builder(
union_num_choices, union_num_states, 0, false, true, union_num_states
);
ValueType belief_uniform_prob = 1.0/num_models;
builder.newRowGroup(union_initial_state);
for(uint64_t model_index = 0; model_index < num_models; ++model_index) {
auto model = models[model_index];
uint64_t initial_state = state_offset[model_index] + *(model->getInitialStates().begin());
builder.addNextValue(union_initial_state, initial_state, belief_uniform_prob);
}
for(uint64_t model_index = 0; model_index < num_models; ++model_index) {
auto model = models[model_index];
storm::models::sparse::ChoiceLabeling const& choice_labeling = model->getChoiceLabeling();
for (auto const& label : choice_labeling.getLabels()) {
if(not union_choice_labeling.containsLabel(label)) {
union_choice_labeling.addLabel(label);
}
}

auto const& row_groups = model->getTransitionMatrix().getRowGroupIndices();
for(uint64_t state = 0; state < model->getNumberOfStates(); ++state) {
builder.newRowGroup(choice_offset[model_index]+row_groups[state]);
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
uint64_t union_choice = choice_offset[model_index]+choice;
for(auto entry: model->getTransitionMatrix().getRow(choice)) {
builder.addNextValue(union_choice, state_offset[model_index]+entry.getColumn(), entry.getValue());
}
for(std::string const& label: choice_labeling.getLabelsOfChoice(choice)) {
union_choice_labeling.addLabelToChoice(label,union_choice);
}
}
}
}
components.transitionMatrix = builder.build();
components.choiceLabeling = union_choice_labeling;
// skipping choice origins

std::map<std::string,std::vector<ValueType>> reward_models;
for(uint64_t model_index = 0; model_index < num_models; ++model_index) {
auto model = models[model_index];
for(auto const& [reward_name,reward_model] : model->getRewardModels()) {
STORM_LOG_THROW(!reward_model.hasStateRewards() and !reward_model.hasTransitionRewards() and reward_model.hasStateActionRewards(),
storm::exceptions::NotSupportedException, "expected state-action rewards");
if(reward_models.count(reward_name) == 0) {
reward_models.emplace(reward_name,std::vector<ValueType>(union_num_choices,0));
}

for(uint64_t choice = 0; choice < model->getNumberOfChoices(); ++choice) {
uint64_t union_choice = choice_offset[model_index] + choice;
reward_models[reward_name][union_choice] = reward_model.getStateActionReward(choice);
}
}
}

for(auto &[reward_name,action_rewards]: reward_models) {
std::optional<std::vector<ValueType>> state_rewards;
components.rewardModels.emplace(
reward_name, storm::models::sparse::StandardRewardModel<ValueType>(std::move(state_rewards), std::move(action_rewards))
);
}

return storm::utility::builder::buildModelFromComponents<ValueType>(models[0]->getType(),std::move(components));
}


template std::vector<std::vector<uint64_t>> computeChoiceDestinations<double>(
storm::models::sparse::Model<double> const& model);
template std::pair<std::vector<std::string>,std::vector<uint64_t>> extractActionLabels<double>(
Expand All @@ -440,5 +576,8 @@ template std::shared_ptr<storm::models::sparse::Model<double>> restoreActionsInA
storm::models::sparse::Model<double> const& model);
template std::shared_ptr<storm::models::sparse::Model<double>> addDontCareAction<double>(
storm::models::sparse::Model<double> const& model);
template std::shared_ptr<storm::models::sparse::Model<double>> createModelUnion(
std::vector<std::shared_ptr<storm::models::sparse::Model<double>>> const&
);

}
8 changes: 8 additions & 0 deletions payntbind/src/synthesis/translation/choiceTransformation.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,12 @@ std::shared_ptr<storm::models::sparse::Model<ValueType>> addDontCareAction(
storm::models::sparse::Model<ValueType> const& model
);

/**
* Create a union model with a fresh initial state simulating a uniform choice.
*/
template<typename ValueType>
std::shared_ptr<storm::models::sparse::Model<ValueType>> createModelUnion(
std::vector<std::shared_ptr<storm::models::sparse::Model<ValueType>>> const& models
);

}

0 comments on commit 31ec5d7

Please sign in to comment.