Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Dreamer #151

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
42cbfbc
added skeleton
veds12 Feb 14, 2022
51f4c21
Merge https://github.com/veds12/mbrl-lib into dreamer
Rohan138 Apr 11, 2022
52218db
update gym and test commit hooks
Rohan138 Apr 11, 2022
3fb3ee9
Initial commit
Rohan138 Apr 11, 2022
630c1ce
pre-commit fixes
Rohan138 Apr 11, 2022
d591586
fixed pyproject.toml
Rohan138 Apr 11, 2022
85e7c14
dreamer core; bug fixes
Rohan138 Apr 12, 2022
fd38c62
dtype fix
Rohan138 Apr 12, 2022
f44d500
remove breakpoint
Rohan138 Apr 12, 2022
6b845fb
working on config
Rohan138 Apr 14, 2022
a6e6a2a
wip
Rohan138 May 30, 2022
4852125
Finish dreamer loss
Rohan138 Jun 6, 2022
d54a0e1
Add Dreamer to README
Rohan138 Jun 6, 2022
64441ca
Added config yamls
Rohan138 Jun 7, 2022
3ffcf2d
rename pyproject
Rohan138 Jun 11, 2022
656a6ce
Make saving replay buffer optional
Rohan138 Jun 12, 2022
03d383d
drop deprecation test
Rohan138 Jun 12, 2022
da7f83c
Fix num_grad_updates
Rohan138 Jun 12, 2022
3b4a00c
Add policy and critic loss to metrics
Rohan138 Jun 12, 2022
68d55f5
Freeze planet during dreamer train
Rohan138 Jun 15, 2022
e9b7196
Merge branch 'main' of https://github.com/facebookresearch/mbrl-lib i…
Rohan138 Aug 3, 2022
e37958b
Merge branch 'main' into dreamer
Rohan138 Aug 3, 2022
3fb69be
wip
Rohan138 Aug 3, 2022
3215e6e
wip
Rohan138 Aug 3, 2022
65cfa05
wip
Rohan138 Aug 3, 2022
85447e6
wip
Rohan138 Aug 3, 2022
b5beaa4
wip
Rohan138 Aug 3, 2022
71b54eb
wip
Rohan138 Aug 3, 2022
ca1604b
wip
Rohan138 Aug 9, 2022
39c7845
Merge branch 'main' into dreamer
Rohan138 Aug 9, 2022
84bda8d
wip
Rohan138 Feb 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ as examples of how to use this library. You can find them in the
[mbrl/algorithms](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms) folder. Currently, we have implemented
[PETS](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/pets.py),
[MBPO](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/mbpo.py),
[PlaNet](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/planet.py),
we plan to keep increasing this list in the future.
[PlaNet](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/planet.py),
[Dreamer](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/dreamer.py), we plan to keep increasing this list in the future.

