diff --git a/.github/scripts/m1_script.sh b/.github/scripts/td_script.sh similarity index 71% rename from .github/scripts/m1_script.sh rename to .github/scripts/td_script.sh index 6552d8e4622..6da1cad5d79 100644 --- a/.github/scripts/m1_script.sh +++ b/.github/scripts/td_script.sh @@ -1,5 +1,5 @@ #!/bin/bash -export TORCHRL_BUILD_VERSION=0.4.0 +export TORCHRL_BUILD_VERSION=0.5.0 ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index f91c050dde0..f8b700c0410 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -9,9 +9,19 @@ # # -set -e +#set -e set -v +# Initialize an error flag +error_occurred=0 +# Function to handle errors +error_handler() { + echo "Error on line $1" + error_occurred=1 +} +# Trap ERR to call the error_handler function with the failing line number +trap 'error_handler $LINENO' ERR + export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" @@ -24,6 +34,7 @@ lib_dir="${env_dir}/lib" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU +export CUDA_LAUNCH_BLOCKING=1 python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 #python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 @@ -163,18 +174,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cr env.name=Pendulum-v1 \ network.device= \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - collector.total_frames=200 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=200 \ - env.n_parallel_envs=4 \ - optimization.optim_steps_per_batch=1 \ - logger.video=True \ - logger.backend=csv \ - replay_buffer.buffer_size=120 \ - replay_buffer.batch_size=24 \ - replay_buffer.batch_length=12 \ - networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -214,8 +213,8 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr collector.frames_per_batch=200 \ env.n_parallel_envs=1 \ optimization.optim_steps_per_batch=1 \ - logger.backend=csv \ logger.video=True \ + logger.backend=csv \ replay_buffer.buffer_size=120 \ replay_buffer.batch_size=24 \ replay_buffer.batch_length=12 \ @@ -312,3 +311,11 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ba coverage combine coverage xml -i + +# Check if any errors occurred during the script execution +if [ "$error_occurred" -ne 0 ]; then + echo "Errors occurred during script execution" + exit 1 +else + echo "Script executed successfully" +fi diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml index 5171a7c3e2a..f51c5ed79b6 100644 --- a/.github/workflows/build-wheels-linux.yml +++ b/.github/workflows/build-wheels-linux.yml @@ -45,3 +45,4 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index 84fe79d09d2..73a365a79f2 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -46,4 +46,4 @@ jobs: runner-type: macos-m1-stable smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} - env-var-script: .github/scripts/m1_script.sh + env-var-script: .github/scripts/td_script.sh diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 683f2a93f69..1beef7318f4 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -46,3 +46,4 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh diff --git a/setup.py b/setup.py index 0196cb4a8f4..95dc0802a4f 100644 --- a/setup.py +++ b/setup.py @@ -172,7 +172,7 @@ def _main(argv): if is_nightly: tensordict_dep = "tensordict-nightly" else: - tensordict_dep = "tensordict>=0.4.0" + tensordict_dep = "tensordict>=0.5.0" if is_nightly: version = get_nightly_version() diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 775dcfe206d..f8c18147306 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 0276039058f..d115174eb9c 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval final = collected_frames >= collector.total_frames diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index d8185c8091c..5ca70f83b53 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 5f8f81357c8..cf629ed0733 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // evaluation_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 4b6f14cd058..d0d6693eb97 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index eb0b88c26f7..a92ee6185c3 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 59dbcafd8c9..9cca9fd8af5 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -56,7 +56,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video) + test_env = make_env( + cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video, device=model_device + ) if cfg.logger.video: test_env = test_env.append_transform( VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) @@ -114,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 to_log = {"train/loss": loss_vals["loss"]} # Evaluation - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): if i % pretrain_log_interval == 0: eval_td = test_env.rollout( max_steps=eval_steps, diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 5cb297e5c0b..da2241ce9fa 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821 } # Evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 7c9500aa4e7..409833c75fa 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -57,7 +57,7 @@ # ----------------- -def make_base_env(env_cfg, from_pixels=False): +def make_base_env(env_cfg, from_pixels=False, device=None): set_gym_backend(env_cfg.backend).set() env_library = LIBS[env_cfg.library] @@ -73,7 +73,7 @@ def make_base_env(env_cfg, from_pixels=False): if env_library is DMControlEnv: env_task = env_cfg.task env_kwargs.update({"task_name": env_task}) - env = env_library(**env_kwargs) + env = env_library(**env_kwargs, device=device) return env @@ -134,7 +134,9 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): return transformed_env -def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): +def make_parallel_env( + env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None +): if train: num_envs = env_cfg.num_train_envs else: @@ -142,10 +144,12 @@ def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False) def make_env(): with set_gym_backend(env_cfg.backend): - return make_base_env(env_cfg, from_pixels=from_pixels) + return make_base_env(env_cfg, from_pixels=from_pixels, device="cpu") env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True), + ParallelEnv( + num_envs, EnvCreator(make_env), serial_for_single=True, device=device + ), env_cfg, obs_loc, obs_std, @@ -154,11 +158,15 @@ def make_env(): return env -def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): - env = make_parallel_env( - env_cfg, obs_loc, obs_std, train=train, from_pixels=from_pixels +def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None): + return make_parallel_env( + env_cfg, + obs_loc, + obs_std, + train=train, + from_pixels=from_pixels, + device=device, ) - return env # ==================================================================== diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 6e100f92dc3..386f743c7d3 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // eval_iter final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 90f93551d4d..906273ee2f5 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -199,7 +199,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index ac3f17a9203..173f88f7028 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -180,7 +180,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index ab101e8486a..604e1ac546a 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -9,17 +9,13 @@ env: image_size : 64 horizon: 500 n_parallel_envs: 8 - device: - _target_: dreamer_utils._default_device - device: null + device: cpu collector: total_frames: 5_000_000 init_random_frames: 3000 frames_per_batch: 1000 device: - _target_: dreamer_utils._default_device - device: null optimization: train_every: 1000 @@ -41,8 +37,6 @@ optimization: networks: exploration_noise: 0.3 device: - _target_: dreamer_utils._default_device - device: null state_dim: 30 rssm_hidden_dim: 200 hidden_dim: 400 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e7b346b2b22..e521b9df386 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -10,6 +10,7 @@ import torch.cuda import tqdm from dreamer_utils import ( + _default_device, dump_video, log_metrics, make_collector, @@ -17,7 +18,6 @@ make_environments, make_replay_buffer, ) -from hydra.utils import instantiate # mixed precision training from torch.cuda.amp import GradScaler @@ -38,7 +38,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # cfg = correct_for_frame_skip(cfg) - device = torch.device(instantiate(cfg.networks.device)) + device = _default_device(cfg.networks.device) # Create logger exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) @@ -284,7 +284,7 @@ def compile_rssms(module): # Evaluation if (i % eval_iter) == 0: # Real env - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_rollout = test_env.rollout( eval_rollout_steps, policy, @@ -298,7 +298,9 @@ def compile_rssms(module): log_metrics(logger, eval_metrics, collected_frames) # Simulated env if model_based_env_eval is not None: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): eval_rollout = model_based_env_eval.rollout( eval_rollout_steps, policy, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index ff14871b011..73baa310821 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from hydra.utils import instantiate from tensordict import NestedKey from tensordict.nn import ( InteractionType, @@ -88,6 +87,7 @@ def _make_env(cfg, device, from_pixels=False): cfg.env.task, from_pixels=cfg.env.from_pixels or from_pixels, pixels_only=cfg.env.from_pixels, + device=device, ) else: raise NotImplementedError(f"Unknown lib {lib}.") @@ -98,7 +98,6 @@ def _make_env(cfg, device, from_pixels=False): env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) ) - assert env is not None return env @@ -129,7 +128,7 @@ def transform_env(cfg, env): def make_environments(cfg, parallel_envs=1, logger=None): """Make environments for training and evaluation.""" - func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device) + func = functools.partial(_make_env, cfg=cfg, device=_default_device(cfg.env.device)) train_env = ParallelEnv( parallel_envs, EnvCreator(func), @@ -138,7 +137,10 @@ def make_environments(cfg, parallel_envs=1, logger=None): train_env = transform_env(cfg, train_env) train_env.set_seed(cfg.env.seed) func = functools.partial( - _make_env, cfg=cfg, device=cfg.env.device, from_pixels=cfg.logger.video + _make_env, + cfg=cfg, + device=_default_device(cfg.env.device), + from_pixels=cfg.logger.video, ) eval_env = ParallelEnv( 1, @@ -332,7 +334,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - policy_device=instantiate(cfg.collector.device), + policy_device=_default_device(cfg.collector.device), env_device=train_env.device, storing_device="cpu", ) @@ -535,7 +537,7 @@ def _dreamer_make_actor_real( SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=[action_key], - default_interaction_type=InteractionType.MODE, + default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, spec=CompositeSpec( diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index 0482a595ffa..1998c044305 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -247,7 +247,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index ce96cf06ce8..fdee4256c42 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -239,7 +239,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index bb0f314197a..cf583909620 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 33513dd3973..ae1894379fd 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index d98724e1371..d1a16fd8192 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -130,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index b66c6f9dcf2..d50ff806294 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 81551ebefb7..a4d2b88a9d0 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index 9d14ff04b04..bd44bb0a043 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index e752c4d73f2..fa006a7d4a2 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -236,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index d294a9c783e..4e6a962c556 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 30b7e7e98bc..f7b2523010b 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 908cb7924a3..2b02254032a 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index e3e74971a49..219ae1b59b6 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -210,7 +210,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index f7a399cda72..9904fe072ab 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 97fd039c238..5fbc9b032d7 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -195,7 +195,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/test/test_env.py b/test/test_env.py index e6ca38b729c..f8f242f3955 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2061,6 +2061,7 @@ def main_collector(j, q=None): total_frames=N * n_workers * 100, storing_device=device, device=device, + cat_results=-1, ) single_collectors = [ SyncDataCollector( diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 50e3dd5cc49..32294a25edd 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2065,18 +2065,18 @@ def _queue_len(self) -> int: def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: - cat_results = 0 + cat_results = "stack" warnings.warn( f"`cat_results` was not specified in the constructor of {type(self).__name__}. " f"For MultiSyncDataCollector, `cat_results` indicates how the data should " - f"be packed: the preferred option is `cat_results='stack'` which provides " - f"the best interoperability across torchrl components. " + f"be packed: the preferred option and current default is `cat_results='stack'` " + f"which provides the best interoperability across torchrl components. " f"Other accepted values are `cat_results=0` (previous behaviour) and " f"`cat_results=-1` (cat along time dimension). Among these two, the latter " f"should be preferred for consistency across environment configurations. " - f"Currently, the default value is `0` (using torch.cat along first dimension)." - f"From v0.5 onward, this will default to `'stack'`. " - f"To suppress this warning, set stack_results to the desired value.", + f"Currently, the default value is `'stack'`." + f"From v0.6 onward, this warning will be removed. " + f"To suppress this warning, set `cat_results` to the desired value.", category=DeprecationWarning, ) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 04c24cb8d57..0006213cd27 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1143,6 +1143,7 @@ def __eq__(self, other): if not isinstance(other, LazyStackedTensorSpec): return False if self.device != other.device: + raise RuntimeError((self, other)) return False if len(self._specs) != len(other._specs): return False diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 4241f6613a0..4996e527527 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -406,17 +406,16 @@ def _find_sync_values(self): return _do_nothing, _do_nothing if worker_device is None: - worker_not_main = [False] + worker_not_main = False - def find_all_worker_devices(item, worker_not_main=worker_not_main): + def find_all_worker_devices(item): + nonlocal worker_not_main if hasattr(item, "device"): - worker_not_main[0] = worker_not_main[0] or ( - item.device != self_device - ) + worker_not_main = worker_not_main or (item.device != self_device) for td in self.shared_tensordicts: td.apply(find_all_worker_devices, filter_empty=True) - if worker_not_main[0]: + if worker_not_main: if torch.cuda.is_available(): worker_device = ( torch.device("cuda") @@ -431,6 +430,8 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main): ) else: raise RuntimeError("Did not find a valid worker device") + else: + worker_device = self_device if ( worker_device is not None @@ -460,6 +461,7 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main): and self_device.type == "mps" ): return _mps_sync(self_device), _mps_sync(self_device) + return _do_nothing, _do_nothing def __getstate__(self): out = copy(self.__dict__) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c965e7dedf3..e30de3534d9 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key -from tensordict.base import NO_DEFAULT from tensordict.utils import NestedKey from torchrl._utils import ( _ends_with, @@ -3020,21 +3019,11 @@ class _EnvWrapper(EnvBase): def __init__( self, *args, - device: DEVICE_TYPING = NO_DEFAULT, + device: DEVICE_TYPING = None, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, **kwargs, ): - if device is NO_DEFAULT: - warnings.warn( - "Your wrapper was not given a device. Currently, this " - "value will default to 'cpu'. From v0.5 it will " - "default to `None`. With a device of None, no device casting " - "is performed and the resulting tensordicts are deviceless. " - "Please set your device accordingly.", - category=DeprecationWarning, - ) - device = torch.device("cpu") super().__init__( device=device, batch_size=batch_size, diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 47f93f09779..c7935272c91 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -348,8 +348,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: batch_size=tensordict.batch_size, ) if self.device is not None: - tensordict_out = tensordict_out.to(self.device, non_blocking=True) - self._sync_device() + tensordict_out = tensordict_out.to(self.device) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -393,8 +392,7 @@ def _reset( if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() if self.device is not None: - tensordict_out = tensordict_out.to(self.device, non_blocking=True) - self._sync_device() + tensordict_out = tensordict_out.to(self.device) return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 07c48587c14..9195929e31d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -27,7 +27,6 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -246,8 +245,8 @@ def _gym_to_torchrl_spec_transform( ).expand(batch_size) gym_spaces = gym_backend("spaces") if isinstance(spec, gym_spaces.tuple.Tuple): - result = LazyStackedTensorSpec( - *[ + result = torch.stack( + [ _gym_to_torchrl_spec_transform( s, device=device, diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 8e30fdb2a7e..9751e84a3ac 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -795,7 +795,9 @@ def _build_env( env=vmas.make_env( scenario=scenario, num_envs=num_envs, - device=self.device, + device=self.device + if self.device is not None + else torch.get_default_device(), continuous_actions=continuous_actions, max_steps=max_steps, seed=seed, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index eb9cdce923d..70aef03e041 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3411,14 +3411,7 @@ def __init__( out_keys_inv: Sequence[NestedKey] | None = None, ): if in_keys is not None and in_keys_inv is None: - warnings.warn( - "in_keys have been provided but not in_keys_inv. From v0.5, " - "this will result in in_keys_inv being an empty list whereas " - "now the input keys are retrieved automatically. " - "To silence this warning, pass the (possibly empty) " - "list of in_keys_inv.", - category=DeprecationWarning, - ) + in_keys_inv = [] self.dtype_in = dtype_in self.dtype_out = dtype_out diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 087cabe4186..38d8d1dfd02 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -481,9 +481,10 @@ def root_dist(self): @property def mode(self): warnings.warn( - "This computation of the mode is based on the first-order Taylor expansion " - "of the transform around the normal mean value, which can be inaccurate. " + "This computation of the mode is based on an inaccurate estimation of the mode " + "given the base_dist mode. " "To use a more stable implementation of the mode, use dist.get_mode() method instead. " + "To silence this warning, consider using the DETERMINISTIC exploration_type." "This implementation will be removed in v0.6.", category=DeprecationWarning, ) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 17b1ea77ee4..83b6a8d1fb3 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union import torch @@ -922,10 +921,9 @@ def __init__( out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, + raise RuntimeError( + "Using specs in action_space is deprecated. " + "Please use the 'spec' argument if you want to provide an action spec" ) action_space, _ = _process_action_space_spec(action_space, None) @@ -1136,10 +1134,9 @@ def __init__( action_mask_key: Optional[NestedKey] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, + raise RuntimeError( + "Using specs in action_space is deprecated." + "Please use the 'spec' argument if you want to provide an action spec" ) action_space, spec = _process_action_space_spec(action_space, spec) diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 2c2f3fb21ac..b7fb8ab4ed2 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -221,11 +221,11 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation_trsf = make_grid( obs_flat, nrow=int(math.ceil(math.sqrt(obs_flat.shape[0]))) ) - self.obs.append(observation_trsf.to(torch.uint8)) + self.obs.append(observation_trsf.to("cpu", torch.uint8)) elif observation_trsf.ndimension() >= 4: - self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4)) + self.obs.extend(observation_trsf.to("cpu", torch.uint8).flatten(0, -4)) else: - self.obs.append(observation_trsf.to(torch.uint8)) + self.obs.append(observation_trsf.to("cpu", torch.uint8)) return observation def forward(self, tensordict: TensorDictBase) -> TensorDictBase: diff --git a/version.txt b/version.txt index 1d0ba9ea182..8f0916f768f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.0 +0.5.0