-
Notifications
You must be signed in to change notification settings - Fork 157
Dreamer #151
base: main
Are you sure you want to change the base?
Dreamer #151
Conversation
Hi @Rohan138, took a quick high-level glance and so far it looks good. I will start looking at the code in I noticed you added some changes to other files like the pre-commit config, the requirements, pyproject.toml, etc. Were you running into some errors? If that's the case, would you mind opening a separate PR for these? |
@Rohan138 trying to pull from your fork as shown below, but I run into access permission errors. Could you see if I can get read access? git checkout -b Rohan138-dreamer main
git pull [email protected]:Rohan138/mbrl-lib.git dreamer |
I can add you to my fork, but do you want to add my repo as a remote and try |
Just sent you a contributor invite; maybe |
HTTP worked for me before accepting the invitation, thanks! |
On the non-Dreamer fixes:
I can definitely move these to a different PR, though. |
Another PR for these would be great, so that we can merge them w/o waiting for this more involved PR to be ready. Thanks! |
@Rohan138 planning to spend most of today and then Friday playing around with your code. Is there anything in particular you'd like for me to focus on or help with? It seems I'm able to run Dreamer, but I haven't checked if it learns correctly yet. What's the current status? |
Hi! So I tried running it, but despite the losses dropping, it doesn't seem to learn right now. Here are the results; I'm planning to look through the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still looking into things, but left some initial comments. I'm also wondering about the way you are computing the value estimates in _compute_return
, but need to look into this more carefully. Are you confident about that part of the code? I was thinking of maybe spending some time on Friday to write a utility function for this and maybe add some unit tests. Let me know if this would be useful.
rewards, | ||
) = self.planet_model(obs[:, 1:], actions[:, :-1], rewards[:, :-1]) | ||
|
||
for epoch in range(num_epochs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the following might be clearer, since it doesn't seem that you use beliefs
and latents
except for initializing the state for unrolling.
B, L, _ = beliefs.shape
for epoch in range(num_epochs):
imag_beliefs = []
imag_latents = []
imag_actions = []
imag_rewards = []
states = {
"belief": beliefs.reshape(B * L, -1),
"latent": latents.reshape(B * L, -1),
}
for i in range(self.horizon):
...
mbrl/planning/dreamer_agent.py
Outdated
imag_rewards.append(rewards) | ||
|
||
# I x (B*L) x _ | ||
imag_beliefs = torch.stack(imag_beliefs).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is to(self.device)
necessary? These are computed from tensors that should already be on self.device
at this point.
mbrl/planning/dreamer_agent.py
Outdated
imag_beliefs = torch.stack(imag_beliefs).to(self.device) | ||
imag_latents = torch.stack(imag_latents).to(self.device) | ||
imag_actions = torch.stack(imag_actions).to(self.device) | ||
freeze(self.critic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious about the use of freeze(self.critic)
instead of with torch.no_grad
, since the next line only calls the critic and no other parameters are being used.
mbrl/planning/dreamer_agent.py
Outdated
""" | ||
next_values = torch.cat([value[1:], bootstrap[None]], 0) | ||
target = reward + discount * next_values * (1 - lambda_) | ||
timesteps = list(range(reward.shape[0] - 1, -1, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to make list
here, since you only need the iterator.
trainer.train( | ||
dataset, num_epochs=1, batch_callback=model_batch_callback, evaluate=False | ||
) | ||
agent.train(dataset, num_epochs=1, batch_callback=agent_batch_callback) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should be passing a different iterator for training the agent. If I understand the paper correctly, the Dreamer agent is trained on trajectories whose start states are sampled from the experience buffer, but where all subsequent states are obtained by rolling out the model. In this case, we only need to sample individual transitions to get start states, and not full sequences, which is what dataset
would return here.
If what I said above is correct, then maybe the cleanest would be to modify DreamerAgent.train()
to directly receive replay_buffer
and also an additional parameter called num_updates
. Then the agent train code can loop num_updates
times , each time doing 1) replay_buffer.sample(batch_size)
, 2) roll out the planet model with a batch of start states, 3) update the agent parameters.
Does the above make sense? Let me know if I'm missing something or if anything is unclear. I guess your current code is serving more data to the Dreamer agent, but seems like it'd be easier to make a mistake with the current implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should be passing a different iterator for training the agent. If I understand the paper correctly, the Dreamer agent is trained on trajectories whose start states are sampled from the experience buffer, but where all subsequent states are obtained by rolling out the model. In this case, we only need to sample individual transitions to get start states, and not full sequences, which is what
dataset
would return here.
I might have misunderstood the paper, but I'm not sure this is correct. In Algorithm 1 (Page 3), they:
- Draw
B
data sequences or episodes{(a_t, o_t, r_t)} ~_{t=k}^{k+L}
. Herek
is the outer variable looping over episodes, whilet
is the inner variable looping over timesteps in an episode. - Compute model states
s_t
for allt
in[k, k + L)
for allk
inB
using the RSSM transition model. - Imagine trajectories
{s_\tau, a_\tau}_{\tau = t}^{\tau = t + H}
from each states_t
in B, not just the initial states_k
in each episode.
I'm not sure if this explanation was clear, and I'll take another look at the prior implementations linked in the other comment to confirm.
We do have a minor divergence+performance hit currently-Instead of computing the model states just once as in the paper and references, we're running the forward+backprop on the model in model.train()
, then running the forward pass again in self.planet_model._process_batch(...)
in agent.train()
. I haven't figured out a way to cleanly fix this yet-maybe return the states from model.train()
? Or append them to the TransitionIterator somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the paper again, I think your interpretation is correct because the Compute model states
step occurs for all sampled o_t
, and then trajectories are imagined for all model states s_t
. I find it a bit confusing how they are using the index k
; I guess this increasing in increments of size L
? That is, the j-th trajectory goes from t=L*(j-1)
to L*j - 1
? In any case, confirming with prior implementations is a good idea.
Regarding the performance hit, one idea that wouldn't require a lot of changes would be to add get/set methods for random state of the iterator, so that we can have it return the same set of samples both for the model and agent loops. We should then be able to use the model trainer callback to store all computed model states, and pass them to the agent trainer in the correct order.
Does that make sense?
action_noise_std: 0.3 | ||
test_frequency: 25 | ||
num_episodes: 1000 | ||
dataset_size: 1000000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline at end of file.
I'm thinking about starting a project that builds off the Dreamer style of dynamics model. If I spend a couple hours here and there on this PR, what would be most useful? |
Sorry for the delay-I'll try to address all of the comments across the next day or so. I moved the non-Dreamer fixes to #161. The @natolambert-The core dreamer implementation is in the |
Signed-off-by: Rohan138 <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
@luisenp @Rohan138 -- is there anything I or @RajGhugare19 can do to get this moving again? |
Hi @natolambert. Unfortunately, it's almost impossible for me at this point to take the lead in development, due to other more pressing commitments. But I'm happy to support with reviews, general advice, and some amount of coding, if someone else is willing to drive this feature to completion. |
Gotcha, so I'm guessing it's at the point where there are small issues and need to verify performance? @luisenp |
There were some comments I left early that I'm not sure were addressed (mostly high level stuff). But leaving that aside I don't think the implementation was fully working yet, @Rohan138 would have more details though. |
Great. I want to take a look, and I have chatted with @danijar who didn't know it was being worked on. Let's see if I can un-stick it and if needed talk to Danijar. |
Hello @natolambert -- I can take a lead developing this. You can review and sanity check the code afterwards. I will take a deeper look at the code and what changes are still required today. |
Pitching in-I can help answer questions and debug the implementation over the weekend. The main function I'm unsure about is the DreamerAgent.train(...) function here. There's also some minor conflicts due to gym versions and gym's type checking that seem to be breaking CI; there's an open PR #161 here. |
Types of changes
Still a WIP, but I've managed to add most of Dreamer. The main thing left is computing the loss in planning/dreamer_agent/DreamerAgent::train().
Motivation and Context / Related issue
How Has This Been Tested (if it applies)
Checklist