The implementations rely on [Hydra](https://github.com/facebookresearch/hydra)
to handle configuration. You can see the configuration files in
Expand Down
221 changes: 221 additions & 0 deletions mbrl/algorithms/dreamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import pathlib
from typing import List, Optional, Union

import gym
import hydra
import numpy as np
import omegaconf
import torch
from tqdm import tqdm

import mbrl.constants
from mbrl.env.termination_fns import no_termination
from mbrl.models import ModelEnv, ModelTrainer, PlaNetModel
from mbrl.planning import DreamerAgent, RandomAgent, create_dreamer_agent_for_model
from mbrl.util import Logger
from mbrl.util.common import (
create_replay_buffer,
get_sequence_buffer_iterator,
rollout_agent_trajectories,
)

METRICS_LOG_FORMAT = [
("observations_loss", "OL", "float"),
("reward_loss", "RL", "float"),
("gradient_norm", "GN", "float"),
("kl_loss", "KL", "float"),
("policy_loss", "PL", "float"),
("critic_loss", "CL", "float"),
]


def train(
env: gym.Env,
cfg: omegaconf.DictConfig,
silent: bool = False,
work_dir: Union[Optional[str], pathlib.Path] = None,
) -> np.float32:
# Experiment initialization
debug_mode = cfg.get("debug_mode", False)

if work_dir is None:
work_dir = os.getcwd()
work_dir = pathlib.Path(work_dir)
print(f"Results will be saved at {work_dir}.")

if silent:
logger = None
else:
logger = Logger(work_dir)
logger.register_group("metrics", METRICS_LOG_FORMAT, color="yellow")
logger.register_group(
mbrl.constants.RESULTS_LOG_NAME,
[
("env_step", "S", "int"),
("train_episode_reward", "RT", "float"),
("episode_reward", "ET", "float"),
],
color="green",
)

rng = torch.Generator(device=cfg.device)
rng.manual_seed(cfg.seed)
np_rng = np.random.default_rng(seed=cfg.seed)

# Create replay buffer and collect initial data
replay_buffer = create_replay_buffer(
cfg,
env.observation_space.shape,
env.action_space.shape,
collect_trajectories=True,
rng=np_rng,
)
rollout_agent_trajectories(
env,
cfg.algorithm.num_initial_trajectories,
RandomAgent(env),
agent_kwargs={},
replay_buffer=replay_buffer,
collect_full_trajectories=True,
trial_length=cfg.overrides.trial_length,
agent_uses_low_dim_obs=False,
)

# Create PlaNet model
cfg.dynamics_model.action_size = env.action_space.shape[0]
planet: PlaNetModel = hydra.utils.instantiate(cfg.dynamics_model)
model_env = ModelEnv(env, planet, no_termination, generator=rng)
trainer = ModelTrainer(planet, logger=logger, optim_lr=1e-3, optim_eps=1e-4)

# Create Dreamer agent
# This agent rolls outs trajectories using ModelEnv, which uses planet.sample()
# to simulate the trajectories from the prior transition model
# The starting point for trajectories is each imagined state output by the
# representation model from the dataset of environment observations
agent: DreamerAgent = create_dreamer_agent_for_model(
planet, model_env, cfg.algorithm.agent
)

# Callback and containers to accumulate training statistics and average over batch
rec_losses: List[float] = []
reward_losses: List[float] = []
policy_losses: List[float] = []
critic_losses: List[float] = []
kl_losses: List[float] = []
model_grad_norms: List[float] = []
agent_grad_norms: List[float] = []

def get_metrics_and_clear_metric_containers():
metrics_ = {
"observations_loss": np.mean(rec_losses).item(),
"reward_loss": np.mean(reward_losses).item(),
"policy_loss": np.mean(policy_losses).item(),
"critic_loss": np.mean(critic_losses).item(),
"model_gradient_norm": np.mean(model_grad_norms).item(),
"agent_gradient_norm": np.mean(agent_grad_norms).item(),
"kl_loss": np.mean(kl_losses).item(),
}

for c in [
rec_losses,
reward_losses,
policy_losses,
critic_losses,
kl_losses,
model_grad_norms,
agent_grad_norms,
]:
c.clear()

return metrics_

def model_batch_callback(_epoch, _loss, meta, _mode):
if meta:
rec_losses.append(meta["observations_loss"])
reward_losses.append(meta["reward_loss"])
kl_losses.append(meta["kl_loss"])
if "grad_norm" in meta:
model_grad_norms.append(meta["grad_norm"])

def agent_batch_callback(_epoch, _loss, meta, _mode):
if meta:
policy_losses.append(meta["policy_loss"])
critic_losses.append(meta["critic_loss"])
if "grad_norm" in meta:
agent_grad_norms.append(meta["grad_norm"])

def is_test_episode(episode_):
return episode_ % cfg.algorithm.test_frequency == 0

# Dreamer loop
step = replay_buffer.num_stored
total_rewards = 0.0
for episode in tqdm(range(cfg.algorithm.num_episodes)):
# Train the model for one epoch of `num_grad_updates`
dataset, _ = get_sequence_buffer_iterator(
replay_buffer,
cfg.overrides.batch_size,
0, # no validation data
cfg.overrides.sequence_length,
max_batches_per_loop_train=cfg.overrides.num_grad_updates,
use_simple_sampler=True,
)
trainer.train(
dataset, num_epochs=1, batch_callback=model_batch_callback, evaluate=False
)
agent.train(dataset, num_epochs=1, batch_callback=agent_batch_callback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should be passing a different iterator for training the agent. If I understand the paper correctly, the Dreamer agent is trained on trajectories whose start states are sampled from the experience buffer, but where all subsequent states are obtained by rolling out the model. In this case, we only need to sample individual transitions to get start states, and not full sequences, which is what dataset would return here.

If what I said above is correct, then maybe the cleanest would be to modify DreamerAgent.train() to directly receive replay_buffer and also an additional parameter called num_updates. Then the agent train code can loop num_updates times , each time doing 1) replay_buffer.sample(batch_size), 2) roll out the planet model with a batch of start states, 3) update the agent parameters.

