diff --git a/README.md b/README.md index 812c6596..9fa2a284 100644 --- a/README.md +++ b/README.md @@ -51,8 +51,8 @@ as examples of how to use this library. You can find them in the [mbrl/algorithms](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms) folder. Currently, we have implemented [PETS](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/pets.py), [MBPO](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/mbpo.py), -[PlaNet](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/planet.py), -we plan to keep increasing this list in the future. +[PlaNet](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/planet.py), +[Dreamer](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/algorithms/dreamer.py), we plan to keep increasing this list in the future. The implementations rely on [Hydra](https://github.com/facebookresearch/hydra) to handle configuration. You can see the configuration files in diff --git a/mbrl/algorithms/dreamer.py b/mbrl/algorithms/dreamer.py new file mode 100644 index 00000000..45f96bfc --- /dev/null +++ b/mbrl/algorithms/dreamer.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import pathlib +from typing import List, Optional, Union + +import gym +import hydra +import numpy as np +import omegaconf +import torch +from tqdm import tqdm + +import mbrl.constants +from mbrl.env.termination_fns import no_termination +from mbrl.models import ModelEnv, ModelTrainer, PlaNetModel +from mbrl.planning import DreamerAgent, RandomAgent, create_dreamer_agent_for_model +from mbrl.util import Logger +from mbrl.util.common import ( + create_replay_buffer, + get_sequence_buffer_iterator, + rollout_agent_trajectories, +) + +METRICS_LOG_FORMAT = [ + ("observations_loss", "OL", "float"), + ("reward_loss", "RL", "float"), + ("gradient_norm", "GN", "float"), + ("kl_loss", "KL", "float"), + ("policy_loss", "PL", "float"), + ("critic_loss", "CL", "float"), +] + + +def train( + env: gym.Env, + cfg: omegaconf.DictConfig, + silent: bool = False, + work_dir: Union[Optional[str], pathlib.Path] = None, +) -> np.float32: + # Experiment initialization + debug_mode = cfg.get("debug_mode", False) + + if work_dir is None: + work_dir = os.getcwd() + work_dir = pathlib.Path(work_dir) + print(f"Results will be saved at {work_dir}.") + + if silent: + logger = None + else: + logger = Logger(work_dir) + logger.register_group("metrics", METRICS_LOG_FORMAT, color="yellow") + logger.register_group( + mbrl.constants.RESULTS_LOG_NAME, + [ + ("env_step", "S", "int"), + ("train_episode_reward", "RT", "float"), + ("episode_reward", "ET", "float"), + ], + color="green", + ) + + rng = torch.Generator(device=cfg.device) + rng.manual_seed(cfg.seed) + np_rng = np.random.default_rng(seed=cfg.seed) + + # Create replay buffer and collect initial data + replay_buffer = create_replay_buffer( + cfg, + env.observation_space.shape, + env.action_space.shape, + collect_trajectories=True, + rng=np_rng, + ) + rollout_agent_trajectories( + env, + cfg.algorithm.num_initial_trajectories, + RandomAgent(env), + agent_kwargs={}, + replay_buffer=replay_buffer, + collect_full_trajectories=True, + trial_length=cfg.overrides.trial_length, + agent_uses_low_dim_obs=False, + ) + + # Create PlaNet model + cfg.dynamics_model.action_size = env.action_space.shape[0] + planet: PlaNetModel = hydra.utils.instantiate(cfg.dynamics_model) + model_env = ModelEnv(env, planet, no_termination, generator=rng) + trainer = ModelTrainer(planet, logger=logger, optim_lr=1e-3, optim_eps=1e-4) + + # Create Dreamer agent + # This agent rolls outs trajectories using ModelEnv, which uses planet.sample() + # to simulate the trajectories from the prior transition model + # The starting point for trajectories is each imagined state output by the + # representation model from the dataset of environment observations + agent: DreamerAgent = create_dreamer_agent_for_model( + planet, model_env, cfg.algorithm.agent + ) + + # Callback and containers to accumulate training statistics and average over batch + rec_losses: List[float] = [] + reward_losses: List[float] = [] + policy_losses: List[float] = [] + critic_losses: List[float] = [] + kl_losses: List[float] = [] + model_grad_norms: List[float] = [] + agent_grad_norms: List[float] = [] + + def get_metrics_and_clear_metric_containers(): + metrics_ = { + "observations_loss": np.mean(rec_losses).item(), + "reward_loss": np.mean(reward_losses).item(), + "policy_loss": np.mean(policy_losses).item(), + "critic_loss": np.mean(critic_losses).item(), + "model_gradient_norm": np.mean(model_grad_norms).item(), + "agent_gradient_norm": np.mean(agent_grad_norms).item(), + "kl_loss": np.mean(kl_losses).item(), + } + + for c in [ + rec_losses, + reward_losses, + policy_losses, + critic_losses, + kl_losses, + model_grad_norms, + agent_grad_norms, + ]: + c.clear() + + return metrics_ + + def model_batch_callback(_epoch, _loss, meta, _mode): + if meta: + rec_losses.append(meta["observations_loss"]) + reward_losses.append(meta["reward_loss"]) + kl_losses.append(meta["kl_loss"]) + if "grad_norm" in meta: + model_grad_norms.append(meta["grad_norm"]) + + def agent_batch_callback(_epoch, _loss, meta, _mode): + if meta: + policy_losses.append(meta["policy_loss"]) + critic_losses.append(meta["critic_loss"]) + if "grad_norm" in meta: + agent_grad_norms.append(meta["grad_norm"]) + + def is_test_episode(episode_): + return episode_ % cfg.algorithm.test_frequency == 0 + + # Dreamer loop + step = replay_buffer.num_stored + total_rewards = 0.0 + for episode in tqdm(range(cfg.algorithm.num_episodes)): + # Train the model for one epoch of `num_grad_updates` + dataset, _ = get_sequence_buffer_iterator( + replay_buffer, + cfg.overrides.batch_size, + 0, # no validation data + cfg.overrides.sequence_length, + max_batches_per_loop_train=cfg.overrides.num_grad_updates, + use_simple_sampler=True, + ) + trainer.train( + dataset, num_epochs=1, batch_callback=model_batch_callback, evaluate=False + ) + agent.train(dataset, num_epochs=1, batch_callback=agent_batch_callback) + planet.save(work_dir) + agent.save(work_dir) + if cfg.overrides.get("save_replay_buffer", False): + replay_buffer.save(work_dir) + metrics = get_metrics_and_clear_metric_containers() + logger.log_data("metrics", metrics) + + # Collect one episode of data + episode_reward = 0.0 + obs = env.reset() + agent.reset() + planet.reset_posterior() + action = None + done = False + pbar = tqdm(total=500) + while not done: + latent_state = planet.update_posterior(obs, action=action, rng=rng) + action_noise = ( + 0 + if is_test_episode(episode) + else cfg.overrides.action_noise_std + * np_rng.standard_normal(env.action_space.shape[0]) + ) + action = agent.act(latent_state) + action = action.detach().cpu().squeeze(0).numpy() + action = action + action_noise + action = np.clip( + action, -1.0, 1.0, dtype=env.action_space.dtype + ) # to account for the noise and fix dtype + next_obs, reward, done, info = env.step(action) + replay_buffer.add(obs, action, next_obs, reward, done) + episode_reward += reward + obs = next_obs + if debug_mode: + print(f"step: {step}, reward: {reward}.") + step += 1 + pbar.update(1) + pbar.close() + total_rewards += episode_reward + logger.log_data( + mbrl.constants.RESULTS_LOG_NAME, + { + "episode_reward": episode_reward * is_test_episode(episode), + "train_episode_reward": episode_reward * (1 - is_test_episode(episode)), + "env_step": step, + }, + ) + + # returns average episode reward (e.g., to use for tuning learning curves) + return total_rewards / cfg.algorithm.num_episodes diff --git a/mbrl/algorithms/planet.py b/mbrl/algorithms/planet.py index d35d3b7e..4729d6fd 100644 --- a/mbrl/algorithms/planet.py +++ b/mbrl/algorithms/planet.py @@ -11,6 +11,7 @@ import numpy as np import omegaconf import torch +from tqdm import tqdm import mbrl.constants from mbrl.env.termination_fns import no_termination @@ -130,7 +131,7 @@ def is_test_episode(episode_): # PlaNet loop step = replay_buffer.num_stored total_rewards = 0.0 - for episode in range(cfg.algorithm.num_episodes): + for episode in tqdm(range(cfg.algorithm.num_episodes)): # Train the model for one epoch of `num_grad_updates` dataset, _ = get_sequence_buffer_iterator( replay_buffer, @@ -143,8 +144,9 @@ def is_test_episode(episode_): trainer.train( dataset, num_epochs=1, batch_callback=batch_callback, evaluate=False ) - planet.save(work_dir / "planet.pth") - replay_buffer.save(work_dir) + planet.save(work_dir) + if cfg.overrides.get("save_replay_buffer", False): + replay_buffer.save(work_dir) metrics = get_metrics_and_clear_metric_containers() logger.log_data("metrics", metrics) @@ -155,6 +157,7 @@ def is_test_episode(episode_): planet.reset_posterior() action = None done = False + pbar = tqdm(total=500) while not done: planet.update_posterior(obs, action=action, rng=rng) action_noise = ( @@ -164,7 +167,9 @@ def is_test_episode(episode_): * np_rng.standard_normal(env.action_space.shape[0]) ) action = agent.act(obs) + action_noise - action = np.clip(action, -1.0, 1.0) # to account for the noise + action = np.clip( + action, -1.0, 1.0, dtype=env.action_space.dtype + ) # to account for the noise and fix dtype next_obs, reward, done, info = env.step(action) replay_buffer.add(obs, action, next_obs, reward, done) episode_reward += reward @@ -172,6 +177,8 @@ def is_test_episode(episode_): if debug_mode: print(f"step: {step}, reward: {reward}.") step += 1 + pbar.update(1) + pbar.close() total_rewards += episode_reward logger.log_data( mbrl.constants.RESULTS_LOG_NAME, diff --git a/mbrl/examples/conf/algorithm/dreamer.yaml b/mbrl/examples/conf/algorithm/dreamer.yaml new file mode 100644 index 00000000..fa42a8c0 --- /dev/null +++ b/mbrl/examples/conf/algorithm/dreamer.yaml @@ -0,0 +1,24 @@ +# @package _group_ +name: "dreamer" + +agent: + _target_: mbrl.planning.DreamerAgent + action_lb: ??? + action_ub: ??? + horizon: 15 + policy_lr: 0.00008 + critic_lr: 0.00008 + gamma: 0.99 + lam: 0.95 + grad_clip_norm: 100.0 + min_std: 0.0001 + init_std: 5 + mean_scale: 5 + activation_function: "elu" + device: ${device} + +num_initial_trajectories: 5 +action_noise_std: 0.3 +test_frequency: 25 +num_episodes: 1000 +dataset_size: 1000000 diff --git a/mbrl/examples/conf/overrides/dreamer_cartpole_balance.yaml b/mbrl/examples/conf/overrides/dreamer_cartpole_balance.yaml new file mode 100644 index 00000000..d7a0c54a --- /dev/null +++ b/mbrl/examples/conf/overrides/dreamer_cartpole_balance.yaml @@ -0,0 +1,30 @@ +# @package _group_ +env: "dmcontrol_cartpole_balance" # used to set the hydra dir, ignored otherwise + +env_cfg: + _target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper" + domain_name: "cartpole" + task_name: "balance" + task_kwargs: + random: ${seed} + visualize_reward: false + from_pixels: true + height: 64 + width: 64 + frame_skip: 2 + bit_depth: 5 + +term_fn: "no_termination" + +# General configuration overrides +trial_length: 500 +action_noise_std: 0.3 + +# Model overrides +num_grad_updates: 100 +sequence_length: 50 +batch_size: 50 +free_nats: 3 +kl_scale: 1.0 + +# Dreamer configuration overrides diff --git a/mbrl/examples/conf/overrides/dreamer_cheetah_run.yaml b/mbrl/examples/conf/overrides/dreamer_cheetah_run.yaml new file mode 100644 index 00000000..afb7ca5a --- /dev/null +++ b/mbrl/examples/conf/overrides/dreamer_cheetah_run.yaml @@ -0,0 +1,30 @@ +# @package _group_ +env: "dmcontrol_cheetah_run" # used to set the hydra dir, ignored otherwise + +env_cfg: + _target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper" + domain_name: "cheetah" + task_name: "run" + task_kwargs: + random: ${seed} + visualize_reward: false + from_pixels: true + height: 64 + width: 64 + frame_skip: 2 + bit_depth: 5 + +term_fn: "no_termination" + +# General configuration overrides +trial_length: 500 +action_noise_std: 0.3 + +# Model overrides +num_grad_updates: 100 +sequence_length: 50 +batch_size: 50 +free_nats: 3 +kl_scale: 1.0 + +# Dreamer configuration overrides diff --git a/mbrl/examples/conf/overrides/dreamer_walker_run.yaml b/mbrl/examples/conf/overrides/dreamer_walker_run.yaml new file mode 100644 index 00000000..24708f28 --- /dev/null +++ b/mbrl/examples/conf/overrides/dreamer_walker_run.yaml @@ -0,0 +1,30 @@ +# @package _group_ +env: "dmcontrol_walker_run" # used to set the hydra dir, ignored otherwise + +env_cfg: + _target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper" + domain_name: "walker" + task_name: "run" + task_kwargs: + random: ${seed} + visualize_reward: false + from_pixels: true + height: 64 + width: 64 + frame_skip: 2 + bit_depth: 5 + +term_fn: "no_termination" + +# General configuration overrides +trial_length: 500 +action_noise_std: 0.3 + +# Model overrides +num_grad_updates: 100 +sequence_length: 50 +batch_size: 50 +free_nats: 3 +kl_scale: 1.0 + +# Dreamer configuration overrides diff --git a/mbrl/examples/conf/overrides/dreamer_walker_stand.yaml b/mbrl/examples/conf/overrides/dreamer_walker_stand.yaml new file mode 100644 index 00000000..3022e68b --- /dev/null +++ b/mbrl/examples/conf/overrides/dreamer_walker_stand.yaml @@ -0,0 +1,30 @@ +# @package _group_ +env: "dmcontrol_walker_stand" # used to set the hydra dir, ignored otherwise + +env_cfg: + _target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper" + domain_name: "walker" + task_name: "stand" + task_kwargs: + random: ${seed} + visualize_reward: false + from_pixels: true + height: 64 + width: 64 + frame_skip: 2 + bit_depth: 5 + +term_fn: "no_termination" + +# General configuration overrides +trial_length: 500 +action_noise_std: 0.3 + +# Model overrides +num_grad_updates: 100 +sequence_length: 50 +batch_size: 50 +free_nats: 3 +kl_scale: 1.0 + +# Dreamer configuration overrides diff --git a/mbrl/examples/conf/overrides/dreamer_walker_walk.yaml b/mbrl/examples/conf/overrides/dreamer_walker_walk.yaml new file mode 100644 index 00000000..501f27d0 --- /dev/null +++ b/mbrl/examples/conf/overrides/dreamer_walker_walk.yaml @@ -0,0 +1,30 @@ +# @package _group_ +env: "dmcontrol_walker_walk" # used to set the hydra dir, ignored otherwise + +env_cfg: + _target_: "mbrl.third_party.dmc2gym.wrappers.DMCWrapper" + domain_name: "walker" + task_name: "walk" + task_kwargs: + random: ${seed} + visualize_reward: false + from_pixels: true + height: 64 + width: 64 + frame_skip: 2 + bit_depth: 5 + +term_fn: "no_termination" + +# General configuration overrides +trial_length: 500 +action_noise_std: 0.3 + +# Model overrides +num_grad_updates: 100 +sequence_length: 50 +batch_size: 50 +free_nats: 3 +kl_scale: 1.0 + +# Dreamer configuration overrides diff --git a/mbrl/examples/main.py b/mbrl/examples/main.py index c2e1a57f..ec6cb17e 100644 --- a/mbrl/examples/main.py +++ b/mbrl/examples/main.py @@ -7,6 +7,7 @@ import omegaconf import torch +import mbrl.algorithms.dreamer as dreamer import mbrl.algorithms.mbpo as mbpo import mbrl.algorithms.pets as pets import mbrl.algorithms.planet as planet @@ -25,6 +26,8 @@ def run(cfg: omegaconf.DictConfig): return mbpo.train(env, test_env, term_fn, cfg) if cfg.algorithm.name == "planet": return planet.train(env, cfg) + if cfg.algorithm.name == "dreamer": + return dreamer.train(env, cfg) if __name__ == "__main__": diff --git a/mbrl/models/planet.py b/mbrl/models/planet.py index 8f206413..e8841755 100644 --- a/mbrl/models/planet.py +++ b/mbrl/models/planet.py @@ -101,7 +101,7 @@ def forward( class MeanStdCat(nn.Module): - # Convenience module to avoid having to write chuck and softplus in multiple places + # Convenience module to avoid having to write chunk and softplus in multiple places # (since it's needed for prior and posterior params) def __init__(self, latent_state_size: int, min_std: float): super().__init__() diff --git a/mbrl/planning/__init__.py b/mbrl/planning/__init__.py index 463bbadc..b956d529 100644 --- a/mbrl/planning/__init__.py +++ b/mbrl/planning/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .core import Agent, RandomAgent, complete_agent_cfg, load_agent +from .dreamer_agent import DreamerAgent, create_dreamer_agent_for_model from .linear_feedback import PIDAgent from .trajectory_opt import ( CEMOptimizer, diff --git a/mbrl/planning/core.py b/mbrl/planning/core.py index 29ca7f93..ad82d844 100644 --- a/mbrl/planning/core.py +++ b/mbrl/planning/core.py @@ -19,27 +19,27 @@ class Agent: """Abstract class for all agents.""" @abc.abstractmethod - def act(self, obs: np.ndarray, **_kwargs) -> np.ndarray: + def act(self, obs: Any, **_kwargs) -> mbrl.types.TensorType: """Issues an action given an observation. Args: - obs (np.ndarray): the observation for which the action is needed. + obs (Any): the observation for which the action is needed. Returns: - (np.ndarray): the action. + (TensorType): the action. """ pass - def plan(self, obs: np.ndarray, **_kwargs) -> np.ndarray: + def plan(self, obs: Any, **_kwargs) -> mbrl.types.TensorType: """Issues a sequence of actions given an observation. Unless overridden by a child class, this will be equivalent to :meth:`act`. Args: - obs (np.ndarray): the observation for which the sequence is needed. + obs (Any): the observation for which the sequence is needed. Returns: - (np.ndarray): a sequence of actions. + (TensorType): a sequence of actions. """ return self.act(obs, **_kwargs) @@ -150,8 +150,15 @@ def load_agent(agent_path: Union[str, pathlib.Path], env: gym.Env) -> Agent: from .sac_wrapper import SACAgent complete_agent_cfg(env, cfg.algorithm.agent) - agent: pytorch_sac.SAC = hydra.utils.instantiate(cfg.algorithm.agent) - agent.load_checkpoint(ckpt_path=agent_path / "sac.pth") - return SACAgent(agent) + sac: pytorch_sac.SAC = hydra.utils.instantiate(cfg.algorithm.agent) + sac.load_checkpoint(ckpt_path=agent_path / "sac.pth") + return SACAgent(sac) + elif cfg.algorithm.agent == "mbrl.planning.dreamer_agent.DreamerAgent": + from mbrl.planning.dreamer_agent import DreamerAgent + + complete_agent_cfg(env, cfg.algorithm.agent) + dreamer_agent: DreamerAgent = hydra.utils.instantiate(cfg.algorithm.agent) + dreamer_agent.load(agent_path) + return dreamer_agent else: raise ValueError("Invalid agent configuration.") diff --git a/mbrl/planning/dreamer_agent.py b/mbrl/planning/dreamer_agent.py new file mode 100644 index 00000000..e25bd836 --- /dev/null +++ b/mbrl/planning/dreamer_agent.py @@ -0,0 +1,343 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pathlib +from typing import Callable, Dict, Optional, Sequence, Union + +import hydra +import numpy as np +import omegaconf +import torch +import tqdm +from torch import nn +from torch.distributions import ( + Independent, + Normal, + TanhTransform, + TransformedDistribution, +) +from torch.nn import functional as F +from torch.optim import Adam + +import mbrl.models +from mbrl.models.planet import PlaNetModel +from mbrl.types import TensorType +from mbrl.util.replay_buffer import TransitionIterator + +from .core import Agent, complete_agent_cfg + + +def freeze(module: nn.Module): + for p in module.parameters(): + p.requires_grad = False + + +def unfreeze(module: nn.Module): + for p in module.parameters(): + p.requires_grad = True + + +class PolicyModel(nn.Module): + def __init__( + self, + latent_size: int, + action_size: int, + hidden_size: int, + min_std: float = 1e-4, + init_std: float = 5, + mean_scale: float = 5, + activation_function="elu", + ): + super().__init__() + self.act_fn = getattr(F, activation_function) + self.fc1 = nn.Linear(latent_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.fc3 = nn.Linear(hidden_size, action_size * 2) + self.min_std = min_std + self.init_std = init_std + self.mean_scale = mean_scale + self.raw_init_std = np.log(np.exp(self.init_std) - 1) + + def forward(self, belief, latent): + hidden = self.act_fn(self.fc1(torch.cat([belief, latent], dim=-1))) + hidden = self.act_fn(self.fc2(hidden)) + model_out = self.fc3(hidden).squeeze(dim=1) + mean, std = torch.chunk(model_out, 2, -1) + mean = self.mean_scale * torch.tanh(mean / self.mean_scale) + std = F.softplus(std + self.raw_init_std) + self.min_std + dist = Normal(mean, std) + dist = TransformedDistribution(dist, TanhTransform()) + dist = Independent(dist, 1) + return dist + + +class ValueModel(nn.Module): + def __init__(self, latent_size, hidden_size, activation_function="elu"): + super().__init__() + self.act_fn = getattr(F, activation_function) + self.fc1 = nn.Linear(latent_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.fc3 = nn.Linear(hidden_size, 1) + + def forward(self, belief, latent): + hidden = self.act_fn(self.fc1(torch.cat([belief, latent], dim=-1))) + hidden = self.act_fn(self.fc2(hidden)) + value = self.fc3(hidden).squeeze(dim=1) + return value + + +class DreamerAgent(Agent): + def __init__( + self, + action_size: int, + action_lb: Sequence[float] = [-1.0], + action_ub: Sequence[float] = [1.0], + belief_size: int = 200, + latent_state_size: int = 30, + hidden_size: int = 300, + horizon: int = 15, + policy_lr: float = 8e-5, + min_std: float = 1e-4, + init_std: float = 5, + mean_scale: float = 5, + critic_lr: float = 8e-5, + gamma: float = 0.99, + lam: float = 0.95, + grad_clip_norm: float = 100.0, + activation_function: str = "elu", + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.belief_size = belief_size + self.latent_state_size = latent_state_size + self.action_size = action_size + self.gamma = gamma + self.lam = lam + self.grad_clip_norm = grad_clip_norm + self.horizon = horizon + self.action_lb = action_lb + self.action_ub = action_ub + self.device = device + self.planet_model: PlaNetModel = None + + self.policy = PolicyModel( + belief_size + latent_state_size, + action_size, + hidden_size, + min_std, + init_std, + mean_scale, + activation_function, + ).to(device) + self.policy_optim = Adam(self.policy.parameters(), policy_lr) + self.critic = ValueModel( + belief_size + latent_state_size, hidden_size, activation_function + ).to(device) + self.critic_optim = Adam(self.critic.parameters(), critic_lr) + + def parameters(self): + return list(self.policy.parameters()) + list(self.critic.parameters()) + + def act( + self, obs: Dict[str, TensorType], training: bool = True, **_kwargs + ) -> TensorType: + action_dist = self.policy(obs["belief"], obs["latent"]) + if training: + action = action_dist.rsample() + else: + action = action_dist.mode() + return action + + def train( + self, + dataset_train: TransitionIterator, + num_epochs: Optional[int] = None, + batch_callback: Optional[Callable] = None, + silent: bool = False, + ) -> None: + + # only enable tqdm if training for a single epoch, + # otherwise it produces too much output + disable_tqdm = silent or (num_epochs is None or num_epochs > 1) + + meta = {} + + freeze(self.planet_model) + + for batch in tqdm.tqdm(dataset_train, disable=disable_tqdm): + obs, actions, rewards = self.planet_model._process_batch( + batch, + pixel_obs=True, + ) + + ( + _, + _, + _, + latents, + beliefs, + _, + rewards, + ) = self.planet_model(obs[:, 1:], actions[:, :-1], rewards[:, :-1]) + + for epoch in range(num_epochs): + B, L, _ = beliefs.shape + imag_beliefs = [] + imag_latents = [] + imag_actions = [] + imag_rewards = [] + states = { + "belief": beliefs.reshape(B * L, -1), + "latent": latents.reshape(B * L, -1), + } + for _ in range(self.horizon): + actions = self.act(states) + imag_beliefs.append(states["belief"]) + imag_latents.append(states["latent"]) + imag_actions.append(actions) + + _, rewards, _, states = self.planet_model.sample(actions, states) + imag_rewards.append(rewards) + + # I x (B*L) x _ + imag_beliefs = torch.stack(imag_beliefs) + imag_latents = torch.stack(imag_latents) + imag_actions = torch.stack(imag_actions) + with torch.no_grad(): + imag_values = self.critic(imag_beliefs, imag_latents) + + imag_rewards = torch.stack(imag_rewards) + discount_arr = self.gamma * torch.ones_like(imag_rewards) + returns = self._compute_return( + imag_rewards[:-1], + imag_values[:-1], + discount_arr[:-1], + bootstrap=imag_values[-1], + lambda_=self.lam, + ) + # Make the top row 1 so the cumulative product starts with discount^0 + discount_arr = torch.cat( + [torch.ones_like(discount_arr[:1]), discount_arr[1:]] + ) + discount = torch.cumprod(discount_arr[:-1], 0) + policy_loss = -torch.mean(discount * returns) + + # Detach tensors which have gradients through policy model for value loss + value_beliefs = imag_beliefs.detach()[:-1] # type: ignore + value_latents = imag_latents.detach()[:-1] # type: ignore + value_discount = discount.detach() + value_target = returns.detach() + value_pred = self.critic(value_beliefs, value_latents) + critic_loss = F.mse_loss(value_discount * value_target, value_pred) + + self.policy_optim.zero_grad() + self.critic_optim.zero_grad() + + nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_clip_norm) + nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_clip_norm) + + policy_loss.backward() + critic_loss.backward() + + meta["policy_loss"] = policy_loss.item() + meta["critic_loss"] = critic_loss.item() + + with torch.no_grad(): + grad_norm = 0.0 + for p in list( + filter(lambda p: p.grad is not None, self.parameters()) + ): + grad_norm += p.grad.data.norm(2).item() + meta["grad_norm"] = grad_norm + + self.policy_optim.step() + self.critic_optim.step() + batch_callback(epoch, None, meta, "train") + unfreeze(self.planet_model) + + def save(self, save_dir: Union[str, pathlib.Path]): + """Saves the agent to the given directory.""" + save_path = pathlib.Path(save_dir) / "agent.pth" + print("Saving models to {}".format(save_path)) + torch.save( + { + "policy_state_dict": self.policy.state_dict(), + "actor_optimizer_state_dict": self.policy_optim.state_dict(), + "critic_state_dict": self.critic.state_dict(), + "critic_optimizer_state_dict": self.critic_optim.state_dict(), + }, + save_path, + ) + + def load(self, load_dir: Union[str, pathlib.Path], evaluate=False): + """Loads the agent from the given directory.""" + load_path = pathlib.Path(load_dir) / "agent.pth" + print("Saving models to {}".format(load_path)) + checkpoint = torch.load(load_path) + self.policy.load_state_dict(checkpoint["policy_state_dict"]) + self.policy_optim.load_state_dict(checkpoint["policy_optimizer_state_dict"]) + self.critic.load_state_dict(checkpoint["critic_state_dict"]) + self.critic_optim.load_state_dict(checkpoint["critic_optimizer_state_dict"]) + + if evaluate: + self.policy.eval() + self.critic.eval() + else: + self.policy.train() + self.critic.train() + + def _compute_return( + self, + reward: torch.Tensor, + value: torch.Tensor, + discount: torch.Tensor, + bootstrap: torch.Tensor, + lambda_: float, + ): + """ + Compute the discounted reward for a batch of data. + reward, value, and discount are all shape [horizon - 1, batch, 1] + (last element is cut off) + Bootstrap is [batch, 1] + """ + next_values = torch.cat([value[1:], bootstrap[None]], 0) + target = reward + discount * next_values * (1 - lambda_) + outputs = [] + accumulated_reward = bootstrap + for t in range(reward.shape[0] - 1, -1, -1): + inp = target[t] + discount_factor = discount[t] + accumulated_reward = inp + discount_factor * lambda_ * accumulated_reward + outputs.append(accumulated_reward) + returns = torch.flip(torch.stack(outputs), [0]) + return returns + + +def create_dreamer_agent_for_model( + planet: mbrl.models.PlaNetModel, + model_env: mbrl.models.ModelEnv, + agent_cfg: omegaconf.DictConfig, +) -> DreamerAgent: + """Utility function for creating an dreamer agent for a model environment. + + This is a convenience function for creating a :class:`DreamerAgent` + + + Args: + model_env (mbrl.models.ModelEnv): the model environment. + agent_cfg (omegaconf.DictConfig): the agent's configuration. + + Returns: + (:class:`DreamerAgent`): the agent. + + """ + complete_agent_cfg(model_env, agent_cfg) + with omegaconf.open_dict(agent_cfg): + agent_cfg.latent_state_size = planet.latent_state_size + agent_cfg.belief_size = planet.belief_size + agent_cfg.action_size = planet.action_size + agent = hydra.utils.instantiate(agent_cfg) + # Not a primitive, so assigned after initialization + agent.planet_model = planet + return agent diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..19385131 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel" +] +build-backend = "setuptools.build_meta" + +[tool.black] +line-length = 88 +exclude = ''' +( + /( + .eggs # exclude a few common directories in the + | .git # root of the project + | .mypy_cache + | docs + ) +) +''' \ No newline at end of file diff --git a/requirements/main.txt b/requirements/main.txt index cdae3633..a5509bc3 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -5,7 +5,7 @@ tensorboard>=2.4.0 imageio>=2.9.0 numpy>=1.19.1 matplotlib>=3.3.1 -gym==0.17.2 +gym>=0.20.0,<0.25.0 jupyter>=1.0.0 pytest>=6.0.1 sk-video>=1.1.10 diff --git a/setup.py b/setup.py index bf1d2db7..6c0d013e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path -from setuptools import setup, find_packages + +from setuptools import find_packages, setup def parse_requirements_file(path):