Skip to content

Commit

Permalink
posterior-unaware FSC unfolding
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Sep 17, 2024
1 parent 3e5bb27 commit ebbef0c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 38 deletions.
53 changes: 23 additions & 30 deletions payntbind/src/synthesis/pomdp_family/FscUnfolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace synthesis {
) : quotient(quotient), state_to_obs_class(state_to_obs_class),
num_actions(num_actions), choice_to_action(choice_to_action) {

this->state_translator = ItemKeyTranslator<std::tuple<uint64_t,uint64_t,bool>>();
this->state_translator = ItemKeyTranslator<std::pair<uint64_t,uint64_t>>();
this->state_action_choices.resize(this->quotient.getNumberOfStates());
std::vector<uint64_t> const& row_groups = this->quotient.getTransitionMatrix().getRowGroupIndices();
for(uint64_t state = 0; state < this->quotient.getNumberOfStates(); ++state) {
Expand Down Expand Up @@ -50,7 +50,7 @@ namespace synthesis {
uint64_t FscUnfolder<ValueType>::translateInitialState() {
uint64_t initial_state = *(this->quotient.getInitialStates().begin());
uint64_t initial_memory = 0;
return this->state_translator.translate(initial_state,std::make_tuple(initial_memory,invalidAction(),false));
return this->state_translator.translate(initial_state,std::make_pair(initial_memory,invalidAction()));
}


Expand All @@ -62,27 +62,25 @@ namespace synthesis {
this->state_translator.resize(this->quotient.getNumberOfStates());
uint64_t translated_state = this->translateInitialState();
while(true) {
auto[state,memory_action_transitioned] = this->state_translator.retrieve(translated_state);
auto[memory,action,transitioned] = memory_action_transitioned;
auto[state,memory_action] = this->state_translator.retrieve(translated_state);
auto[memory,action] = memory_action;
uint64_t observation = this->state_to_obs_class[state];
if(action == invalidAction() and not transitioned) {
if(action == invalidAction()) {
// random choice of an action
for(auto [action,prob] : action_function[memory][observation]) {
this->state_translator.translate(state,std::make_tuple(memory,action,false));
for(auto [action,_] : action_function[memory][observation]) {
this->state_translator.translate(state,std::make_pair(memory,action));
}
} else if(action != invalidAction()) {
} else { // action != invalidAction()) {
// executing variants of the selected actions
for(uint64_t choice: this->state_action_choices[state][action]) {
for(auto const &entry: this->quotient.getTransitionMatrix().getRow(choice)) {
uint64_t state_dst = entry.getColumn();
this->state_translator.translate(state_dst,std::make_tuple(memory,invalidAction(),true));
// executing memory update
for(auto [memory_dst,_] : update_function[memory][observation]) {
this->state_translator.translate(state_dst,std::make_pair(memory_dst,invalidAction()));
}
}
}
} else { // action == invalidAction() and transitioned
// executing memory update
for(auto [memory_dst,prob] : update_function[memory][observation]) {
this->state_translator.translate(state,std::make_tuple(memory_dst,invalidAction(),false));
}
}
translated_state++;
if(translated_state >= numberOfTranslatedStates()) {
Expand All @@ -91,7 +89,7 @@ namespace synthesis {
}

this->product_state_to_state = this->state_translator.translationToItem();
this->product_state_to_state_memory_action_transitioned = this->state_translator.translationToItemKey();
// this->product_state_to_state_memory_action = this->state_translator.translationToItemKey();
}

template<typename ValueType>
Expand All @@ -103,36 +101,31 @@ namespace synthesis {
storm::storage::SparseMatrixBuilder<ValueType> builder(0, 0, 0, false, true, 0);
for(uint64_t translated_state = 0; translated_state < numberOfTranslatedStates(); ++translated_state) {
builder.newRowGroup(numberOfTranslatedChoices());
auto[state,memory_action_transitioned] = this->state_translator.retrieve(translated_state);
auto[memory,action,transitioned] = memory_action_transitioned;
auto[state,memory_action] = this->state_translator.retrieve(translated_state);
auto[memory,action] = memory_action;
uint64_t observation = this->state_to_obs_class[state];
if(action == invalidAction() and not transitioned) {
if(action == invalidAction()) {
// random choice of an action
uint64_t product_choice = numberOfTranslatedChoices();
this->product_choice_to_choice.push_back(invalidChoice());
for(auto [action,prob] : action_function[memory][observation]) {
uint64_t translated_dst = this->state_translator.translate(state,std::make_tuple(memory,action,false));
uint64_t translated_dst = this->state_translator.translate(state,std::make_pair(memory,action));
builder.addNextValue(product_choice, translated_dst, prob);
}
} else if(action != invalidAction()) {
} else { // action == invalidAction()
// executing variants of the selected actions
for(uint64_t choice: this->state_action_choices[state][action]) {
uint64_t product_choice = numberOfTranslatedChoices();
this->product_choice_to_choice.push_back(choice);
for(auto const &entry: this->quotient.getTransitionMatrix().getRow(choice)) {
uint64_t state_dst = entry.getColumn();
uint64_t translated_dst = this->state_translator.translate(state_dst,std::make_tuple(memory,invalidAction(),true));
builder.addNextValue(product_choice, translated_dst, entry.getValue());
// executing memory update
for(auto [memory_dst,prob] : update_function[memory][observation]) {
uint64_t translated_dst = this->state_translator.translate(state_dst,std::make_pair(memory_dst,invalidAction()));
builder.addNextValue(product_choice, translated_dst, entry.getValue()*prob);
}
}
}
} else { // action == invalidAction() and transitioned
// executing memory update
uint64_t product_choice = numberOfTranslatedChoices();
this->product_choice_to_choice.push_back(invalidChoice());
for(auto [memory_dst,prob] : update_function[memory][observation]) {
uint64_t translated_dst = this->state_translator.translate(state,std::make_tuple(memory_dst,invalidAction(),false));
builder.addNextValue(product_choice, translated_dst, prob);
}
}
}

Expand Down
15 changes: 7 additions & 8 deletions payntbind/src/synthesis/pomdp_family/FscUnfolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace synthesis {
* Create a product of the quotient POMDP and the given FSC.
* @param action_function for each node in the FSC and for each observation class, a dictionary containing
* entries (action,probability)
* @param action_function for each node in the FSC and for each (posterior) observation class, a dictionary
* @param action_function for each node in the FSC and for each observation class, a dictionary
* containing entries (memory,probability)
*/
void applyFsc(
Expand All @@ -41,8 +41,8 @@ namespace synthesis {
std::vector<uint64_t> product_choice_to_choice;
/** For each state of the product MDP, the original state. */
std::vector<uint64_t> product_state_to_state;
/** For each state of the product MDP, the correponding state-memory-action-transitioned tuple. */
std::vector<std::pair<uint64_t,std::tuple<uint64_t,uint64_t,bool>>> product_state_to_state_memory_action_transitioned;
/** For each state of the product MDP, the correponding state-memory-action tuple. */
// std::vector<std::pair<uint64_t,std::pair<uint64_t,uint64_t>>> product_state_to_state_memory_action;


private:
Expand All @@ -62,12 +62,11 @@ namespace synthesis {
uint64_t invalidChoice();

/**
* Each state is a tuple (s,n,act,tr) with the following semantics:
* - from state (s,n,-,-), an action act is selected according to gamma(n,O(s)), transitioning to (s,n,act,-)
* - from state (s,n,act,-), a variant of action act is executed, transitioning to (s',n,-,+)
* - from state (s',n,-,+), a memory update n' is selected according delta(n,O(s')), transitioning to (s',n',-,-)
* Each state is a tuple (s,n,act) with the following semantics:
* - from state (s,n,-), an action act is selected according to gamma(n,O(s)), transitioning to (s,n,act)
* - from state (s,n,act), a variant of action act is executed and n' is selected according to delta(n,O(s)), transitioning to (s',n',-)
**/
ItemKeyTranslator<std::tuple<uint64_t,uint64_t,bool>> state_translator;
ItemKeyTranslator<std::pair<uint64_t,uint64_t>> state_translator;
uint64_t translateInitialState();
uint64_t numberOfTranslatedStates();
uint64_t numberOfTranslatedChoices();
Expand Down

0 comments on commit ebbef0c

Please sign in to comment.