Does the above make sense? Let me know if I'm missing something or if anything is unclear. I guess your current code is serving more data to the Dreamer agent, but seems like it'd be easier to make a mistake with the current implementation?

Copy link
Contributor Author

@Rohan138 Rohan138 Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should be passing a different iterator for training the agent. If I understand the paper correctly, the Dreamer agent is trained on trajectories whose start states are sampled from the experience buffer, but where all subsequent states are obtained by rolling out the model. In this case, we only need to sample individual transitions to get start states, and not full sequences, which is what dataset would return here.

I might have misunderstood the paper, but I'm not sure this is correct. In Algorithm 1 (Page 3), they:

  1. Draw B data sequences or episodes {(a_t, o_t, r_t)} ~_{t=k}^{k+L}. Here k is the outer variable looping over episodes, while t is the inner variable looping over timesteps in an episode.
  2. Compute model states s_t for all t in [k, k + L) for all k in B using the RSSM transition model.
  3. Imagine trajectories {s_\tau, a_\tau}_{\tau = t}^{\tau = t + H} from each state s_t in B, not just the initial state s_k in each episode.

I'm not sure if this explanation was clear, and I'll take another look at the prior implementations linked in the other comment to confirm.

We do have a minor divergence+performance hit currently-Instead of computing the model states just once as in the paper and references, we're running the forward+backprop on the model in model.train(), then running the forward pass again in self.planet_model._process_batch(...) in agent.train(). I haven't figured out a way to cleanly fix this yet-maybe return the states from model.train()? Or append them to the TransitionIterator somehow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the paper again, I think your interpretation is correct because the Compute model states step occurs for all sampled o_t, and then trajectories are imagined for all model states s_t. I find it a bit confusing how they are using the index k; I guess this increasing in increments of size L? That is, the j-th trajectory goes from t=L*(j-1) to L*j - 1? In any case, confirming with prior implementations is a good idea.

Regarding the performance hit, one idea that wouldn't require a lot of changes would be to add get/set methods for random state of the iterator, so that we can have it return the same set of samples both for the model and agent loops. We should then be able to use the model trainer callback to store all computed model states, and pass them to the agent trainer in the correct order.

Does that make sense?

planet.save(work_dir)
agent.save(work_dir)
if cfg.overrides.get("save_replay_buffer", False):
replay_buffer.save(work_dir)
metrics = get_metrics_and_clear_metric_containers()
logger.log_data("metrics", metrics)

# Collect one episode of data
episode_reward = 0.0
obs = env.reset()
agent.reset()
planet.reset_posterior()
action = None
done = False
pbar = tqdm(total=500)
while not done:
latent_state = planet.update_posterior(obs, action=action, rng=rng)
action_noise = (
0
if is_test_episode(episode)
else cfg.overrides.action_noise_std
* np_rng.standard_normal(env.action_space.shape[0])
)
action = agent.act(latent_state)
action = action.detach().cpu().squeeze(0).numpy()
action = action + action_noise
action = np.clip(
action, -1.0, 1.0, dtype=env.action_space.dtype
) # to account for the noise and fix dtype
next_obs, reward, done, info = env.step(action)
replay_buffer.add(obs, action, next_obs, reward, done)
episode_reward += reward
obs = next_obs
if debug_mode:
print(f"step: {step}, reward: {reward}.")
step += 1
pbar.update(1)
pbar.close()
total_rewards += episode_reward
logger.log_data(
mbrl.constants.RESULTS_LOG_NAME,
{
"episode_reward": episode_reward * is_test_episode(episode),
"train_episode_reward": episode_reward * (1 - is_test_episode(episode)),
"env_step": step,
},
)

# returns average episode reward (e.g., to use for tuning learning curves)
return total_rewards / cfg.algorithm.num_episodes
15 changes: 11 additions & 4 deletions mbrl/algorithms/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import omegaconf
import torch
from tqdm import tqdm

