From 2ecbb3f71f2bdfbc4a76b44f4607a4d91979f9c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sat, 4 Mar 2023 19:05:17 +0800 Subject: [PATCH 01/26] init commit --- ding/entry/__init__.py | 1 + ding/entry/serial_entry_pc_mcts.py | 122 +++++++++++ ding/model/template/procedure_cloning.py | 69 ++++-- .../template/tests/test_procedure_cloning.py | 33 +-- ding/policy/pc.py | 201 ++++++++++++++++++ .../serial/qbert/qbert_pc_mcts_config.py | 44 ++++ 6 files changed, 437 insertions(+), 33 deletions(-) create mode 100644 ding/entry/serial_entry_pc_mcts.py create mode 100644 ding/policy/pc.py create mode 100644 dizoo/atari/config/serial/qbert/qbert_pc_mcts_config.py diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index 1e90351ee4..87d06d2d68 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -27,3 +27,4 @@ from .application_entry_drex_collect_data import drex_collecting_data from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream from .serial_entry_bco import serial_pipeline_bco +from .serial_entry_pc_mcts import serial_pipeline_pc_mcts diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py new file mode 100644 index 0000000000..d6f8c1a42a --- /dev/null +++ b/ding/entry/serial_entry_pc_mcts.py @@ -0,0 +1,122 @@ +from typing import Union, Optional, Tuple +import os +import torch +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy +from torch.utils.data import DataLoader, Dataset +import pickle + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed + + +class MCTSPCDataset(Dataset): + + def __init__(self, data_dic, seq_len=4): + self.observations = data_dic['obs'] + self.actions = data_dic['actions'] + self.hidden_states = data_dic['hidden_state'] + self.seq_len = seq_len + self.length = len(self.observations) - seq_len - 1 + + def __getitem__(self, idx): + """ + Assume the trajectory is: o1, h2, h3, h4 + """ + return { + 'obs': self.observations[idx], + 'hidden_states': list(reversed(self.hidden_states[idx+1: idx+self.seq_len+1])), + 'action': self.actions[idx] + } + + def __len__(self): + return self.length + + +def load_mcts_datasets(path, batch_size=32): + with open(path, 'rb') as f: + dic = pickle.load(f) + return DataLoader(MCTSPCDataset(dic), shuffle=True, batch_size=batch_size) + + +def serial_pipeline_pc_mcts( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + max_iter=int(1e6), +) -> Union['Policy', bool]: # noqa + r""" + Overview: + Serial pipeline entry of imitation learning. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - data_path (:obj:`str`): Path of training data. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + Returns: + - policy (:obj:`Policy`): Converged policy. + - convergence (:obj:`bool`): whether il training is converged + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + + # Env, Policy + env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + # Random seed + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) + + # Main components + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + + # ========== + # Main loop + # ========== + learner.call_hook('before_run') + stop = False + iter_cnt = 0 + for epoch in range(cfg.policy.learn.train_epoch): + # train + criterion = torch.nn.CrossEntropyLoss() + for i, train_data in enumerate(dataloader): + learner.train(train_data) + iter_cnt += 1 + if iter_cnt >= max_iter: + stop = True + break + if epoch % 69 == 0: + policy._optimizer.param_groups[0]['lr'] /= 10 + if stop: + break + losses = [] + acces = [] + for _, test_data in enumerate(test_dataloader): + logits = policy._model.forward_eval(test_data) + + loss = criterion(logits, test_data['action']).item() + preds = torch.argmax(logits, dim=-1) + acc = torch.sum((preds == test_data['action'])) / preds.shape[0] + + losses.append(loss) + acces.append(acc) + print('Test Finished! Loss: {} acc: {}'.format(sum(losses) / len(losses), sum(acces) / len(acces))) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + learner.call_hook('after_run') + print('final reward is: {}'.format(reward)) + return policy, stop diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index a86e813933..6fc574f77e 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -48,12 +48,14 @@ class ProcedureCloning(nn.Module): def __init__( self, obs_shape: SequenceType, + hidden_shape: SequenceType, action_dim: int, + seq_len: int, cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], cnn_activation: Optional[nn.Module] = nn.ReLU(), cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], cnn_stride: SequenceType = [1, 1, 1, 1, 1], - cnn_padding: Optional[SequenceType] = ['same', 'same', 'same', 'same', 'same'], + cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1], mlp_hidden_list: SequenceType = [256, 256], mlp_activation: Optional[nn.Module] = nn.ReLU(), att_heads: int = 8, @@ -63,15 +65,21 @@ def __init__( feedforward_hidden: int = 256, drop_p: float = 0.5, augment: bool = True, - max_T: int = 17 ) -> None: super().__init__() + self.obs_shape = obs_shape + self.hidden_shape = hidden_shape + self.seq_len = seq_len + max_T = seq_len + 1 #Conv Encoder + print(cnn_padding) self.embed_state = ConvEncoder( obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding ) - self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation) + self.embed_hidden = ConvEncoder( + hidden_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding + ) self.cnn_hidden_list = cnn_hidden_list self.augment = augment @@ -95,25 +103,52 @@ def __init__( cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward ) - self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) + self.predict_hidden_state = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim) - def forward(self, states: torch.Tensor, goals: torch.Tensor, - actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - B, T, _ = actions.shape + def _compute_embeddings(self, states: torch.Tensor, hidden_states: torch.Tensor): + B, T, *_ = hidden_states.shape - # shape: (B, h_dim) + # shape: (B, 1, h_dim) state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1]) - goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1]) - # shape: (B, context_len, h_dim) - actions_embeddings = self.embed_action(actions) + # shape: (B, T, h_dim) + hidden_state_embeddings = self.embed_hidden(hidden_states.reshape(B * T, *hidden_states.shape[2:])) \ + .reshape(B, T, self.cnn_hidden_list[-1]) + return state_embeddings, hidden_state_embeddings - h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1) + def _compute_transformer(self, h): + B, T, *_ = h.shape h = self.transformer(h) - h = h.reshape(B, T + 2, self.cnn_hidden_list[-1]) + h = h.reshape(B, T, self.cnn_hidden_list[-1]) + + hidden_state_preds = self.predict_hidden_state(h[:, 0:-1, ...]) + action_preds = self.predict_action(h[:, -1, :]) + return hidden_state_preds, action_preds + + def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \ + -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # State is current observation. + # Hidden states is a sequence including [L, R, ...]. + # The shape of state and hidden state may be different. + B, T, *_ = hidden_states.shape + assert T == self.seq_len + state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) + + h = torch.cat((state_embeddings, hidden_state_embeddings), dim=1) + hidden_state_preds, action_preds = self._compute_transformer(h) + + return hidden_state_preds, action_preds, hidden_state_embeddings.detach() + + def forward_eval(self, states: torch.Tensor) -> torch.Tensor: + batch_size = states.shape[0] + hidden_states = torch.zeros(batch_size, self.seq_len, *self.hidden_shape, dtype=states.dtype).to(states.device) + embedding_mask = torch.zeros(1, self.seq_len, 1) + + state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) - goal_preds = self.predict_goal(h[:, 0, :]) - action_preds = self.predict_action(h[:, 1:, :]) + for i in range(self.seq_len): + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + hidden_state_embeddings, action_pred = self._compute_transformer(h) + embedding_mask[0, i, 0] = 1 - return goal_preds, action_preds + return action_pred diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index e169ec2cee..adf718a796 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -9,26 +9,27 @@ B = 4 T = 15 -obs_shape = [(64, 64, 3)] -action_dim = [9] +obs_shape = (64, 64, 3) +hidden_shape = (9, 9, 64) +action_dim = 9 obs_embeddings = 256 -args = list(product(*[obs_shape, action_dim])) @pytest.mark.unittest -@pytest.mark.parametrize('obs_shape, action_dim', args) -class TestProcedureCloning: +def test_procedure_cloning(): + inputs = { + 'states': torch.randn(B, *obs_shape), + 'hidden_states': torch.randn(B, T, *hidden_shape), + 'actions': torch.randn(B, action_dim) + } + model = ProcedureCloning(obs_shape=obs_shape, hidden_shape=hidden_shape, + seq_len=T, action_dim=action_dim) - def test_procedure_cloning(self, obs_shape, action_dim): - inputs = { - 'states': torch.randn(B, *obs_shape), - 'goals': torch.randn(B, *obs_shape), - 'actions': torch.randn(B, T, action_dim) - } - model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim) + print(model) - print(model) + hidden_state_preds, action_preds, target_hidden_state = model(inputs['states'], inputs['hidden_states']) + assert hidden_state_preds.shape == (B, T, obs_embeddings) + assert action_preds.shape == (B, action_dim) - goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) - assert goal_preds.shape == (B, obs_embeddings) - assert action_preds.shape == (B, T + 1, action_dim) + action_eval = model.forward_eval(inputs['states']) + assert action_eval.shape == (B, action_dim) diff --git a/ding/policy/pc.py b/ding/policy/pc.py new file mode 100644 index 0000000000..6de159b481 --- /dev/null +++ b/ding/policy/pc.py @@ -0,0 +1,201 @@ +import math +import torch +import torch.nn as nn +from torch.optim import Adam, SGD, AdamW +from torch.optim.lr_scheduler import LambdaLR +import logging +from typing import List, Dict, Any, Tuple, Union, Optional +from collections import namedtuple +from easydict import EasyDict +from ding.policy import Policy +from ding.model import model_wrap +from ding.torch_utils import to_device, to_list +from ding.utils import EasyTimer +from ding.utils.data import default_collate, default_decollate +from ding.rl_utils import get_nstep_return_data, get_train_sample +from ding.utils import POLICY_REGISTRY +from ding.torch_utils.loss.cross_entropy_loss import LabelSmoothCELoss + + +@POLICY_REGISTRY.register('pc_mcts') +class ProcedureCloningPolicyMCTS(Policy): + + config = dict( + type='pc_mcts', + cuda=True, + on_policy=False, + continuous=False, + learn=dict( + multi_gpu=False, + update_per_collect=1, + batch_size=32, + learning_rate=1e-5, + lr_decay=False, + decay_epoch=30, + decay_rate=0.1, + warmup_lr=1e-4, + warmup_epoch=3, + optimizer='SGD', + momentum=0.9, + weight_decay=1e-4, + ce_label_smooth=False, + show_accuracy=False, + tanh_mask=False, # if actions always converge to 1 or -1, use this. + ), + collect=dict( + unroll_len=1, + noise=False, + noise_sigma=0.2, + noise_range=dict( + min=-0.5, + max=0.5, + ), + ), + eval=dict(), + other=dict(replay_buffer=dict(replay_buffer_size=10000, )), + ) + + def default_model(self) -> Tuple[str, List[str]]: + if self._cfg.continuous: + return 'continuous_bc', ['ding.model.template.bc'] + else: + return 'discrete_bc', ['ding.model.template.bc'] + + def _init_learn(self): + assert self._cfg.learn.optimizer in ['SGD', 'Adam'] + if self._cfg.learn.optimizer == 'SGD': + self._optimizer = SGD( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + weight_decay=self._cfg.learn.weight_decay, + momentum=self._cfg.learn.momentum + ) + elif self._cfg.learn.optimizer == 'Adam': + if self._cfg.learn.weight_decay is None: + self._optimizer = Adam( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + ) + else: + self._optimizer = AdamW( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + weight_decay=self._cfg.learn.weight_decay + ) + if self._cfg.learn.lr_decay: + + def lr_scheduler_fn(epoch): + if epoch <= self._cfg.learn.warmup_epoch: + return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate + else: + ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch + return math.pow(self._cfg.learn.decay_rate, ratio) + + self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn) + self._timer = EasyTimer(cuda=True) + self._learn_model = model_wrap(self._model, 'base') + self._learn_model.reset() + + self._hidden_state_loss = nn.MSELoss() + self._action_loss = nn.CrossEntropyLoss() + + def _forward_learn(self, data): + if self._cuda: + data = to_device(data, self._device) + self._learn_model.train() + with self._timer: + obs, hidden_states, action = data['obs'], data['hidden_states'], data['actions'] + pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) + # When we use bco, action is predicted by idm, gradient is not expected. + loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states)\ + + self._action_loss(pred_action, action) + forward_time = self._timer.value + + with self._timer: + self._optimizer.zero_grad() + loss.backward() + backward_time = self._timer.value + + with self._timer: + if self._cfg.learn.multi_gpu: + self.sync_gradients(self._learn_model) + sync_time = self._timer.value + + self._optimizer.step() + cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups] + cur_lr = sum(cur_lr) / len(cur_lr) + return { + 'cur_lr': cur_lr, + 'total_loss': loss.item(), + 'forward_time': forward_time, + 'backward_time': backward_time, + 'sync_time': sync_time, + } + + def _monitor_vars_learn(self): + return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time'] + + def _init_eval(self): + self._eval_model = model_wrap(self._model, wrapper_name='base') + self._eval_model.reset() + + def _forward_eval(self, data): + data_id = list(data.keys()) + data = default_collate(list(data.values())) + + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + with torch.no_grad(): + output = self._eval_model.forward_eval(data['obs']) + if self._cuda: + output = to_device(output, 'cpu') + output = {'action': output} + return {i: d for i, d in zip(data_id, output)} + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + pass + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + r""" + Overview: + Generate dict type transition data from inputs. + Arguments: + - obs (:obj:`Any`): Env observation + - model_output (:obj:`dict`): Output of collect model, including at least ['action'] + - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ + (here 'obs' indicates obs after env step). + Returns: + - transition (:obj:`dict`): Dict type transition data. + """ + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': model_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return EasyDict(transition) + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ + or some continuous transitions(DRQN). + Arguments: + - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ + format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`dict`): The list of training samples. + + .. note:: + We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ + And the user can customize the this data processing procecure by overriding this two methods and collector \ + itself. + """ + data = get_nstep_return_data(data, 1, 1) + return get_train_sample(data, self._unroll_len) diff --git a/dizoo/atari/config/serial/qbert/qbert_pc_mcts_config.py b/dizoo/atari/config/serial/qbert/qbert_pc_mcts_config.py new file mode 100644 index 0000000000..099e18c917 --- /dev/null +++ b/dizoo/atari/config/serial/qbert/qbert_pc_mcts_config.py @@ -0,0 +1,44 @@ +from easydict import EasyDict + +qbert_pc_mcts_config = dict( + exp_name='qbert_pc_mcts_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=1000000, + env_id='Qbert-v4', + ), + policy=dict( + cuda=True, + expert_data_path='pong_expert/ez_pong_seed0.pkl', + model=dict( + obs_shape=[3, 96, 96], + hidden_shape=[32, 8, 8], + action_shape=6, + ), + learn=dict( + batch_size=64, + learning_rate=0.01, + learner=dict(hook=dict(save_ckpt_after_iter=1000)), + train_epoch=20, + ), + eval=dict(evaluator=dict(eval_freq=40, )) + ), +) +qbert_pc_mcts_config = EasyDict(qbert_pc_mcts_config) +main_config = qbert_pc_mcts_config +qbert_pc_mcts_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='pc_mcts'), +) +qbert_pc_mcts_create_config = EasyDict(qbert_pc_mcts_create_config) +create_config = qbert_pc_mcts_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_pc_mcts + serial_pipeline_pc_mcts([main_config, create_config], seed=0) From f090d315f2c572b83f6fe90c25cf8443b0c45d69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sat, 4 Mar 2023 19:07:13 +0800 Subject: [PATCH 02/26] init commit --- .../config/serial/pong/pong_pc_mcts_config.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 dizoo/atari/config/serial/pong/pong_pc_mcts_config.py diff --git a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py new file mode 100644 index 0000000000..dad710e66d --- /dev/null +++ b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py @@ -0,0 +1,44 @@ +from easydict import EasyDict + +qbert_pc_mcts_config = dict( + exp_name='pong_pc_mcts_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=1000000, + env_id='Pong-v4', + ), + policy=dict( + cuda=True, + expert_data_path='pong_expert/ez_pong_seed0.pkl', + model=dict( + obs_shape=[3, 96, 96], + hidden_shape=[32, 8, 8], + action_shape=6, + ), + learn=dict( + batch_size=64, + learning_rate=0.01, + learner=dict(hook=dict(save_ckpt_after_iter=1000)), + train_epoch=20, + ), + eval=dict(evaluator=dict(eval_freq=40, )) + ), +) +qbert_pc_mcts_config = EasyDict(qbert_pc_mcts_config) +main_config = qbert_pc_mcts_config +qbert_pc_mcts_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='pc_mcts'), +) +qbert_pc_mcts_create_config = EasyDict(qbert_pc_mcts_create_config) +create_config = qbert_pc_mcts_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_pc_mcts + serial_pipeline_pc_mcts([main_config, create_config], seed=0) From 6dc65e14182b8bcc5fada57e8bbfa5506685d123 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sun, 5 Mar 2023 16:52:57 +0800 Subject: [PATCH 03/26] bug fux --- ding/entry/serial_entry_pc_mcts.py | 18 ++++--- ding/model/template/__init__.py | 2 +- ding/model/template/procedure_cloning.py | 18 ++++--- .../template/tests/test_procedure_cloning.py | 10 ++-- ding/policy/__init__.py | 1 + ding/policy/pc.py | 38 +++++++------- .../collector/interaction_serial_evaluator.py | 5 +- .../config/serial/pong/pong_pc_mcts_config.py | 49 +++++++++++++++---- 8 files changed, 93 insertions(+), 48 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index d6f8c1a42a..a859e6418c 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -29,7 +29,7 @@ def __getitem__(self, idx): """ return { 'obs': self.observations[idx], - 'hidden_states': list(reversed(self.hidden_states[idx+1: idx+self.seq_len+1])), + 'hidden_states': list(reversed(self.hidden_states[idx + 1: idx + self.seq_len + 1])), 'action': self.actions[idx] } @@ -37,10 +37,14 @@ def __len__(self): return self.length -def load_mcts_datasets(path, batch_size=32): +def load_mcts_datasets(path, seq_len, batch_size=32): with open(path, 'rb') as f: dic = pickle.load(f) - return DataLoader(MCTSPCDataset(dic), shuffle=True, batch_size=batch_size) + tot_len = len(dic['obs']) + train_dic = {k: v[:-tot_len // 10] for k, v in dic.items()} + test_dic = {k: v[-tot_len // 10:] for k, v in dic.items()} + return DataLoader(MCTSPCDataset(train_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size), \ + DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size) def serial_pipeline_pc_mcts( @@ -79,7 +83,7 @@ def serial_pipeline_pc_mcts( # Main components tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path) + dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path, seq_len=cfg.policy.seq_len) learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) evaluator = InteractionSerialEvaluator( cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name @@ -107,11 +111,11 @@ def serial_pipeline_pc_mcts( losses = [] acces = [] for _, test_data in enumerate(test_dataloader): - logits = policy._model.forward_eval(test_data) + logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) - loss = criterion(logits, test_data['action']).item() + loss = criterion(logits, test_data['action'].cuda()).item() preds = torch.argmax(logits, dim=-1) - acc = torch.sum((preds == test_data['action'])) / preds.shape[0] + acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] losses.append(loss) acces.append(acc) diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 11f7aa35b5..fb2223f3da 100644 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -22,4 +22,4 @@ from .madqn import MADQN from .vae import VanillaVAE from .decision_transformer import DecisionTransformer -from .procedure_cloning import ProcedureCloning +from .procedure_cloning import ProcedureCloningMCTS diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 6fc574f77e..09dc66225d 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -19,18 +19,25 @@ def __init__( self.attention_layer = [] self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att + self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) for i in range(n_att - 1): self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - + self.attention_layer = nn.ModuleList(self.attention_layer) self.att_drop = nn.Dropout(drop_p) self.fc_blocks = [] self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) for i in range(n_feedforward - 1): self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) + self.fc_blocks = nn.ModuleList(self.fc_blocks) + self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward) - self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T) + self.norm_layer = nn.ModuleList(self.norm_layer) + + self.mask = nn.Parameter( + torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T), requires_grad=False + ) def forward(self, x: torch.Tensor): for i in range(self.n_att): @@ -42,8 +49,8 @@ def forward(self, x: torch.Tensor): return x -@MODEL_REGISTRY.register('pc') -class ProcedureCloning(nn.Module): +@MODEL_REGISTRY.register('pc_mcts') +class ProcedureCloningMCTS(nn.Module): def __init__( self, @@ -73,7 +80,6 @@ def __init__( max_T = seq_len + 1 #Conv Encoder - print(cnn_padding) self.embed_state = ConvEncoder( obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding ) @@ -142,7 +148,7 @@ def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \ def forward_eval(self, states: torch.Tensor) -> torch.Tensor: batch_size = states.shape[0] hidden_states = torch.zeros(batch_size, self.seq_len, *self.hidden_shape, dtype=states.dtype).to(states.device) - embedding_mask = torch.zeros(1, self.seq_len, 1) + embedding_mask = torch.zeros(1, self.seq_len, 1).to(states.device) state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index adf718a796..9c35c9b93b 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -1,11 +1,7 @@ import torch import pytest -import numpy as np -from itertools import product -from ding.model.template import ProcedureCloning -from ding.torch_utils import is_differentiable -from ding.utils import squeeze +from ding.model.template import ProcedureCloningMCTS B = 4 T = 15 @@ -22,8 +18,8 @@ def test_procedure_cloning(): 'hidden_states': torch.randn(B, T, *hidden_shape), 'actions': torch.randn(B, action_dim) } - model = ProcedureCloning(obs_shape=obs_shape, hidden_shape=hidden_shape, - seq_len=T, action_dim=action_dim) + model = ProcedureCloningMCTS(obs_shape=obs_shape, hidden_shape=hidden_shape, + seq_len=T, action_dim=action_dim) print(model) diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 5938334022..b599c0a579 100644 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -44,6 +44,7 @@ from .bc import BehaviourCloningPolicy from .ibc import IBCPolicy +from .pc import ProcedureCloningPolicyMCTS # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/pc.py b/ding/policy/pc.py index 6de159b481..e13a36f49c 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -1,25 +1,24 @@ import math +from typing import List, Dict, Any, Tuple +from collections import namedtuple + import torch import torch.nn as nn from torch.optim import Adam, SGD, AdamW from torch.optim.lr_scheduler import LambdaLR -import logging -from typing import List, Dict, Any, Tuple, Union, Optional -from collections import namedtuple from easydict import EasyDict + from ding.policy import Policy from ding.model import model_wrap -from ding.torch_utils import to_device, to_list +from ding.torch_utils import to_device from ding.utils import EasyTimer from ding.utils.data import default_collate, default_decollate from ding.rl_utils import get_nstep_return_data, get_train_sample from ding.utils import POLICY_REGISTRY -from ding.torch_utils.loss.cross_entropy_loss import LabelSmoothCELoss @POLICY_REGISTRY.register('pc_mcts') class ProcedureCloningPolicyMCTS(Policy): - config = dict( type='pc_mcts', cuda=True, @@ -56,10 +55,7 @@ class ProcedureCloningPolicyMCTS(Policy): ) def default_model(self) -> Tuple[str, List[str]]: - if self._cfg.continuous: - return 'continuous_bc', ['ding.model.template.bc'] - else: - return 'discrete_bc', ['ding.model.template.bc'] + return 'pc_mcts', ['ding.model.template.procedure_cloning'] def _init_learn(self): assert self._cfg.learn.optimizer in ['SGD', 'Adam'] @@ -104,10 +100,11 @@ def _forward_learn(self, data): data = to_device(data, self._device) self._learn_model.train() with self._timer: - obs, hidden_states, action = data['obs'], data['hidden_states'], data['actions'] - pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) - # When we use bco, action is predicted by idm, gradient is not expected. - loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states)\ + obs, hidden_states, action = data['obs'], data['hidden_states'], data['action'] + obs = obs.permute(0, 3, 1, 2).float() + hidden_states = torch.stack(hidden_states, dim=1).float() + pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs / 255., hidden_states) + loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) \ + self._action_loss(pred_action, action) forward_time = self._timer.value @@ -141,17 +138,24 @@ def _init_eval(self): def _forward_eval(self, data): data_id = list(data.keys()) - data = default_collate(list(data.values())) + values = list(data.values()) + data = [{'obs': v['observation']} for v in values] + data = default_collate(data) if self._cuda: data = to_device(data, self._device) self._eval_model.eval() with torch.no_grad(): - output = self._eval_model.forward_eval(data['obs']) + output = self._eval_model.forward_eval(data['obs'].permute(0, 3, 1, 2) / 255.) + output = torch.argmax(output, dim=-1) if self._cuda: output = to_device(output, 'cpu') output = {'action': output} - return {i: d for i, d in zip(data_id, output)} + output = default_decollate(output) + # TODO why this bug? + output = [{'action': o['action'].item()} for o in output] + res = {i: d for i, d in zip(data_id, output)} + return res def _init_collect(self) -> None: pass diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index 3c5857c869..57b74e3d69 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -245,7 +245,10 @@ def eval( if self._cfg.figure_path is not None: self._env.enable_save_figure(env_id, self._cfg.figure_path) self._policy.reset([env_id]) - reward = t.info['eval_episode_return'] + if 'final_eval_reward' in t.info.keys(): + reward = t.info['final_eval_reward'] + else: + reward = t.info['eval_episode_return'] if 'episode_info' in t.info: eval_monitor.update_info(env_id, t.info['episode_info']) eval_monitor.update_reward(env_id, reward) diff --git a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py index dad710e66d..f05586befa 100644 --- a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py +++ b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py @@ -1,21 +1,52 @@ from easydict import EasyDict +seq_len = 4 qbert_pc_mcts_config = dict( exp_name='pong_pc_mcts_seed0', env=dict( + manager=dict( + episode_num=float('inf'), + max_retry=5, + step_timeout=None, + auto_reset=True, + reset_timeout=None, + retry_type='reset', + retry_waiting_time=0.1, + shared_memory=False, + copy_on_get=True, + context='fork', + wait_num=float('inf'), + step_wait_timeout=None, + connect_timeout=60, + reset_inplace=False, + cfg_type='SyncSubprocessEnvManagerDict', + type='subprocess', + ), + dqn_expert_data=False, + cfg_type='AtariLightZeroEnvDict', collector_env_num=8, - evaluator_env_num=5, - n_evaluator_episode=5, - stop_value=1000000, - env_id='Pong-v4', + evaluator_env_num=3, + n_evaluator_episode=3, + env_name='PongNoFrameskip-v4', + stop_value=20, + collect_max_episode_steps=10800, + eval_max_episode_steps=108000, + frame_skip=4, + obs_shape=[12, 96, 96], + episode_life=True, + gray_scale=False, + cvt_string=False, + game_wrapper=True, ), policy=dict( cuda=True, - expert_data_path='pong_expert/ez_pong_seed0.pkl', + expert_data_path='pong-v4-expert.pkl', + seq_len=seq_len, model=dict( obs_shape=[3, 96, 96], - hidden_shape=[32, 8, 8], - action_shape=6, + hidden_shape=[64, 6, 6], + action_dim=6, + seq_len=seq_len, ), learn=dict( batch_size=64, @@ -30,8 +61,8 @@ main_config = qbert_pc_mcts_config qbert_pc_mcts_create_config = dict( env=dict( - type='atari', - import_names=['dizoo.atari.envs.atari_env'], + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], ), env_manager=dict(type='base'), policy=dict(type='pc_mcts'), From fdd5d34024527d79849d501903d2545486239e33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 6 Mar 2023 10:27:51 +0800 Subject: [PATCH 04/26] reformat --- ding/entry/serial_entry_pc_mcts.py | 2 +- ding/model/template/procedure_cloning.py | 38 +++++++++---------- .../template/tests/test_procedure_cloning.py | 3 +- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index a859e6418c..b0a14c3f12 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -29,7 +29,7 @@ def __getitem__(self, idx): """ return { 'obs': self.observations[idx], - 'hidden_states': list(reversed(self.hidden_states[idx + 1: idx + self.seq_len + 1])), + 'hidden_states': list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1])), 'action': self.actions[idx] } diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 09dc66225d..e600c441e1 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -53,25 +53,25 @@ def forward(self, x: torch.Tensor): class ProcedureCloningMCTS(nn.Module): def __init__( - self, - obs_shape: SequenceType, - hidden_shape: SequenceType, - action_dim: int, - seq_len: int, - cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], - cnn_activation: Optional[nn.Module] = nn.ReLU(), - cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], - cnn_stride: SequenceType = [1, 1, 1, 1, 1], - cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1], - mlp_hidden_list: SequenceType = [256, 256], - mlp_activation: Optional[nn.Module] = nn.ReLU(), - att_heads: int = 8, - att_hidden: int = 128, - n_att: int = 4, - n_feedforward: int = 2, - feedforward_hidden: int = 256, - drop_p: float = 0.5, - augment: bool = True, + self, + obs_shape: SequenceType, + hidden_shape: SequenceType, + action_dim: int, + seq_len: int, + cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], + cnn_activation: Optional[nn.Module] = nn.ReLU(), + cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], + cnn_stride: SequenceType = [1, 1, 1, 1, 1], + cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1], + mlp_hidden_list: SequenceType = [256, 256], + mlp_activation: Optional[nn.Module] = nn.ReLU(), + att_heads: int = 8, + att_hidden: int = 128, + n_att: int = 4, + n_feedforward: int = 2, + feedforward_hidden: int = 256, + drop_p: float = 0.5, + augment: bool = True, ) -> None: super().__init__() self.obs_shape = obs_shape diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index 9c35c9b93b..641986f049 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -18,8 +18,7 @@ def test_procedure_cloning(): 'hidden_states': torch.randn(B, T, *hidden_shape), 'actions': torch.randn(B, action_dim) } - model = ProcedureCloningMCTS(obs_shape=obs_shape, hidden_shape=hidden_shape, - seq_len=T, action_dim=action_dim) + model = ProcedureCloningMCTS(obs_shape=obs_shape, hidden_shape=hidden_shape, seq_len=T, action_dim=action_dim) print(model) From 12b9a02444a3c888f5e78b29a08d4a53551afb43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 15 Mar 2023 17:40:02 +0800 Subject: [PATCH 05/26] add visualization --- ding/entry/serial_entry_pc_mcts.py | 2 +- ding/policy/pc.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index b0a14c3f12..465ef73164 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -99,6 +99,7 @@ def serial_pipeline_pc_mcts( # train criterion = torch.nn.CrossEntropyLoss() for i, train_data in enumerate(dataloader): + train_data['obs'] = train_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255. learner.train(train_data) iter_cnt += 1 if iter_cnt >= max_iter: @@ -112,7 +113,6 @@ def serial_pipeline_pc_mcts( acces = [] for _, test_data in enumerate(test_dataloader): logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) - loss = criterion(logits, test_data['action'].cuda()).item() preds = torch.argmax(logits, dim=-1) acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] diff --git a/ding/policy/pc.py b/ding/policy/pc.py index e13a36f49c..6236499701 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -101,11 +101,11 @@ def _forward_learn(self, data): self._learn_model.train() with self._timer: obs, hidden_states, action = data['obs'], data['hidden_states'], data['action'] - obs = obs.permute(0, 3, 1, 2).float() hidden_states = torch.stack(hidden_states, dim=1).float() - pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs / 255., hidden_states) - loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) \ - + self._action_loss(pred_action, action) + pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) + hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) + action_loss = self._action_loss(pred_action, action) + loss = hidden_state_loss + action_loss forward_time = self._timer.value with self._timer: @@ -124,13 +124,16 @@ def _forward_learn(self, data): return { 'cur_lr': cur_lr, 'total_loss': loss.item(), + 'hidden_state_loss': hidden_state_loss.item(), + 'action_loss': action_loss.item(), 'forward_time': forward_time, 'backward_time': backward_time, 'sync_time': sync_time, } def _monitor_vars_learn(self): - return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time'] + return ['cur_lr', 'total_loss', 'hidden_state_loss', 'action_loss', + 'forward_time', 'backward_time', 'sync_time'] def _init_eval(self): self._eval_model = model_wrap(self._model, wrapper_name='base') From a4592780691257b338bf9c0172513d07a0015d1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 20:42:10 +0800 Subject: [PATCH 06/26] feature(whl): update pc model --- ding/model/template/procedure_cloning.py | 119 ++++++++---------- .../template/tests/test_procedure_cloning.py | 8 +- 2 files changed, 58 insertions(+), 69 deletions(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index e600c441e1..116672bd6d 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -3,49 +3,53 @@ import torch.nn as nn from ding.utils import MODEL_REGISTRY, SequenceType from ding.torch_utils.network.transformer import Attention -from ding.torch_utils.network.nn_module import fc_block, build_normalization -from ..common import FCEncoder, ConvEncoder +from ..common import ConvEncoder -class Block(nn.Module): - - def __init__( - self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, - feedforward_hidden: int, n_feedforward: int - ) -> None: +class PreNorm(nn.Module): + def __init__(self, dim, fn): super().__init__() - self.n_att = n_att - self.n_feedforward = n_feedforward - self.attention_layer = [] + self.norm = nn.LayerNorm(dim) + self.fn = fn - self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) - self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - for i in range(n_att - 1): - self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - self.attention_layer = nn.ModuleList(self.attention_layer) - self.att_drop = nn.Dropout(drop_p) - self.fc_blocks = [] - self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) - for i in range(n_feedforward - 1): - self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) - self.fc_blocks = nn.ModuleList(self.fc_blocks) +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, drop_p=0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(drop_p), + nn.Linear(hidden_dim, dim), + nn.Dropout(drop_p) + ) - self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward) - self.norm_layer = nn.ModuleList(self.norm_layer) + def forward(self, x): + return self.net(x) + +class Transformer(nn.Module): + def __init__(self, n_layer: int, n_attn: int, n_head: int, drop_p: float, max_T: int, n_ffn: int): + super().__init__() + self.layers = nn.ModuleList([]) + assert n_attn % n_head == 0 + dim_head = n_attn // n_head + for _ in range(n_layer): + self.layers.append(nn.ModuleList([ + PreNorm(n_attn, Attention(n_attn, dim_head, n_attn, n_head, nn.Dropout(drop_p))), + PreNorm(n_attn, FeedForward(n_attn, n_ffn, drop_p=drop_p)) + ])) self.mask = nn.Parameter( torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T), requires_grad=False ) - def forward(self, x: torch.Tensor): - for i in range(self.n_att): - x = self.att_drop(self.attention_layer[i](x, self.mask)) - x = self.norm_layer[i](x) - for i in range(self.n_feedforward): - x = self.fc_blocks[i](x) - x = self.norm_layer[i + self.n_att](x) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x, mask=self.mask) + x + x = ff(x) + x return x @@ -58,20 +62,20 @@ def __init__( hidden_shape: SequenceType, action_dim: int, seq_len: int, - cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], + cnn_hidden_list: SequenceType = [128, 256, 512], + cnn_kernel_size: SequenceType = [8, 4, 3], + cnn_stride: SequenceType = [4, 2, 1], + cnn_padding: Optional[SequenceType] = [0, 0, 0], + hidden_state_cnn_hidden_list: SequenceType = [128, 256, 512], + hidden_state_cnn_kernel_size: SequenceType = [3, 3, 3], + hidden_state_cnn_stride: SequenceType = [1, 1, 1], + hidden_state_cnn_padding: Optional[SequenceType] = [1, 1, 1], cnn_activation: Optional[nn.Module] = nn.ReLU(), - cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], - cnn_stride: SequenceType = [1, 1, 1, 1, 1], - cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1], - mlp_hidden_list: SequenceType = [256, 256], - mlp_activation: Optional[nn.Module] = nn.ReLU(), att_heads: int = 8, - att_hidden: int = 128, - n_att: int = 4, - n_feedforward: int = 2, - feedforward_hidden: int = 256, - drop_p: float = 0.5, - augment: bool = True, + att_hidden: int = 512, + n_att_layer: int = 4, + ffn_hidden: int = 512, + drop_p: float = 0., ) -> None: super().__init__() self.obs_shape = obs_shape @@ -79,35 +83,20 @@ def __init__( self.seq_len = seq_len max_T = seq_len + 1 - #Conv Encoder + # Conv Encoder self.embed_state = ConvEncoder( obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding ) self.embed_hidden = ConvEncoder( - hidden_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding + hidden_shape, hidden_state_cnn_hidden_list, cnn_activation, hidden_state_cnn_kernel_size, + hidden_state_cnn_stride, hidden_state_cnn_padding ) self.cnn_hidden_list = cnn_hidden_list - self.augment = augment - - assert cnn_hidden_list[-1] == mlp_hidden_list[-1] - layers = [] - for i in range(n_att): - if i == 0: - layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - else: - layers.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) - layers.append(build_normalization('LN')(att_hidden)) - for i in range(n_feedforward): - if i == 0: - layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) - else: - layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) - self.layernorm2 = build_normalization('LN')(feedforward_hidden) - - self.transformer = Block( - cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward - ) + + assert cnn_hidden_list[-1] == att_hidden + self.transformer = Transformer(n_layer=n_att_layer, n_attn=att_hidden, n_head=att_heads, + drop_p=drop_p, max_T=max_T, n_ffn=ffn_hidden) self.predict_hidden_state = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim) diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index 641986f049..47346a1da3 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -5,10 +5,10 @@ B = 4 T = 15 -obs_shape = (64, 64, 3) -hidden_shape = (9, 9, 64) -action_dim = 9 -obs_embeddings = 256 +obs_shape = (3, 64, 64) +hidden_shape = (64, 9, 9) +action_dim = 6 +obs_embeddings = 512 @pytest.mark.unittest From f9836023d7bcc42a483ad2732e1d0768b75ceec4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 20:54:42 +0800 Subject: [PATCH 07/26] fix_bug --- ding/policy/pc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index 6236499701..dea896ed6c 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -101,7 +101,10 @@ def _forward_learn(self, data): self._learn_model.train() with self._timer: obs, hidden_states, action = data['obs'], data['hidden_states'], data['action'] - hidden_states = torch.stack(hidden_states, dim=1).float() + if len(hidden_states) > 0: + hidden_states = torch.stack(hidden_states, dim=1).float() + else: + hidden_states = torch.empty(obs.shape[0], 0, obs.shape[1:]) pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) action_loss = self._action_loss(pred_action, action) From f2db38ac11d0b16941081979a14142b71dfe7dbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 20:58:40 +0800 Subject: [PATCH 08/26] fix_bug --- ding/policy/pc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index dea896ed6c..2db2a402fa 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -104,7 +104,7 @@ def _forward_learn(self, data): if len(hidden_states) > 0: hidden_states = torch.stack(hidden_states, dim=1).float() else: - hidden_states = torch.empty(obs.shape[0], 0, obs.shape[1:]) + hidden_states = torch.empty(obs.shape[0], 0, *self._learn_model.hidden_shape) pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) action_loss = self._action_loss(pred_action, action) From 1403f7312e3d0a46fd5d49d736d35cd95946fefb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 21:00:04 +0800 Subject: [PATCH 09/26] fix_bug --- ding/policy/pc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index 2db2a402fa..7178d8ec8e 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -104,7 +104,7 @@ def _forward_learn(self, data): if len(hidden_states) > 0: hidden_states = torch.stack(hidden_states, dim=1).float() else: - hidden_states = torch.empty(obs.shape[0], 0, *self._learn_model.hidden_shape) + hidden_states = to_device(torch.empty(obs.shape[0], 0, *self._learn_model.hidden_shape), self._device) pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) action_loss = self._action_loss(pred_action, action) From cd2c87151d78f63e9e01d94bd0c595fcded9bb53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 21:03:56 +0800 Subject: [PATCH 10/26] fix_bug --- ding/model/template/procedure_cloning.py | 3 +++ ding/policy/pc.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 116672bd6d..d49e3fe2f2 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -146,4 +146,7 @@ def forward_eval(self, states: torch.Tensor) -> torch.Tensor: hidden_state_embeddings, action_pred = self._compute_transformer(h) embedding_mask[0, i, 0] = 1 + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + hidden_state_embeddings, action_pred = self._compute_transformer(h) + return action_pred diff --git a/ding/policy/pc.py b/ding/policy/pc.py index 7178d8ec8e..a0b036d41b 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -101,12 +101,16 @@ def _forward_learn(self, data): self._learn_model.train() with self._timer: obs, hidden_states, action = data['obs'], data['hidden_states'], data['action'] - if len(hidden_states) > 0: + zero_hidden_len = len(hidden_states) > 0 + if zero_hidden_len: hidden_states = torch.stack(hidden_states, dim=1).float() else: hidden_states = to_device(torch.empty(obs.shape[0], 0, *self._learn_model.hidden_shape), self._device) pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) - hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) + if zero_hidden_len: + hidden_state_loss = 0 + else: + hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) action_loss = self._action_loss(pred_action, action) loss = hidden_state_loss + action_loss forward_time = self._timer.value From 1a8ea4f248347a7b13c0fd7619fdbe8c9d442276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 21:10:50 +0800 Subject: [PATCH 11/26] fix_bug --- ding/model/template/procedure_cloning.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index d49e3fe2f2..a499fc5ddc 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -107,8 +107,11 @@ def _compute_embeddings(self, states: torch.Tensor, hidden_states: torch.Tensor) # shape: (B, 1, h_dim) state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1]) # shape: (B, T, h_dim) - hidden_state_embeddings = self.embed_hidden(hidden_states.reshape(B * T, *hidden_states.shape[2:])) \ - .reshape(B, T, self.cnn_hidden_list[-1]) + if T > 0: + hidden_state_embeddings = self.embed_hidden(hidden_states.reshape(B * T, *hidden_states.shape[2:])) \ + .reshape(B, T, self.cnn_hidden_list[-1]) + else: + hidden_state_embeddings = None return state_embeddings, hidden_state_embeddings def _compute_transformer(self, h): @@ -128,11 +131,13 @@ def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \ B, T, *_ = hidden_states.shape assert T == self.seq_len state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) - - h = torch.cat((state_embeddings, hidden_state_embeddings), dim=1) + if hidden_state_embeddings: + h = torch.cat((state_embeddings, hidden_state_embeddings), dim=1) + else: + h = state_embeddings hidden_state_preds, action_preds = self._compute_transformer(h) - return hidden_state_preds, action_preds, hidden_state_embeddings.detach() + return hidden_state_preds, action_preds, hidden_state_embeddings.detach() if hidden_state_embeddings else None def forward_eval(self, states: torch.Tensor) -> torch.Tensor: batch_size = states.shape[0] From 4ffaaa4893f07d62c4b170a078ed72436fb06df4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 21:12:42 +0800 Subject: [PATCH 12/26] fix_bug --- ding/policy/pc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index a0b036d41b..e4c528b572 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -101,8 +101,8 @@ def _forward_learn(self, data): self._learn_model.train() with self._timer: obs, hidden_states, action = data['obs'], data['hidden_states'], data['action'] - zero_hidden_len = len(hidden_states) > 0 - if zero_hidden_len: + zero_hidden_len = len(hidden_states) == 0 + if not zero_hidden_len: hidden_states = torch.stack(hidden_states, dim=1).float() else: hidden_states = to_device(torch.empty(obs.shape[0], 0, *self._learn_model.hidden_shape), self._device) From 1e15ee4a9b2abc11f227f3dae5e43727570090ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 21:13:43 +0800 Subject: [PATCH 13/26] fix_bug --- ding/policy/pc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index e4c528b572..d17fc2efe0 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -108,7 +108,7 @@ def _forward_learn(self, data): hidden_states = to_device(torch.empty(obs.shape[0], 0, *self._learn_model.hidden_shape), self._device) pred_hidden_states, pred_action, target_hidden_states = self._learn_model.forward(obs, hidden_states) if zero_hidden_len: - hidden_state_loss = 0 + hidden_state_loss = torch.tensor(0.) else: hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) action_loss = self._action_loss(pred_action, action) From d486ee81331ad710898834ee90a1b576abe50289 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 21:16:07 +0800 Subject: [PATCH 14/26] fix_bug --- ding/model/template/procedure_cloning.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index a499fc5ddc..b87a053a61 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -151,7 +151,10 @@ def forward_eval(self, states: torch.Tensor) -> torch.Tensor: hidden_state_embeddings, action_pred = self._compute_transformer(h) embedding_mask[0, i, 0] = 1 - h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + if self.seq_len > 0: + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + else: + h = state_embeddings hidden_state_embeddings, action_pred = self._compute_transformer(h) return action_pred From c6662c4f151440f7da673106a9118023d2c0d202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 20 Mar 2023 23:09:32 +0800 Subject: [PATCH 15/26] bug fix --- ding/entry/serial_entry_pc_mcts.py | 5 +++-- ding/model/template/procedure_cloning.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index 465ef73164..c3b8de66b8 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -44,7 +44,7 @@ def load_mcts_datasets(path, seq_len, batch_size=32): train_dic = {k: v[:-tot_len // 10] for k, v in dic.items()} test_dic = {k: v[-tot_len // 10:] for k, v in dic.items()} return DataLoader(MCTSPCDataset(train_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size), \ - DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size) + DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size) def serial_pipeline_pc_mcts( @@ -83,7 +83,8 @@ def serial_pipeline_pc_mcts( # Main components tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path, seq_len=cfg.policy.seq_len) + dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path, seq_len=cfg.policy.seq_len, + batch_size=cfg.policy.learn.batch_size) learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) evaluator = InteractionSerialEvaluator( cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index b87a053a61..92806d0f0e 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -131,13 +131,14 @@ def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \ B, T, *_ = hidden_states.shape assert T == self.seq_len state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) - if hidden_state_embeddings: + if hidden_state_embeddings is not None: h = torch.cat((state_embeddings, hidden_state_embeddings), dim=1) else: h = state_embeddings hidden_state_preds, action_preds = self._compute_transformer(h) - return hidden_state_preds, action_preds, hidden_state_embeddings.detach() if hidden_state_embeddings else None + return hidden_state_preds, action_preds, hidden_state_embeddings.detach() \ + if hidden_state_embeddings is not None else None def forward_eval(self, states: torch.Tensor) -> torch.Tensor: batch_size = states.shape[0] From d92589702a272eeb72a4712db42018cc5e397663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 10:21:13 +0800 Subject: [PATCH 16/26] add visualization for recurrent mode --- ding/entry/serial_entry_pc_mcts.py | 42 +++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index c3b8de66b8..1c8f14bb71 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -55,13 +55,12 @@ def serial_pipeline_pc_mcts( ) -> Union['Policy', bool]: # noqa r""" Overview: - Serial pipeline entry of imitation learning. + Serial pipeline entry of procedure cloning with MCTS. Arguments: - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ ``str`` type means config file path. \ ``Tuple[dict, dict]`` type means [user_config, create_cfg]. - seed (:obj:`int`): Random seed. - - data_path (:obj:`str`): Path of training data. - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. Returns: - policy (:obj:`Policy`): Converged policy. @@ -96,6 +95,7 @@ def serial_pipeline_pc_mcts( learner.call_hook('before_run') stop = False iter_cnt = 0 + epoch_per_test = 10 for epoch in range(cfg.policy.learn.train_epoch): # train criterion = torch.nn.CrossEntropyLoss() @@ -110,17 +110,33 @@ def serial_pipeline_pc_mcts( policy._optimizer.param_groups[0]['lr'] /= 10 if stop: break - losses = [] - acces = [] - for _, test_data in enumerate(test_dataloader): - logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) - loss = criterion(logits, test_data['action'].cuda()).item() - preds = torch.argmax(logits, dim=-1) - acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] - - losses.append(loss) - acces.append(acc) - print('Test Finished! Loss: {} acc: {}'.format(sum(losses) / len(losses), sum(acces) / len(acces))) + + if epoch % epoch_per_test == 0: + losses = [] + acces = [] + for _, test_data in enumerate(test_dataloader): + logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) + loss = criterion(logits, test_data['action'].cuda()).item() + preds = torch.argmax(logits, dim=-1) + acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] + + losses.append(loss) + acces.append(acc) + tb_logger.add_scalar('learn_epoch/recurrent_test_loss', sum(losses) / len(losses), epoch) + tb_logger.add_scalar('learn_epoch/recurrent_test_acc', sum(acces) / len(acces)) + + losses = [] + acces = [] + for _, test_data in enumerate(dataloader): + logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) + loss = criterion(logits, test_data['action'].cuda()).item() + preds = torch.argmax(logits, dim=-1) + acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] + + losses.append(loss) + acces.append(acc) + tb_logger.add_scalar('learn_epoch/recurrent_train_loss', sum(losses) / len(losses), epoch) + tb_logger.add_scalar('learn_epoch/recurrent_train_acc', sum(acces) / len(acces)) stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) learner.call_hook('after_run') print('final reward is: {}'.format(reward)) From 841e013bc87abb1ec228994a6750fec9964e74f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 10:30:32 +0800 Subject: [PATCH 17/26] debug visualization for recurrent mode --- ding/entry/serial_entry_pc_mcts.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index 1c8f14bb71..b29a9c1db2 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -94,7 +94,6 @@ def serial_pipeline_pc_mcts( # ========== learner.call_hook('before_run') stop = False - iter_cnt = 0 epoch_per_test = 10 for epoch in range(cfg.policy.learn.train_epoch): # train @@ -102,8 +101,7 @@ def serial_pipeline_pc_mcts( for i, train_data in enumerate(dataloader): train_data['obs'] = train_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255. learner.train(train_data) - iter_cnt += 1 - if iter_cnt >= max_iter: + if learner.train_iter >= max_iter: stop = True break if epoch % 69 == 0: @@ -122,8 +120,8 @@ def serial_pipeline_pc_mcts( losses.append(loss) acces.append(acc) - tb_logger.add_scalar('learn_epoch/recurrent_test_loss', sum(losses) / len(losses), epoch) - tb_logger.add_scalar('learn_epoch/recurrent_test_acc', sum(acces) / len(acces)) + tb_logger.add_scalar('learn_iter/recurrent_test_loss', sum(losses) / len(losses), learner.train_iter) + tb_logger.add_scalar('learn_iter/recurrent_test_acc', sum(acces) / len(acces), learner.train_iter) losses = [] acces = [] @@ -135,8 +133,8 @@ def serial_pipeline_pc_mcts( losses.append(loss) acces.append(acc) - tb_logger.add_scalar('learn_epoch/recurrent_train_loss', sum(losses) / len(losses), epoch) - tb_logger.add_scalar('learn_epoch/recurrent_train_acc', sum(acces) / len(acces)) + tb_logger.add_scalar('learn_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) + tb_logger.add_scalar('learn_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) learner.call_hook('after_run') print('final reward is: {}'.format(reward)) From ffb81548c197e55e0e472c21017cc9a75d1dfdc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 10:34:49 +0800 Subject: [PATCH 18/26] debug visualization for recurrent mode --- ding/entry/serial_entry_pc_mcts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index b29a9c1db2..a1c93b26e1 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -120,8 +120,8 @@ def serial_pipeline_pc_mcts( losses.append(loss) acces.append(acc) - tb_logger.add_scalar('learn_iter/recurrent_test_loss', sum(losses) / len(losses), learner.train_iter) - tb_logger.add_scalar('learn_iter/recurrent_test_acc', sum(acces) / len(acces), learner.train_iter) + tb_logger.add_scalar('learner_iter/recurrent_test_loss', sum(losses) / len(losses), learner.train_iter) + tb_logger.add_scalar('learner_iter/recurrent_test_acc', sum(acces) / len(acces), learner.train_iter) losses = [] acces = [] @@ -133,8 +133,8 @@ def serial_pipeline_pc_mcts( losses.append(loss) acces.append(acc) - tb_logger.add_scalar('learn_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) - tb_logger.add_scalar('learn_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) + tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) + tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) learner.call_hook('after_run') print('final reward is: {}'.format(reward)) From b3c72aad14b55b02842e405450331658ce508ac6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 11:53:45 +0800 Subject: [PATCH 19/26] debug forward eval --- ding/entry/serial_entry_pc_mcts.py | 26 ++++++++++++- ding/model/template/procedure_cloning.py | 49 +++++++++++++++++------- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index a1c93b26e1..085bdc8091 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -95,9 +95,10 @@ def serial_pipeline_pc_mcts( learner.call_hook('before_run') stop = False epoch_per_test = 10 + criterion = torch.nn.CrossEntropyLoss() + hidden_state_criterion = torch.nn.MSELoss() for epoch in range(cfg.policy.learn.train_epoch): # train - criterion = torch.nn.CrossEntropyLoss() for i, train_data in enumerate(dataloader): train_data['obs'] = train_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255. learner.train(train_data) @@ -123,18 +124,39 @@ def serial_pipeline_pc_mcts( tb_logger.add_scalar('learner_iter/recurrent_test_loss', sum(losses) / len(losses), learner.train_iter) tb_logger.add_scalar('learner_iter/recurrent_test_acc', sum(acces) / len(acces), learner.train_iter) + # losses = [] + # acces = [] + # for _, test_data in enumerate(dataloader): + # logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) + # loss = criterion(logits, test_data['action'].cuda()).item() + # preds = torch.argmax(logits, dim=-1) + # acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] + # + # losses.append(loss) + # acces.append(acc) + # tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) + # tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) + losses = [] + mse_losses = [] acces = [] for _, test_data in enumerate(dataloader): - logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) + test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float() + logits, pred_hidden_states = policy._model.test_forward_eval( + test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255., + test_hidden_states + ) loss = criterion(logits, test_data['action'].cuda()).item() + mse_loss = hidden_state_criterion(pred_hidden_states, test_hidden_states.cuda()).item() preds = torch.argmax(logits, dim=-1) acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] losses.append(loss) acces.append(acc) + mse_losses.append(mse_loss) tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) + tb_logger.add_scalar('learner_iter/recurrent_train_mse_loss', sum(mse_losses) / len(mse_losses), learner.train_iter) stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) learner.call_hook('after_run') print('final reward is: {}'.format(reward)) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 92806d0f0e..699d0445e6 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -141,21 +141,42 @@ def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \ if hidden_state_embeddings is not None else None def forward_eval(self, states: torch.Tensor) -> torch.Tensor: - batch_size = states.shape[0] - hidden_states = torch.zeros(batch_size, self.seq_len, *self.hidden_shape, dtype=states.dtype).to(states.device) - embedding_mask = torch.zeros(1, self.seq_len, 1).to(states.device) + with torch.no_grad(): + batch_size = states.shape[0] + hidden_states = torch.zeros(batch_size, self.seq_len, *self.hidden_shape, dtype=states.dtype).to( + states.device) + embedding_mask = torch.zeros(1, self.seq_len, 1).to(states.device) + + state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) + + for i in range(self.seq_len): + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + hidden_state_embeddings, action_pred = self._compute_transformer(h) + embedding_mask[0, i, 0] = 1 + + if self.seq_len > 0: + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + else: + h = state_embeddings + hidden_state_embeddings, action_pred = self._compute_transformer(h) - state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) + return action_pred - for i in range(self.seq_len): - h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) - hidden_state_embeddings, action_pred = self._compute_transformer(h) - embedding_mask[0, i, 0] = 1 + def test_forward_eval(self, states: torch.Tensor, hidden_states: torch.Tensor) -> Tuple: + # Action pred in this function is supposed to be identical in training phase. + with torch.no_grad(): + embedding_mask = torch.zeros(1, self.seq_len, 1).to(states.device) + state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states) - if self.seq_len > 0: - h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) - else: - h = state_embeddings - hidden_state_embeddings, action_pred = self._compute_transformer(h) + for i in range(self.seq_len): + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + _, action_pred = self._compute_transformer(h) + embedding_mask[0, i, 0] = 1 - return action_pred + if self.seq_len > 0: + h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1) + else: + h = state_embeddings + pred_hidden_state_embeddings, action_pred = self._compute_transformer(h) + + return action_pred, pred_hidden_state_embeddings From cfbd27737331f89f7fc883b75223e236de668fd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 11:55:03 +0800 Subject: [PATCH 20/26] debug forward eval --- ding/entry/serial_entry_pc_mcts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index 085bdc8091..2eb87a71ac 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -141,13 +141,13 @@ def serial_pipeline_pc_mcts( mse_losses = [] acces = [] for _, test_data in enumerate(dataloader): - test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float() + test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float().cuda() logits, pred_hidden_states = policy._model.test_forward_eval( test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255., test_hidden_states ) loss = criterion(logits, test_data['action'].cuda()).item() - mse_loss = hidden_state_criterion(pred_hidden_states, test_hidden_states.cuda()).item() + mse_loss = hidden_state_criterion(pred_hidden_states, test_hidden_states).item() preds = torch.argmax(logits, dim=-1) acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] From b8763474c98b3c43c01451e5d3054728ef389111 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 11:56:58 +0800 Subject: [PATCH 21/26] debug forward eval --- ding/entry/serial_entry_pc_mcts.py | 4 ++-- ding/model/template/procedure_cloning.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index 2eb87a71ac..b439fd1857 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -142,12 +142,12 @@ def serial_pipeline_pc_mcts( acces = [] for _, test_data in enumerate(dataloader): test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float().cuda() - logits, pred_hidden_states = policy._model.test_forward_eval( + logits, pred_hidden_states, hidden_state_embeddings = policy._model.test_forward_eval( test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255., test_hidden_states ) loss = criterion(logits, test_data['action'].cuda()).item() - mse_loss = hidden_state_criterion(pred_hidden_states, test_hidden_states).item() + mse_loss = hidden_state_criterion(pred_hidden_states, hidden_state_embeddings).item() preds = torch.argmax(logits, dim=-1) acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 699d0445e6..1dc43b5c97 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -179,4 +179,4 @@ def test_forward_eval(self, states: torch.Tensor, hidden_states: torch.Tensor) - h = state_embeddings pred_hidden_state_embeddings, action_pred = self._compute_transformer(h) - return action_pred, pred_hidden_state_embeddings + return action_pred, pred_hidden_state_embeddings, hidden_state_embeddings From 70baa697ea21140fd780d50f3b6a6e7eb62abcc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 14:57:45 +0800 Subject: [PATCH 22/26] reweight loss --- ding/entry/serial_entry_pc_mcts.py | 45 +++++++++++++++--------------- ding/policy/pc.py | 2 +- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index b439fd1857..e50f6749bf 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -124,39 +124,40 @@ def serial_pipeline_pc_mcts( tb_logger.add_scalar('learner_iter/recurrent_test_loss', sum(losses) / len(losses), learner.train_iter) tb_logger.add_scalar('learner_iter/recurrent_test_acc', sum(acces) / len(acces), learner.train_iter) - # losses = [] - # acces = [] - # for _, test_data in enumerate(dataloader): - # logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) - # loss = criterion(logits, test_data['action'].cuda()).item() - # preds = torch.argmax(logits, dim=-1) - # acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] - # - # losses.append(loss) - # acces.append(acc) - # tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) - # tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) - losses = [] - mse_losses = [] acces = [] for _, test_data in enumerate(dataloader): - test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float().cuda() - logits, pred_hidden_states, hidden_state_embeddings = policy._model.test_forward_eval( - test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255., - test_hidden_states - ) + logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) loss = criterion(logits, test_data['action'].cuda()).item() - mse_loss = hidden_state_criterion(pred_hidden_states, hidden_state_embeddings).item() preds = torch.argmax(logits, dim=-1) acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] losses.append(loss) acces.append(acc) - mse_losses.append(mse_loss) tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) - tb_logger.add_scalar('learner_iter/recurrent_train_mse_loss', sum(mse_losses) / len(mse_losses), learner.train_iter) + + # Test for forward eval function. + # losses = [] + # mse_losses = [] + # acces = [] + # for _, test_data in enumerate(dataloader): + # test_hidden_states = torch.stack(test_data['hidden_states'], dim=1).float().cuda() + # logits, pred_hidden_states, hidden_state_embeddings = policy._model.test_forward_eval( + # test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255., + # test_hidden_states + # ) + # loss = criterion(logits, test_data['action'].cuda()).item() + # mse_loss = hidden_state_criterion(pred_hidden_states, hidden_state_embeddings).item() + # preds = torch.argmax(logits, dim=-1) + # acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] + # + # losses.append(loss) + # acces.append(acc) + # mse_losses.append(mse_loss) + # tb_logger.add_scalar('learner_iter/recurrent_train_loss', sum(losses) / len(losses), learner.train_iter) + # tb_logger.add_scalar('learner_iter/recurrent_train_acc', sum(acces) / len(acces), learner.train_iter) + # tb_logger.add_scalar('learner_iter/recurrent_train_mse_loss', sum(mse_losses) / len(mse_losses), learner.train_iter) stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) learner.call_hook('after_run') print('final reward is: {}'.format(reward)) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index d17fc2efe0..5dd30c66b7 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -110,7 +110,7 @@ def _forward_learn(self, data): if zero_hidden_len: hidden_state_loss = torch.tensor(0.) else: - hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) + hidden_state_loss = 10 * self._hidden_state_loss(pred_hidden_states, target_hidden_states) action_loss = self._action_loss(pred_action, action) loss = hidden_state_loss + action_loss forward_time = self._timer.value From ab9eda7c145e5d081b7791b2640ad535165234aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 21 Mar 2023 21:27:47 +0800 Subject: [PATCH 23/26] reweight loss --- ding/entry/serial_entry_pc_mcts.py | 20 ++++++++++++++------ ding/policy/pc.py | 5 +++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index e50f6749bf..465401391f 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -16,20 +16,25 @@ class MCTSPCDataset(Dataset): - def __init__(self, data_dic, seq_len=4): + def __init__(self, data_dic, seq_len=4, hidden_state_noise=0): self.observations = data_dic['obs'] self.actions = data_dic['actions'] self.hidden_states = data_dic['hidden_state'] self.seq_len = seq_len self.length = len(self.observations) - seq_len - 1 + self.hidden_state_noise = hidden_state_noise def __getitem__(self, idx): """ Assume the trajectory is: o1, h2, h3, h4 """ + hidden_states = list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1])) + if self.hidden_state_noise > 0: + for i in range(len(hidden_states)): + hidden_states[i] += self.hidden_state_noise * torch.randn_like(hidden_states[i]) return { 'obs': self.observations[idx], - 'hidden_states': list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1])), + 'hidden_states': hidden_states, 'action': self.actions[idx] } @@ -37,14 +42,16 @@ def __len__(self): return self.length -def load_mcts_datasets(path, seq_len, batch_size=32): +def load_mcts_datasets(path, seq_len, batch_size=32, hidden_state_noise=0): with open(path, 'rb') as f: dic = pickle.load(f) tot_len = len(dic['obs']) train_dic = {k: v[:-tot_len // 10] for k, v in dic.items()} test_dic = {k: v[-tot_len // 10:] for k, v in dic.items()} - return DataLoader(MCTSPCDataset(train_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size), \ - DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size) + return DataLoader(MCTSPCDataset(train_dic, seq_len=seq_len, hidden_state_noise=hidden_state_noise), shuffle=True + , batch_size=batch_size), \ + DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len, hidden_state_noise=hidden_state_noise), shuffle=True, + batch_size=batch_size) def serial_pipeline_pc_mcts( @@ -83,7 +90,8 @@ def serial_pipeline_pc_mcts( # Main components tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path, seq_len=cfg.policy.seq_len, - batch_size=cfg.policy.learn.batch_size) + batch_size=cfg.policy.learn.batch_size, + hidden_state_noise=cfg.policy.learn.hidden_state_noise) learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) evaluator = InteractionSerialEvaluator( cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name diff --git a/ding/policy/pc.py b/ding/policy/pc.py index 5dd30c66b7..9ec7874ac6 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -92,7 +92,8 @@ def lr_scheduler_fn(epoch): self._learn_model = model_wrap(self._model, 'base') self._learn_model.reset() - self._hidden_state_loss = nn.MSELoss() + # self._hidden_state_loss = nn.MSELoss() + self._hidden_state_loss = nn.L1Loss() self._action_loss = nn.CrossEntropyLoss() def _forward_learn(self, data): @@ -110,7 +111,7 @@ def _forward_learn(self, data): if zero_hidden_len: hidden_state_loss = torch.tensor(0.) else: - hidden_state_loss = 10 * self._hidden_state_loss(pred_hidden_states, target_hidden_states) + hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) action_loss = self._action_loss(pred_action, action) loss = hidden_state_loss + action_loss forward_time = self._timer.value From b0080ac974134ceba57288dcca2ad8ebd82a7472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 14 Apr 2023 16:43:13 +0800 Subject: [PATCH 24/26] add seq actions --- ding/entry/serial_entry_pc_mcts.py | 3 ++- ding/model/template/procedure_cloning.py | 4 ++-- .../template/tests/test_procedure_cloning.py | 3 +++ ding/policy/pc.py | 19 ++++++++++++++++++- .../config/serial/pong/pong_pc_mcts_config.py | 2 ++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index 465401391f..083537335d 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -29,13 +29,14 @@ def __getitem__(self, idx): Assume the trajectory is: o1, h2, h3, h4 """ hidden_states = list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1])) + actions = list(reversed(self.actions[idx: idx + self.seq_len])) if self.hidden_state_noise > 0: for i in range(len(hidden_states)): hidden_states[i] += self.hidden_state_noise * torch.randn_like(hidden_states[i]) return { 'obs': self.observations[idx], 'hidden_states': hidden_states, - 'action': self.actions[idx] + 'action': actions } def __len__(self): diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 1dc43b5c97..51594faf7c 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -120,7 +120,7 @@ def _compute_transformer(self, h): h = h.reshape(B, T, self.cnn_hidden_list[-1]) hidden_state_preds = self.predict_hidden_state(h[:, 0:-1, ...]) - action_preds = self.predict_action(h[:, -1, :]) + action_preds = self.predict_action(h[:, 1:, :]) return hidden_state_preds, action_preds def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \ @@ -160,7 +160,7 @@ def forward_eval(self, states: torch.Tensor) -> torch.Tensor: h = state_embeddings hidden_state_embeddings, action_pred = self._compute_transformer(h) - return action_pred + return action_pred[:, -1, :] def test_forward_eval(self, states: torch.Tensor, hidden_states: torch.Tensor) -> Tuple: # Action pred in this function is supposed to be identical in training phase. diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index 47346a1da3..534d792f37 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -28,3 +28,6 @@ def test_procedure_cloning(): action_eval = model.forward_eval(inputs['states']) assert action_eval.shape == (B, action_dim) + + hidden_state_preds_new, _, _ = model(inputs['states'], torch.zeros_like(inputs['hidden_states'])) + assert torch.sum(torch.abs(hidden_state_preds_new[:, 0, :] - hidden_state_preds[:, 0, :])).item() < 1e-9 diff --git a/ding/policy/pc.py b/ding/policy/pc.py index 9ec7874ac6..cb7d202d79 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -17,6 +17,23 @@ from ding.utils import POLICY_REGISTRY +class BatchCELoss(nn.Module): + def __init__(self, seq, mask): + super(BatchCELoss, self).__init__() + self.ce = nn.CrossEntropyLoss() + self.mask = mask + self.seq = seq + self.masked_ratio = 0 + + def forward(self, pred_y, target_y): + if not self.seq: + return self.ce(pred_y[:, -1, :], target_y[:, -1]) + losses = 0 + for i in range(target_y.shape[1]): + losses += self.ce(pred_y[:, i, :], target_y[:, i]) + return losses + + @POLICY_REGISTRY.register('pc_mcts') class ProcedureCloningPolicyMCTS(Policy): config = dict( @@ -94,7 +111,7 @@ def lr_scheduler_fn(epoch): # self._hidden_state_loss = nn.MSELoss() self._hidden_state_loss = nn.L1Loss() - self._action_loss = nn.CrossEntropyLoss() + self._action_loss = BatchCELoss(seq=self.config.seq_action, mask=self.config.mask_seq_action) def _forward_learn(self, data): if self._cuda: diff --git a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py index f05586befa..1aeef4dae1 100644 --- a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py +++ b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py @@ -42,6 +42,8 @@ cuda=True, expert_data_path='pong-v4-expert.pkl', seq_len=seq_len, + seq_action=True, + mask_seq_action=False, model=dict( obs_shape=[3, 96, 96], hidden_shape=[64, 6, 6], From 6f5ba7eef9abbe9b224d7a401f3d8edeb3c8b9b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 17 Apr 2023 11:24:32 +0800 Subject: [PATCH 25/26] polish loss --- ding/entry/serial_entry_pc_mcts.py | 10 +++---- ding/policy/pc.py | 30 +++++++++++++++---- .../config/serial/pong/pong_pc_mcts_config.py | 7 +++-- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/ding/entry/serial_entry_pc_mcts.py b/ding/entry/serial_entry_pc_mcts.py index 083537335d..671d45c76b 100644 --- a/ding/entry/serial_entry_pc_mcts.py +++ b/ding/entry/serial_entry_pc_mcts.py @@ -29,7 +29,7 @@ def __getitem__(self, idx): Assume the trajectory is: o1, h2, h3, h4 """ hidden_states = list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1])) - actions = list(reversed(self.actions[idx: idx + self.seq_len])) + actions = torch.tensor(list(reversed(self.actions[idx: idx + self.seq_len]))) if self.hidden_state_noise > 0: for i in range(len(hidden_states)): hidden_states[i] += self.hidden_state_noise * torch.randn_like(hidden_states[i]) @@ -124,9 +124,9 @@ def serial_pipeline_pc_mcts( acces = [] for _, test_data in enumerate(test_dataloader): logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) - loss = criterion(logits, test_data['action'].cuda()).item() + loss = criterion(logits, test_data['action'][:, -1].cuda()).item() preds = torch.argmax(logits, dim=-1) - acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] + acc = torch.sum((preds == test_data['action'][:, -1].cuda())).item() / preds.shape[0] losses.append(loss) acces.append(acc) @@ -137,9 +137,9 @@ def serial_pipeline_pc_mcts( acces = [] for _, test_data in enumerate(dataloader): logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.) - loss = criterion(logits, test_data['action'].cuda()).item() + loss = criterion(logits, test_data['action'][:, -1].cuda()).item() preds = torch.argmax(logits, dim=-1) - acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0] + acc = torch.sum((preds == test_data['action'][:, -1].cuda())).item() / preds.shape[0] losses.append(loss) acces.append(acc) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index cb7d202d79..97be9bb615 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -21,6 +21,7 @@ class BatchCELoss(nn.Module): def __init__(self, seq, mask): super(BatchCELoss, self).__init__() self.ce = nn.CrossEntropyLoss() + self.nce = nn.CrossEntropyLoss(reduction='none') self.mask = mask self.seq = seq self.masked_ratio = 0 @@ -28,10 +29,29 @@ def __init__(self, seq, mask): def forward(self, pred_y, target_y): if not self.seq: return self.ce(pred_y[:, -1, :], target_y[:, -1]) - losses = 0 - for i in range(target_y.shape[1]): - losses += self.ce(pred_y[:, i, :], target_y[:, i]) - return losses + if not self.mask: + losses = 0 + for i in range(target_y.shape[1]): + losses += self.ce(pred_y[:, i, :], target_y[:, i]) + return losses + else: + eqs = [] + losses = 0 + cnt = 0 + + cur_loss = self.nce(pred_y[:, 0, :], target_y[:, 0]) + losses += torch.sum(cur_loss) + cnt += target_y.shape[0] + eqs.append((torch.argmax(pred_y[:, 0, :], dim=-1) == target_y[:, 0])) + + for i in range(1, target_y.shape[1]): + cur_loss = self.nce(pred_y[:, i, :], target_y[:, i]) + losses += torch.sum(cur_loss * eqs[-1]) + cnt += torch.sum(eqs[-1]) + # Update eqs + eqs.append((torch.argmax(pred_y[:, i, :], dim=-1) == target_y[:, i])) + eqs[-1] = eqs[-1] and eqs[-2] + return losses / cnt @POLICY_REGISTRY.register('pc_mcts') @@ -111,7 +131,7 @@ def lr_scheduler_fn(epoch): # self._hidden_state_loss = nn.MSELoss() self._hidden_state_loss = nn.L1Loss() - self._action_loss = BatchCELoss(seq=self.config.seq_action, mask=self.config.mask_seq_action) + self._action_loss = BatchCELoss(seq=self._cfg.seq_action, mask=self._cfg.mask_seq_action) def _forward_learn(self, data): if self._cuda: diff --git a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py index 1aeef4dae1..0c7a93bb4a 100644 --- a/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py +++ b/dizoo/atari/config/serial/pong/pong_pc_mcts_config.py @@ -51,10 +51,11 @@ seq_len=seq_len, ), learn=dict( - batch_size=64, - learning_rate=0.01, + batch_size=32, + learning_rate=5e-4, learner=dict(hook=dict(save_ckpt_after_iter=1000)), - train_epoch=20, + train_epoch=100, + hidden_state_noise=0, ), eval=dict(evaluator=dict(eval_freq=40, )) ), From e30c98d69b5dd4942c6dc57f7fcc3f7379220f4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 17 Apr 2023 11:33:27 +0800 Subject: [PATCH 26/26] update metric monitor --- ding/policy/pc.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ding/policy/pc.py b/ding/policy/pc.py index 97be9bb615..b6865a0e8f 100644 --- a/ding/policy/pc.py +++ b/ding/policy/pc.py @@ -28,12 +28,12 @@ def __init__(self, seq, mask): def forward(self, pred_y, target_y): if not self.seq: - return self.ce(pred_y[:, -1, :], target_y[:, -1]) + return self.ce(pred_y[:, -1, :], target_y[:, -1]), 1 if not self.mask: losses = 0 for i in range(target_y.shape[1]): losses += self.ce(pred_y[:, i, :], target_y[:, i]) - return losses + return losses, target_y.shape[1] else: eqs = [] losses = 0 @@ -50,8 +50,8 @@ def forward(self, pred_y, target_y): cnt += torch.sum(eqs[-1]) # Update eqs eqs.append((torch.argmax(pred_y[:, i, :], dim=-1) == target_y[:, i])) - eqs[-1] = eqs[-1] and eqs[-2] - return losses / cnt + eqs[-1] = eqs[-1] & eqs[-2] + return losses / cnt, cnt / target_y.shape[0] @POLICY_REGISTRY.register('pc_mcts') @@ -149,7 +149,7 @@ def _forward_learn(self, data): hidden_state_loss = torch.tensor(0.) else: hidden_state_loss = self._hidden_state_loss(pred_hidden_states, target_hidden_states) - action_loss = self._action_loss(pred_action, action) + action_loss, action_number = self._action_loss(pred_action, action) loss = hidden_state_loss + action_loss forward_time = self._timer.value @@ -173,12 +173,13 @@ def _forward_learn(self, data): 'action_loss': action_loss.item(), 'forward_time': forward_time, 'backward_time': backward_time, + 'action_number': action_number, 'sync_time': sync_time, } def _monitor_vars_learn(self): return ['cur_lr', 'total_loss', 'hidden_state_loss', 'action_loss', - 'forward_time', 'backward_time', 'sync_time'] + 'forward_time', 'backward_time', 'sync_time', 'action_number'] def _init_eval(self): self._eval_model = model_wrap(self._model, wrapper_name='base')