Skip to content

Commit

Permalink
[Algorithm] Update SAC Example (pytorch#1524)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
BY571 and vmoens committed Oct 10, 2023
1 parent f847e69 commit 086a7cd
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 124 deletions.
14 changes: 6 additions & 8 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,10 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
collector.collector_device=cuda:0 \
optimization.batch_size=10 \
optimization.utd_ratio=1 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
Expand Down Expand Up @@ -221,17 +220,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.collector_device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
optimization.batch_size=10 \
optimization.utd_ratio=1 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
logger.backend=
# record_video=True \
# record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
total_frames=48 \
batch_size=10 \
Expand Down
40 changes: 20 additions & 20 deletions examples/sac/config.yaml
Original file line number Diff line number Diff line change
@@ -1,49 +1,49 @@
# Environment
# environment and task
env:
name: HalfCheetah-v3
task: ""
exp_name: "HalfCheetah-SAC"
library: gym
frame_skip: 1
seed: 1
exp_name: ${env.name}_SAC
library: gymnasium
max_episode_steps: 1000
seed: 42

# Collection
# collector
collector:
total_frames: 1000000
init_random_frames: 10000
total_frames: 1_000_000
init_random_frames: 25000
frames_per_batch: 1000
max_frames_per_traj: 1000
init_env_steps: 1000
async_collection: 1
collector_device: cpu
env_per_collector: 1
num_workers: 1
reset_at_each_iter: False

# Replay Buffer
# replay buffer
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: ${env.exp_name}_${env.seed}

# Optimization
optimization:
# optim
optim:
utd_ratio: 1.0
gamma: 0.99
loss_function: smooth_l1
lr: 3e-4
weight_decay: 2e-4
lr_scheduler: ""
loss_function: l2
lr: 3.0e-4
weight_decay: 0.0
batch_size: 256
target_update_polyak: 0.995
alpha_init: 1.0
adam_eps: 1.0e-8

# Algorithm
# network
network:
hidden_sizes: [256, 256]
activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"

# Logging
# logging
logger:
backend: wandb
mode: online
Expand Down
158 changes: 94 additions & 64 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
The helper functions are coded in the utils.py associated with this script.
"""

import time

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm

from tensordict import TensorDict
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_environment,
make_loss_module,
Expand All @@ -35,6 +38,7 @@
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("SAC", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
Expand All @@ -48,132 +52,158 @@ def main(cfg: "DictConfig"): # noqa: F821
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create Environments
# Create environments
train_env, eval_env = make_environment(cfg)
# Create Agent

# Create agent
model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device)

# Create TD3 loss
# Create SAC loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Make Off-Policy Collector
# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)

# Make Replay Buffer
# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optimization.batch_size,
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir,
device=device,
)

# Make Optimizers
optimizer = make_sac_optimizer(cfg, loss_module)

rewards = []
rewards_eval = []
# Create optimizers
(
optimizer_actor,
optimizer_critic,
optimizer_alpha,
) = make_sac_optimizer(cfg, loss_module)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
r0 = None
q_loss = None

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optimization.utd_ratio
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
env_per_collector = cfg.collector.env_per_collector
eval_iter = cfg.logger.eval_iter
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for i, tensordict in enumerate(collector):
# update weights of the inference policy
sampling_time = time.time() - sampling_start

# Update weights of the inference policy
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

tensordict = tensordict.view(-1)
tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(actor_losses, q_losses, alpha_losses) = ([], [], [])
for _ in range(num_updates):
# sample from replay buffer
losses = TensorDict(
{},
batch_size=[
num_updates,
],
)
for i in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
alpha_loss = loss_td["loss_alpha"]
loss = actor_loss + q_loss + alpha_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())
alpha_losses.append(alpha_loss.item())
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

losses[i] = loss_td.select(
"loss_actor", "loss_qvalue", "loss_alpha"
).detach()

# Update qnet_target params
target_net_updater.step()

# update priority
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append(
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
}
if q_loss is not None:
train_log.update(
{
"actor_loss": np.mean(actor_losses),
"q_loss": np.mean(q_losses),
"alpha_loss": np.mean(alpha_losses),
"alpha": loss_td["alpha"],
"entropy": loss_td["entropy"],
}
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if logger is not None:
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item()
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
metrics_to_log["train/alpha"] = loss_td["alpha"].item()
metrics_to_log["train/entropy"] = loss_td["entropy"].item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
if logger is not None:
logger.log_scalar(
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
)
if len(rewards_eval):
pbar.set_description(
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
)
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 086a7cd

Please sign in to comment.