import mbrl.constants
from mbrl.env.termination_fns import no_termination
Expand Down Expand Up @@ -130,7 +131,7 @@ def is_test_episode(episode_):
# PlaNet loop
step = replay_buffer.num_stored
total_rewards = 0.0
for episode in range(cfg.algorithm.num_episodes):
for episode in tqdm(range(cfg.algorithm.num_episodes)):
# Train the model for one epoch of `num_grad_updates`
dataset, _ = get_sequence_buffer_iterator(
replay_buffer,
Expand All @@ -143,8 +144,9 @@ def is_test_episode(episode_):
trainer.train(
dataset, num_epochs=1, batch_callback=batch_callback, evaluate=False
)
planet.save(work_dir / "planet.pth")
replay_buffer.save(work_dir)
planet.save(work_dir)
if cfg.overrides.get("save_replay_buffer", False):
replay_buffer.save(work_dir)
metrics = get_metrics_and_clear_metric_containers()
logger.log_data("metrics", metrics)

Expand All @@ -155,6 +157,7 @@ def is_test_episode(episode_):
planet.reset_posterior()
action = None
done = False
pbar = tqdm(total=500)
while not done:
planet.update_posterior(obs, action=action, rng=rng)
action_noise = (
Expand All @@ -164,14 +167,18 @@ def is_test_episode(episode_):
* np_rng.standard_normal(env.action_space.shape[0])
)
action = agent.act(obs) + action_noise
action = np.clip(action, -1.0, 1.0) # to account for the noise
action = np.clip(
action, -1.0, 1.0, dtype=env.action_space.dtype
) # to account for the noise and fix dtype
next_obs, reward, done, info = env.step(action)
replay_buffer.add(obs, action, next_obs, reward, done)
episode_reward += reward
obs = next_obs
if debug_mode:
print(f"step: {step}, reward: {reward}.")
step += 1
pbar.update(1)
pbar.close()
total_rewards += episode_reward
logger.log_data(
mbrl.constants.RESULTS_LOG_NAME,
Expand Down
24 changes: 24 additions & 0 deletions mbrl/examples/conf/algorithm/dreamer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _group_
name: "dreamer"

agent:
_target_: mbrl.planning.DreamerAgent
action_lb: ???
action_ub: ???
horizon: 15
policy_lr: 0.00008
critic_lr: 0.00008
gamma: 0.99
lam: 0.95
grad_clip_norm: 100.0
min_std: 0.0001
init_std: 5
mean_scale: 5
activation_function: "elu"
device: ${device}

num_initial_trajectories: 5
action_noise_std: 0.3
test_frequency: 25
num_episodes: 1000
dataset_size: 1000000
30 changes: 30 additions & 0 deletions mbrl/examples/conf/overrides/dreamer_cartpole_balance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# @package _group_
env: "dmcontrol_cartpole_balance" # used to set the hydra dir, ignored otherwise

env_cfg:
_target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper"
domain_name: "cartpole"
task_name: "balance"
task_kwargs:
random: ${seed}
visualize_reward: false
from_pixels: true
height: 64
width: 64
frame_skip: 2
bit_depth: 5

term_fn: "no_termination"

# General configuration overrides
trial_length: 500
action_noise_std: 0.3

# Model overrides
num_grad_updates: 100
sequence_length: 50
batch_size: 50
free_nats: 3
kl_scale: 1.0

# Dreamer configuration overrides
30 changes: 30 additions & 0 deletions mbrl/examples/conf/overrides/dreamer_cheetah_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# @package _group_
env: "dmcontrol_cheetah_run" # used to set the hydra dir, ignored otherwise

env_cfg:
_target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper"
domain_name: "cheetah"
task_name: "run"
task_kwargs:
random: ${seed}
visualize_reward: false
from_pixels: true
height: 64
width: 64
frame_skip: 2
bit_depth: 5

term_fn: "no_termination"

# General configuration overrides
trial_length: 500
action_noise_std: 0.3

# Model overrides
num_grad_updates: 100
sequence_length: 50
batch_size: 50
free_nats: 3
kl_scale: 1.0

# Dreamer configuration overrides
Loading