Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 21, 2024
2 parents ad61eb2 + 4750f8e commit bbaa142
Show file tree
Hide file tree
Showing 51 changed files with 503 additions and 199 deletions.
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- future
- cloudpickle
- pygame
- moviepy
- moviepy<2.0.0
- tqdm
- pytest
- pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- future
- cloudpickle
- pygame
- moviepy
- moviepy<2.0.0
- tqdm
- pytest
- pytest-cov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- future
- cloudpickle
- pygame
- moviepy
- moviepy<2.0.0
- pytest-cov
- pytest-mock
- pytest-instafail
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_gym/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- future
- cloudpickle
- pygame
- moviepy
- moviepy<2.0.0
- tqdm
- pytest
- pytest-cov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- future
- cloudpickle
- pygame
- moviepy
- moviepy<2.0.0
- tqdm
- pytest
- pytest-cov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- cloudpickle
- gym[atari]==0.13
- pygame
- moviepy
- moviepy<2.0.0
- tqdm
- pytest
- pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_sota/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- future
- cloudpickle
- pygame
- moviepy
- moviepy<2.0.0
- tqdm
- pytest
- pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ make of torchrl:
pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher

# rendering
pip3 install moviepy
pip3 install "moviepy<2.0.0"

# deepmind control suite
pip3 install dm_control
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer.
OnlineDTActor
RSSMPosterior
RSSMPrior
set_recurrent_mode
recurrent_mode

Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
"low": env.action_spec_unbatched.space.low,
"high": env.action_spec_unbatched.space.high,
},
return_log_prob=True,
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _main(argv):
],
"dm_control": ["dm_control"],
"gym_continuous": ["gymnasium<1.0", "mujoco"],
"rendering": ["moviepy"],
"rendering": ["moviepy<2.0.0"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
"utils": [
"tensorboard",
Expand Down
2 changes: 1 addition & 1 deletion sota-check/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export MUJOCO_GL=egl
conda create -n rl-sota-bench python=3.10 -y
conda install anaconda::libglu -y
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame moviepy imageio submitit hydra-core transformers
pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame "moviepy<2.0.0" imageio submitit hydra-core transformers

cd /path/to/tensordict
python setup.py develop
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": True,
}
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def make_offline_replay_buffer(rb_cfg):
def make_cql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

action_spec = train_env.single_action_spec
action_spec = train_env.action_spec_unbatched

actor_net, q_net = make_cql_modules_state(model_cfg, eval_env)
in_keys = ["observation"]
Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def make_crossQ_agent(cfg, train_env, device):
"""Make CrossQ agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def make_dt_model(cfg):
make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1
)

action_spec = proof_environment.action_spec
action_spec = proof_environment.action_spec_unbatched
for key, value in proof_environment.observation_spec.items():
if key == "observation":
state_dim = value.shape[-1]
Expand Down
10 changes: 9 additions & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)

# mixed precision training
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn.utils import clip_grad_norm_
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -321,6 +321,14 @@ def compile_rssms(module):

t_collect_init = time.time()

test_env.close()
train_env.close()
collector.shutdown()

del test_env
del train_env
del collector


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)

out_keys = ["loc", "scale"]
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Install vmas and dependencies:

```bash
pip install vmas
pip install wandb moviepy
pip install wandb "moviepy<2.0.0"
pip install hydra-core
```

Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train(cfg: "DictConfig"): # noqa: F821
# Policy
net = MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=env.action_spec.space.n,
n_agent_outputs=env.full_action_spec["agents", "action"].space.n,
n_agents=env.n_agents,
centralised=False,
share_params=cfg.model.shared_parameters,
Expand All @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
),
)

Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "param")],
out_keys=[env.action_key],
distribution_class=TanhDelta,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=False,
)

policy_explore = TensorDictSequential(
policy,
AdditiveGaussianModule(
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train(cfg: "DictConfig"): # noqa: F821
# Policy
net = MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=env.action_spec.space.n,
n_agent_outputs=env.full_action_spec["agents", "action"].space.n,
n_agents=env.n_agents,
centralised=False,
share_params=cfg.model.shared_parameters,
Expand All @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,14 +103,14 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
),
)

if cfg.loss.mixer_type == "qmix":
mixer = TensorDictModule(
module=QMixer(
state_shape=env.unbatched_observation_spec[
state_shape=env.observation_spec_unbatched[
"agents", "observation"
].shape,
mixing_embed_dim=32,
Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821

policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down Expand Up @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "logits")],
out_keys=[env.action_key],
distribution_class=OneHotCategorical
Expand Down Expand Up @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.unbatched_action_spec,
action_spec=env.full_action_spec_unbatched,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
Expand All @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821
qvalue_network=value_module,
delay_qvalue=True,
num_actions=env.action_spec.space.n,
action_space=env.unbatched_action_spec,
action_space=env.full_action_spec_unbatched,
)
loss_module.set_keys(
action_value=("agents", "action_value"),
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def make_redq_model(
default_policy_scale = cfg.network.default_policy_scale
gSDE = cfg.exploration.gSDE

action_spec = proof_environment.action_spec
action_spec = proof_environment.action_spec_unbatched

if actor_net_kwargs is None:
actor_net_kwargs = {}
Expand Down
Loading

0 comments on commit bbaa142

Please sign in to comment.