This repository has been archived by the owner on Sep 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 157
Dreamer #151
Draft
Rohan138
wants to merge
31
commits into
facebookresearch:main
Choose a base branch
from
Rohan138:dreamer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Dreamer #151
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
42cbfbc
added skeleton
veds12 51f4c21
Merge https://github.com/veds12/mbrl-lib into dreamer
Rohan138 52218db
update gym and test commit hooks
Rohan138 3fb3ee9
Initial commit
Rohan138 630c1ce
pre-commit fixes
Rohan138 d591586
fixed pyproject.toml
Rohan138 85e7c14
dreamer core; bug fixes
Rohan138 fd38c62
dtype fix
Rohan138 f44d500
remove breakpoint
Rohan138 6b845fb
working on config
Rohan138 a6e6a2a
wip
Rohan138 4852125
Finish dreamer loss
Rohan138 d54a0e1
Add Dreamer to README
Rohan138 64441ca
Added config yamls
Rohan138 3ffcf2d
rename pyproject
Rohan138 656a6ce
Make saving replay buffer optional
Rohan138 03d383d
drop deprecation test
Rohan138 da7f83c
Fix num_grad_updates
Rohan138 3b4a00c
Add policy and critic loss to metrics
Rohan138 68d55f5
Freeze planet during dreamer train
Rohan138 e9b7196
Merge branch 'main' of https://github.com/facebookresearch/mbrl-lib i…
Rohan138 e37958b
Merge branch 'main' into dreamer
Rohan138 3fb69be
wip
Rohan138 3215e6e
wip
Rohan138 65cfa05
wip
Rohan138 85447e6
wip
Rohan138 b5beaa4
wip
Rohan138 71b54eb
wip
Rohan138 ca1604b
wip
Rohan138 39c7845
Merge branch 'main' into dreamer
Rohan138 84bda8d
wip
Rohan138 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import os | ||
import pathlib | ||
from typing import List, Optional, Union | ||
|
||
import gym | ||
import hydra | ||
import numpy as np | ||
import omegaconf | ||
import torch | ||
from tqdm import tqdm | ||
|
||
import mbrl.constants | ||
from mbrl.env.termination_fns import no_termination | ||
from mbrl.models import ModelEnv, ModelTrainer, PlaNetModel | ||
from mbrl.planning import DreamerAgent, RandomAgent, create_dreamer_agent_for_model | ||
from mbrl.util import Logger | ||
from mbrl.util.common import ( | ||
create_replay_buffer, | ||
get_sequence_buffer_iterator, | ||
rollout_agent_trajectories, | ||
) | ||
|
||
METRICS_LOG_FORMAT = [ | ||
("observations_loss", "OL", "float"), | ||
("reward_loss", "RL", "float"), | ||
("gradient_norm", "GN", "float"), | ||
("kl_loss", "KL", "float"), | ||
("policy_loss", "PL", "float"), | ||
("critic_loss", "CL", "float"), | ||
] | ||
|
||
|
||
def train( | ||
env: gym.Env, | ||
cfg: omegaconf.DictConfig, | ||
silent: bool = False, | ||
work_dir: Union[Optional[str], pathlib.Path] = None, | ||
) -> np.float32: | ||
# Experiment initialization | ||
debug_mode = cfg.get("debug_mode", False) | ||
|
||
if work_dir is None: | ||
work_dir = os.getcwd() | ||
work_dir = pathlib.Path(work_dir) | ||
print(f"Results will be saved at {work_dir}.") | ||
|
||
if silent: | ||
logger = None | ||
else: | ||
logger = Logger(work_dir) | ||
logger.register_group("metrics", METRICS_LOG_FORMAT, color="yellow") | ||
logger.register_group( | ||
mbrl.constants.RESULTS_LOG_NAME, | ||
[ | ||
("env_step", "S", "int"), | ||
("train_episode_reward", "RT", "float"), | ||
("episode_reward", "ET", "float"), | ||
], | ||
color="green", | ||
) | ||
|
||
rng = torch.Generator(device=cfg.device) | ||
rng.manual_seed(cfg.seed) | ||
np_rng = np.random.default_rng(seed=cfg.seed) | ||
|
||
# Create replay buffer and collect initial data | ||
replay_buffer = create_replay_buffer( | ||
cfg, | ||
env.observation_space.shape, | ||
env.action_space.shape, | ||
collect_trajectories=True, | ||
rng=np_rng, | ||
) | ||
rollout_agent_trajectories( | ||
env, | ||
cfg.algorithm.num_initial_trajectories, | ||
RandomAgent(env), | ||
agent_kwargs={}, | ||
replay_buffer=replay_buffer, | ||
collect_full_trajectories=True, | ||
trial_length=cfg.overrides.trial_length, | ||
agent_uses_low_dim_obs=False, | ||
) | ||
|
||
# Create PlaNet model | ||
cfg.dynamics_model.action_size = env.action_space.shape[0] | ||
planet: PlaNetModel = hydra.utils.instantiate(cfg.dynamics_model) | ||
model_env = ModelEnv(env, planet, no_termination, generator=rng) | ||
trainer = ModelTrainer(planet, logger=logger, optim_lr=1e-3, optim_eps=1e-4) | ||
|
||
# Create Dreamer agent | ||
# This agent rolls outs trajectories using ModelEnv, which uses planet.sample() | ||
# to simulate the trajectories from the prior transition model | ||
# The starting point for trajectories is each imagined state output by the | ||
# representation model from the dataset of environment observations | ||
agent: DreamerAgent = create_dreamer_agent_for_model( | ||
planet, model_env, cfg.algorithm.agent | ||
) | ||
|
||
# Callback and containers to accumulate training statistics and average over batch | ||
rec_losses: List[float] = [] | ||
reward_losses: List[float] = [] | ||
policy_losses: List[float] = [] | ||
critic_losses: List[float] = [] | ||
kl_losses: List[float] = [] | ||
model_grad_norms: List[float] = [] | ||
agent_grad_norms: List[float] = [] | ||
|
||
def get_metrics_and_clear_metric_containers(): | ||
metrics_ = { | ||
"observations_loss": np.mean(rec_losses).item(), | ||
"reward_loss": np.mean(reward_losses).item(), | ||
"policy_loss": np.mean(policy_losses).item(), | ||
"critic_loss": np.mean(critic_losses).item(), | ||
"model_gradient_norm": np.mean(model_grad_norms).item(), | ||
"agent_gradient_norm": np.mean(agent_grad_norms).item(), | ||
"kl_loss": np.mean(kl_losses).item(), | ||
} | ||
|
||
for c in [ | ||
rec_losses, | ||
reward_losses, | ||
policy_losses, | ||
critic_losses, | ||
kl_losses, | ||
model_grad_norms, | ||
agent_grad_norms, | ||
]: | ||
c.clear() | ||
|
||
return metrics_ | ||
|
||
def model_batch_callback(_epoch, _loss, meta, _mode): | ||
if meta: | ||
rec_losses.append(meta["observations_loss"]) | ||
reward_losses.append(meta["reward_loss"]) | ||
kl_losses.append(meta["kl_loss"]) | ||
if "grad_norm" in meta: | ||
model_grad_norms.append(meta["grad_norm"]) | ||
|
||
def agent_batch_callback(_epoch, _loss, meta, _mode): | ||
if meta: | ||
policy_losses.append(meta["policy_loss"]) | ||
critic_losses.append(meta["critic_loss"]) | ||
if "grad_norm" in meta: | ||
agent_grad_norms.append(meta["grad_norm"]) | ||
|
||
def is_test_episode(episode_): | ||
return episode_ % cfg.algorithm.test_frequency == 0 | ||
|
||
# Dreamer loop | ||
step = replay_buffer.num_stored | ||
total_rewards = 0.0 | ||
for episode in tqdm(range(cfg.algorithm.num_episodes)): | ||
# Train the model for one epoch of `num_grad_updates` | ||
dataset, _ = get_sequence_buffer_iterator( | ||
replay_buffer, | ||
cfg.overrides.batch_size, | ||
0, # no validation data | ||
cfg.overrides.sequence_length, | ||
max_batches_per_loop_train=cfg.overrides.num_grad_updates, | ||
use_simple_sampler=True, | ||
) | ||
trainer.train( | ||
dataset, num_epochs=1, batch_callback=model_batch_callback, evaluate=False | ||
) | ||
agent.train(dataset, num_epochs=1, batch_callback=agent_batch_callback) | ||
planet.save(work_dir) | ||
agent.save(work_dir) | ||
if cfg.overrides.get("save_replay_buffer", False): | ||
replay_buffer.save(work_dir) | ||
metrics = get_metrics_and_clear_metric_containers() | ||
logger.log_data("metrics", metrics) | ||
|
||
# Collect one episode of data | ||
episode_reward = 0.0 | ||
obs = env.reset() | ||
agent.reset() | ||
planet.reset_posterior() | ||
action = None | ||
done = False | ||
pbar = tqdm(total=500) | ||
while not done: | ||
latent_state = planet.update_posterior(obs, action=action, rng=rng) | ||
action_noise = ( | ||
0 | ||
if is_test_episode(episode) | ||
else cfg.overrides.action_noise_std | ||
* np_rng.standard_normal(env.action_space.shape[0]) | ||
) | ||
action = agent.act(latent_state) | ||
action = action.detach().cpu().squeeze(0).numpy() | ||
action = action + action_noise | ||
action = np.clip( | ||
action, -1.0, 1.0, dtype=env.action_space.dtype | ||
) # to account for the noise and fix dtype | ||
next_obs, reward, done, info = env.step(action) | ||
replay_buffer.add(obs, action, next_obs, reward, done) | ||
episode_reward += reward | ||
obs = next_obs | ||
if debug_mode: | ||
print(f"step: {step}, reward: {reward}.") | ||
step += 1 | ||
pbar.update(1) | ||
pbar.close() | ||
total_rewards += episode_reward | ||
logger.log_data( | ||
mbrl.constants.RESULTS_LOG_NAME, | ||
{ | ||
"episode_reward": episode_reward * is_test_episode(episode), | ||
"train_episode_reward": episode_reward * (1 - is_test_episode(episode)), | ||
"env_step": step, | ||
}, | ||
) | ||
|
||
# returns average episode reward (e.g., to use for tuning learning curves) | ||
return total_rewards / cfg.algorithm.num_episodes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# @package _group_ | ||
name: "dreamer" | ||
|
||
agent: | ||
_target_: mbrl.planning.DreamerAgent | ||
action_lb: ??? | ||
action_ub: ??? | ||
horizon: 15 | ||
policy_lr: 0.00008 | ||
critic_lr: 0.00008 | ||
gamma: 0.99 | ||
lam: 0.95 | ||
grad_clip_norm: 100.0 | ||
min_std: 0.0001 | ||
init_std: 5 | ||
mean_scale: 5 | ||
activation_function: "elu" | ||
device: ${device} | ||
|
||
num_initial_trajectories: 5 | ||
action_noise_std: 0.3 | ||
test_frequency: 25 | ||
num_episodes: 1000 | ||
dataset_size: 1000000 |
30 changes: 30 additions & 0 deletions
30
mbrl/examples/conf/overrides/dreamer_cartpole_balance.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# @package _group_ | ||
env: "dmcontrol_cartpole_balance" # used to set the hydra dir, ignored otherwise | ||
|
||
env_cfg: | ||
_target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper" | ||
domain_name: "cartpole" | ||
task_name: "balance" | ||
task_kwargs: | ||
random: ${seed} | ||
visualize_reward: false | ||
from_pixels: true | ||
height: 64 | ||
width: 64 | ||
frame_skip: 2 | ||
bit_depth: 5 | ||
|
||
term_fn: "no_termination" | ||
|
||
# General configuration overrides | ||
trial_length: 500 | ||
action_noise_std: 0.3 | ||
|
||
# Model overrides | ||
num_grad_updates: 100 | ||
sequence_length: 50 | ||
batch_size: 50 | ||
free_nats: 3 | ||
kl_scale: 1.0 | ||
|
||
# Dreamer configuration overrides |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# @package _group_ | ||
env: "dmcontrol_cheetah_run" # used to set the hydra dir, ignored otherwise | ||
|
||
env_cfg: | ||
_target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper" | ||
domain_name: "cheetah" | ||
task_name: "run" | ||
task_kwargs: | ||
random: ${seed} | ||
visualize_reward: false | ||
from_pixels: true | ||
height: 64 | ||
width: 64 | ||
frame_skip: 2 | ||
bit_depth: 5 | ||
|
||
term_fn: "no_termination" | ||
|
||
# General configuration overrides | ||
trial_length: 500 | ||
action_noise_std: 0.3 | ||
|
||
# Model overrides | ||
num_grad_updates: 100 | ||
sequence_length: 50 | ||
batch_size: 50 | ||
free_nats: 3 | ||
kl_scale: 1.0 | ||
|
||
# Dreamer configuration overrides |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should be passing a different iterator for training the agent. If I understand the paper correctly, the Dreamer agent is trained on trajectories whose start states are sampled from the experience buffer, but where all subsequent states are obtained by rolling out the model. In this case, we only need to sample individual transitions to get start states, and not full sequences, which is what
dataset
would return here.If what I said above is correct, then maybe the cleanest would be to modify
DreamerAgent.train()
to directly receivereplay_buffer
and also an additional parameter callednum_updates
. Then the agent train code can loopnum_updates
times , each time doing 1)replay_buffer.sample(batch_size)
, 2) roll out the planet model with a batch of start states, 3) update the agent parameters.Does the above make sense? Let me know if I'm missing something or if anything is unclear. I guess your current code is serving more data to the Dreamer agent, but seems like it'd be easier to make a mistake with the current implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have misunderstood the paper, but I'm not sure this is correct. In Algorithm 1 (Page 3), they:
B
data sequences or episodes{(a_t, o_t, r_t)} ~_{t=k}^{k+L}
. Herek
is the outer variable looping over episodes, whilet
is the inner variable looping over timesteps in an episode.s_t
for allt
in[k, k + L)
for allk
inB
using the RSSM transition model.{s_\tau, a_\tau}_{\tau = t}^{\tau = t + H}
from each states_t
in B, not just the initial states_k
in each episode.I'm not sure if this explanation was clear, and I'll take another look at the prior implementations linked in the other comment to confirm.
We do have a minor divergence+performance hit currently-Instead of computing the model states just once as in the paper and references, we're running the forward+backprop on the model in
model.train()
, then running the forward pass again inself.planet_model._process_batch(...)
inagent.train()
. I haven't figured out a way to cleanly fix this yet-maybe return the states frommodel.train()
? Or append them to the TransitionIterator somehow?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the paper again, I think your interpretation is correct because the
Compute model states
step occurs for all sampledo_t
, and then trajectories are imagined for all model statess_t
. I find it a bit confusing how they are using the indexk
; I guess this increasing in increments of sizeL
? That is, the j-th trajectory goes fromt=L*(j-1)
toL*j - 1
? In any case, confirming with prior implementations is a good idea.Regarding the performance hit, one idea that wouldn't require a lot of changes would be to add get/set methods for random state of the iterator, so that we can have it return the same set of samples both for the model and agent loops. We should then be able to use the model trainer callback to store all computed model states, and pass them to the agent trainer in the correct order.
Does that make sense?