Skip to content

Commit

Permalink
config
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Sep 19, 2023
1 parent 888fbcb commit f3f9832
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 28 deletions.
9 changes: 5 additions & 4 deletions examples/impala/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ env:

# collector
collector:
frames_per_batch: 80
frames_per_batch: 128 # 80
total_frames: 40_000_000
num_workers: 12

# logger
logger:
Expand All @@ -16,16 +17,16 @@ logger:

# Optim
optim:
lr: 0.0001
eps: 1.0e-8
lr: 0.0006 # 0.0001
eps: 1.0e-5
weight_decay: 0.0
max_grad_norm: 40.0
anneal_lr: True

# loss
loss:
gamma: 0.99
mini_batch_size: 80
mini_batch_size: 128 # 80
critic_coef: 0.25
entropy_coef: 0.01
loss_critic_type: l2
Expand Down
27 changes: 9 additions & 18 deletions examples/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import A2CLoss, ClipPPOLoss
from torchrl.objectives import A2CLoss
from torchrl.record.loggers import generate_exp_name, get_logger
from torchrl.objectives.value.vtrace import VTrace
from utils import make_parallel_env, make_ppo_models
from utils import make_parallel_env, make_ppo_models, eval_model

device = "cpu" if not torch.cuda.is_available() else "cuda"

Expand All @@ -39,7 +39,7 @@ def main(cfg: "DictConfig"): # noqa: F821
mini_batch_size = cfg.loss.mini_batch_size // frame_skip
test_interval = cfg.logger.test_interval // frame_skip

# Create models (check utils_atari.py)
# Create models (check utils.py)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
actor, critic, critic_head = (
actor.to(device),
Expand All @@ -58,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# sync=False,
# )
collector = MultiaSyncDataCollector(
create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 8,
create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * cfg.collector.num_workers,
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
Expand Down Expand Up @@ -125,7 +125,7 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
for data in collector:

log_info = None
log_info = {}
sampling_time = time.time() - sampling_start
frames_in_batch = data.numel()
collected_frames += frames_in_batch * frame_skip
Expand All @@ -134,7 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Get train reward
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
log_info.update({"reward_train": episode_rewards.mean().item()})
log_info.update({"train/reward": episode_rewards.mean().item()})

# Apply episodic end of life
data["done"].copy_(data["end_of_life"])
Expand Down Expand Up @@ -205,18 +205,9 @@ def main(cfg: "DictConfig"): # noqa: F821
):
actor.eval()
eval_start = time.time()
test_rewards = []
for _ in range(cfg.logger.num_test_episodes):
td_test = test_env.rollout(
policy=actor,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
max_steps=10_000_000,
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
del td_test
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
{
Expand Down
25 changes: 23 additions & 2 deletions examples/impala/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import random

import gymnasium as gym
import numpy as np
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
Expand Down Expand Up @@ -99,7 +99,7 @@ def make_parallel_env(env_name, device, is_test=False):
if not is_test:
env.append_transform(RewardClipping(-1, 1))
env.append_transform(DoubleToFloat())
env.append_transform(VecNorm(in_keys=["pixels"]))
# env.append_transform(VecNorm(in_keys=["pixels"]))
return env


Expand Down Expand Up @@ -218,3 +218,24 @@ def make_ppo_models(env_name):
del proof_environment

return actor, critic, critic_head


# ====================================================================
# Evaluation utils
# --------------------------------------------------------------------


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
td_test = test_env.rollout(
policy=actor,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
max_steps=10_000_000,
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
del td_test
return test_rewards.mean()
8 changes: 4 additions & 4 deletions torchrl/objectives/value/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _c_val(
log_mu: torch.Tensor,
c: Union[float, torch.Tensor] = 1,
) -> torch.Tensor:
return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) # TODO: is unsqueeze needed?
return (log_pi - log_mu).clamp_max(math.log(c)).exp() # TODO: is unsqueeze needed?

def _dv_val(
rewards: torch.Tensor,
Expand Down Expand Up @@ -313,17 +313,17 @@ def forward(
# Make sure we have the log prob computed at collection time
if self.log_prob_key not in tensordict.keys():
raise ValueError(f"Expected {self.log_prob_key} to be in tensordict")
log_mu = tensordict.get(self.log_prob_key)
log_mu = tensordict.get(self.log_prob_key).reshape_as(value)

# Compute log prob with current policy
with hold_out_net(self.actor_network):
log_pi = self.actor_network(
tensordict.select(self.actor_network.in_keys)
).get(self.log_prob_key)
).get(self.log_prob_key).reshape_as(value)

# Compute the V-Trace correction
done = tensordict.get(("next", self.tensor_keys.done))
import ipdb; ipdb.set_trace()

adv, value_target = vtrace_correction(
gamma,
log_pi,
Expand Down

0 comments on commit f3f9832

Please sign in to comment.