From f3f9832f90e7012997af8b753600135591483634 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 10:31:02 +0200 Subject: [PATCH] config --- examples/impala/config.yaml | 9 +++++---- examples/impala/impala.py | 27 +++++++++------------------ examples/impala/utils.py | 25 +++++++++++++++++++++++-- torchrl/objectives/value/vtrace.py | 8 ++++---- 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 9d67995938f..042a572375e 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,8 +4,9 @@ env: # collector collector: - frames_per_batch: 80 + frames_per_batch: 128 # 80 total_frames: 40_000_000 + num_workers: 12 # logger logger: @@ -16,8 +17,8 @@ logger: # Optim optim: - lr: 0.0001 - eps: 1.0e-8 + lr: 0.0006 # 0.0001 + eps: 1.0e-5 weight_decay: 0.0 max_grad_norm: 40.0 anneal_lr: True @@ -25,7 +26,7 @@ optim: # loss loss: gamma: 0.99 - mini_batch_size: 80 + mini_batch_size: 128 # 80 critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 96fc11ad0c6..aa781fa9dad 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -25,10 +25,10 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss, ClipPPOLoss + from torchrl.objectives import A2CLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.objectives.value.vtrace import VTrace - from utils import make_parallel_env, make_ppo_models + from utils import make_parallel_env, make_ppo_models, eval_model device = "cpu" if not torch.cuda.is_available() else "cuda" @@ -39,7 +39,7 @@ def main(cfg: "DictConfig"): # noqa: F821 mini_batch_size = cfg.loss.mini_batch_size // frame_skip test_interval = cfg.logger.test_interval // frame_skip - # Create models (check utils_atari.py) + # Create models (check utils.py) actor, critic, critic_head = make_ppo_models(cfg.env.env_name) actor, critic, critic_head = ( actor.to(device), @@ -58,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # sync=False, # ) collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 8, + create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * cfg.collector.num_workers, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -125,7 +125,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() for data in collector: - log_info = None + log_info = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -134,7 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get train reward episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: - log_info.update({"reward_train": episode_rewards.mean().item()}) + log_info.update({"train/reward": episode_rewards.mean().item()}) # Apply episodic end of life data["done"].copy_(data["end_of_life"]) @@ -205,18 +205,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ): actor.eval() eval_start = time.time() - test_rewards = [] - for _ in range(cfg.logger.num_test_episodes): - td_test = test_env.rollout( - policy=actor, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - max_steps=10_000_000, - ) - reward = td_test["next", "episode_reward"][td_test["next", "done"]] - test_rewards = np.append(test_rewards, reward.cpu().numpy()) - del td_test + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) eval_time = time.time() - eval_start log_info.update( { diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 9f48782d2f7..e431dd82255 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -3,9 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import random import gymnasium as gym +import numpy as np import torch.nn import torch.optim from tensordict.nn import TensorDictModule @@ -99,7 +99,7 @@ def make_parallel_env(env_name, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"])) + # env.append_transform(VecNorm(in_keys=["pixels"])) return env @@ -218,3 +218,24 @@ def make_ppo_models(env_name): del proof_environment return actor, critic, critic_head + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean() diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 510f5d8fda5..00e03930169 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -31,7 +31,7 @@ def _c_val( log_mu: torch.Tensor, c: Union[float, torch.Tensor] = 1, ) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) # TODO: is unsqueeze needed? + return (log_pi - log_mu).clamp_max(math.log(c)).exp() # TODO: is unsqueeze needed? def _dv_val( rewards: torch.Tensor, @@ -313,17 +313,17 @@ def forward( # Make sure we have the log prob computed at collection time if self.log_prob_key not in tensordict.keys(): raise ValueError(f"Expected {self.log_prob_key} to be in tensordict") - log_mu = tensordict.get(self.log_prob_key) + log_mu = tensordict.get(self.log_prob_key).reshape_as(value) # Compute log prob with current policy with hold_out_net(self.actor_network): log_pi = self.actor_network( tensordict.select(self.actor_network.in_keys) - ).get(self.log_prob_key) + ).get(self.log_prob_key).reshape_as(value) # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) - import ipdb; ipdb.set_trace() + adv, value_target = vtrace_correction( gamma, log_pi,