Skip to content

Commit

Permalink
[BugFix] action_spec_unbatched whenever necessary
Browse files Browse the repository at this point in the history
ghstack-source-id: f346c47cd2d87a9306059e3ca56affcc68a7ff9c
Pull Request resolved: #2592
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent a47b32c commit acd00a1
Show file tree
Hide file tree
Showing 24 changed files with 141 additions and 106 deletions.
4 changes: 2 additions & 2 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
"low": env.action_spec_unbatched.space.low,
"high": env.action_spec_unbatched.space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": True,
}
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def make_offline_replay_buffer(rb_cfg):
def make_cql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

action_spec = train_env.action_spec
action_spec = train_env.action_spec_unbatched

actor_net, q_net = make_cql_modules_state(model_cfg, eval_env)
in_keys = ["observation"]
Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def make_crossQ_agent(cfg, train_env, device):
"""Make CrossQ agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def make_dt_model(cfg):
make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1
)

action_spec = proof_environment.action_spec
action_spec = proof_environment.action_spec_unbatched
for key, value in proof_environment.observation_spec.items():
if key == "observation":
state_dim = value.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)

out_keys = ["loc", "scale"]
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
),
)

Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "param")],
out_keys=[env.action_key],
distribution_class=TanhDelta,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=False,
)

policy_explore = TensorDictSequential(
policy,
AdditiveGaussianModule(
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
),
)

Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821

policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down Expand Up @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "logits")],
out_keys=[env.action_key],
distribution_class=OneHotCategorical
Expand Down Expand Up @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.unbatched_action_spec,
action_spec=env.full_action_spec_unbatched,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
Expand All @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821
qvalue_network=value_module,
delay_qvalue=True,
num_actions=env.action_spec.space.n,
action_space=env.unbatched_action_spec,
action_space=env.full_action_spec_unbatched,
)
loss_module.set_keys(
action_value=("agents", "action_value"),
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def make_redq_model(
default_policy_scale = cfg.network.default_policy_scale
gSDE = cfg.exploration.gSDE

action_spec = proof_environment.action_spec
action_spec = proof_environment.action_spec_unbatched

if actor_net_kwargs is None:
actor_net_kwargs = {}
Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def make_sac_agent(cfg, train_env, eval_env, device):
"""Make SAC agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
42 changes: 8 additions & 34 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,17 +1388,17 @@ def _make_specs(self):
obs_spec_unlazy = consolidate_spec(obs_specs)
action_specs = torch.stack(action_specs, dim=0)

self.unbatched_observation_spec = Composite(
self.observation_spec_unbatched = Composite(
lazy=obs_spec_unlazy,
state=Unbounded(shape=(64, 64, 3)),
device=self.device,
)

self.unbatched_action_spec = Composite(
self.action_spec_unbatched = Composite(
lazy=action_specs,
device=self.device,
)
self.unbatched_reward_spec = Composite(
self.reward_spec_unbatched = Composite(
{
"lazy": Composite(
{"reward": Unbounded(shape=(self.n_nested_dim, 1))},
Expand All @@ -1407,7 +1407,7 @@ def _make_specs(self):
},
device=self.device,
)
self.unbatched_done_spec = Composite(
self.done_spec_unbatched = Composite(
{
"lazy": Composite(
{
Expand All @@ -1423,19 +1423,6 @@ def _make_specs(self):
device=self.device,
)

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
)

def get_agent_obs_spec(self, i):
camera = Bounded(low=0, high=200, shape=(7, 7, 3))
vector_3d = Unbounded(shape=(3,))
Expand Down Expand Up @@ -1610,21 +1597,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):

self.make_specs()

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
)

def make_specs(self):
self.unbatched_observation_spec = Composite(
self.observation_spec_unbatched = Composite(
nested_1=Composite(
observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)),
shape=(self.nested_dim_1,),
Expand All @@ -1642,7 +1616,7 @@ def make_specs(self):
),
)

self.unbatched_action_spec = Composite(
self.action_spec_unbatched = Composite(
nested_1=Composite(
action=Categorical(n=2, shape=(self.nested_dim_1,)),
shape=(self.nested_dim_1,),
Expand All @@ -1654,7 +1628,7 @@ def make_specs(self):
action=OneHot(n=2),
)

self.unbatched_reward_spec = Composite(
self.reward_spec_unbatched = Composite(
nested_1=Composite(
gift=Unbounded(shape=(self.nested_dim_1, 1)),
shape=(self.nested_dim_1,),
Expand All @@ -1666,7 +1640,7 @@ def make_specs(self):
reward=Unbounded(shape=(1,)),
)

self.unbatched_done_spec = Composite(
self.done_spec_unbatched = Composite(
nested_1=Composite(
done=Categorical(
n=2,
Expand Down
24 changes: 12 additions & 12 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3512,18 +3512,18 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi

def test_single_env_spec():
env = NestedCountingEnv(batch_size=[3, 1, 7])
assert not env.single_full_action_spec.shape
assert not env.single_full_done_spec.shape
assert not env.single_input_spec.shape
assert not env.single_full_observation_spec.shape
assert not env.single_output_spec.shape
assert not env.single_full_reward_spec.shape

assert env.single_action_spec.shape
assert env.single_reward_spec.shape

assert env.output_spec.is_in(env.single_output_spec.zeros(env.shape))
assert env.input_spec.is_in(env.single_input_spec.zeros(env.shape))
assert not env.full_action_spec_unbatched.shape
assert not env.full_done_spec_unbatched.shape
assert not env.input_spec_unbatched.shape
assert not env.full_observation_spec_unbatched.shape
assert not env.output_spec_unbatched.shape
assert not env.full_reward_spec_unbatched.shape

assert env.action_spec_unbatched.shape
assert env.reward_spec_unbatched.shape

assert env.output_spec.is_in(env.output_spec_unbatched.zeros(env.shape))
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,13 @@ def test_vmas_batch_size(self, scenario_name, num_envs, n_agents):
max_steps=n_rollout_samples,
return_contiguous=False if env.het_specs else True,
)
assert (
env.full_action_spec_unbatched.shape == env.unbatched_action_spec.shape
), (
env.action_spec,
env.batch_size,
)

env.close()

if env.het_specs:
Expand Down
Loading

0 comments on commit acd00a1

Please sign in to comment.