Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] How to add an action mask to DiscreteCQL algorithm? #356

Open
rgraziosi-fbk opened this issue Dec 4, 2023 · 2 comments
Open

Comments

@rgraziosi-fbk
Copy link

Hi everyone!

I'm trying to implement action masking for the discrete CQL algorithm, i.e. I'd like to make some actions impossible to choose given some conditions on the current observation.

At inference time it should be easy, because predict_value can be used to get action values for every possible action, then the mask could be used to filter out the impossible actions, and finally argmax can be used to get the action to execute.

However, I'm unsure on how to implement action masking during training. Is there any way to do that without changing d3rlpy source code? If not, could you please give me some hints about which parts of the codebase should be changed to achieve this?

Thank you a lot in advance!

@rgraziosi-fbk rgraziosi-fbk changed the title How to add an action mask to DiscreteCQL algorithm? [QUESTION] How to add an action mask to DiscreteCQL algorithm? Dec 4, 2023
@Lucien-Evans-123
Copy link

I also want to ask this question!

@takuseno
Copy link
Owner

takuseno commented Dec 6, 2023

@rgraziosi-fbk Thanks for the issue. I assume that you want to mask actions at bootstrap target calculation. In that case, you need to modify this action selection here:

action = self.inner_predict_best_action(batch.next_observations)

This method is inherited up to DiscreteCQL. If you change this method, DiscreteCQL will be also modified.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants