Skip to content

Commit

Permalink
[BugFix] Use single_action_spec whenever necessary
Browse files Browse the repository at this point in the history
ghstack-source-id: a6748fa882a41fdd50795b46b261e6e214af2c0e
Pull Request resolved: #2592
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent 7564567 commit c5e83e6
Show file tree
Hide file tree
Showing 22 changed files with 53 additions and 54 deletions.
4 changes: 2 additions & 2 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
"low": env.single_action_spec.space.low,
"high": env.single_action_spec.space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.single_action_spec.space.low.to(device),
"high": proof_environment.single_action_spec.space.high.to(device),
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.single_action_spec.space.low.to(device),
"high": proof_environment.single_action_spec.space.high.to(device),
"tanh_loc": False,
"safe_tanh": True,
}
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def make_offline_replay_buffer(rb_cfg):
def make_cql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

action_spec = train_env.action_spec
action_spec = train_env.single_action_spec

actor_net, q_net = make_cql_modules_state(model_cfg, eval_env)
in_keys = ["observation"]
Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def make_crossQ_agent(cfg, train_env, device):
"""Make CrossQ agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.single_action_spec
actor_net_kwargs = {
"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def make_dt_model(cfg):
make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1
)

action_spec = proof_environment.action_spec
action_spec = proof_environment.single_action_spec
for key, value in proof_environment.observation_spec.items():
if key == "observation":
state_dim = value.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.single_action_spec.space.low,
"high": proof_environment.single_action_spec.space.high,
"tanh_loc": False,
}

Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.single_action_spec
actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)

out_keys = ["loc", "scale"]
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
),
)

Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
in_keys=[("agents", "param")],
out_keys=[env.action_key],
distribution_class=TanhDelta,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.single_action_spec[("agents", "action")].space.low,
"high": env.single_action_spec[("agents", "action")].space.high,
},
return_log_prob=False,
)

policy_explore = TensorDictSequential(
policy,
AdditiveGaussianModule(
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.single_action_spec[("agents", "action")].space.low,
"high": env.single_action_spec[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
),
)

Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821

policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.single_action_spec[("agents", "action")].space.low,
"high": env.single_action_spec[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down Expand Up @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
in_keys=[("agents", "logits")],
out_keys=[env.action_key],
distribution_class=OneHotCategorical
Expand Down Expand Up @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.unbatched_action_spec,
action_spec=env.single_action_spec,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
Expand All @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821
qvalue_network=value_module,
delay_qvalue=True,
num_actions=env.action_spec.space.n,
action_space=env.unbatched_action_spec,
action_space=env.single_action_spec,
)
loss_module.set_keys(
action_value=("agents", "action_value"),
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.single_action_spec.space.low,
"high": proof_environment.single_action_spec.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.single_action_spec.space.low,
"high": proof_environment.single_action_spec.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def make_redq_model(
default_policy_scale = cfg.network.default_policy_scale
gSDE = cfg.exploration.gSDE

action_spec = proof_environment.action_spec
action_spec = proof_environment.single_action_spec

if actor_net_kwargs is None:
actor_net_kwargs = {}
Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def make_sac_agent(cfg, train_env, eval_env, device):
"""Make SAC agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.single_action_spec
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
12 changes: 6 additions & 6 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,7 @@ def _make_specs(self):
device=self.device,
)

self.unbatched_action_spec = Composite(
self.single_action_spec = Composite(
lazy=action_specs,
device=self.device,
)
Expand Down Expand Up @@ -1423,8 +1423,8 @@ def _make_specs(self):
device=self.device,
)

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
self.action_spec = self.single_action_spec.expand(
*self.batch_size, *self.single_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
Expand Down Expand Up @@ -1610,8 +1610,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):

self.make_specs()

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
self.action_spec = self.single_action_spec.expand(
*self.batch_size, *self.single_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
Expand Down Expand Up @@ -1642,7 +1642,7 @@ def make_specs(self):
),
)

self.unbatched_action_spec = Composite(
self.single_action_spec = Composite(
nested_1=Composite(
action=Categorical(n=2, shape=(self.nested_dim_1,)),
shape=(self.nested_dim_1,),
Expand Down
5 changes: 5 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,11 @@ def test_vmas_batch_size(self, scenario_name, num_envs, n_agents):
max_steps=n_rollout_samples,
return_contiguous=False if env.het_specs else True,
)
assert env.single_full_action_spec.shape == env.unbatched_action_spec.shape, (
env.action_spec,
env.batch_size,
)

env.close()

if env.het_specs:
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
"low": env.single_action_spec.space.low,
"high": env.single_action_spec.space.high,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@
out_keys=[(group, "action")],
distribution_class=TanhDelta,
distribution_kwargs={
"low": env.full_action_spec[group, "action"].space.low,
"high": env.full_action_spec[group, "action"].space.high,
"low": env.single_full_action_spec[group, "action"].space.low,
"high": env.single_full_action_spec[group, "action"].space.high,
},
return_log_prob=False,
)
Expand Down
6 changes: 3 additions & 3 deletions tutorials/sphinx-tutorials/multiagent_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,13 @@

policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.single_action_spec,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[env.action_key].space.low,
"high": env.unbatched_action_spec[env.action_key].space.high,
"low": env.single_action_spec[env.action_key].space.low,
"high": env.single_action_spec[env.action_key].space.high,
},
return_log_prob=True,
log_prob_key=("agents", "sample_log_prob"),
Expand Down

0 comments on commit c5e83e6

Please sign in to comment.