-
Notifications
You must be signed in to change notification settings - Fork 25
/
exploration.py
100 lines (88 loc) · 3.39 KB
/
exploration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from torch import nn
from torch import distributions as torchd
import models
import networks
import tools
class Random(nn.Module):
def __init__(self, config):
self._config = config
def actor(self, feat):
shape = feat.shape[:-1] + [self._config.num_actions]
if self._config.actor_dist == 'onehot':
return tools.OneHotDist(torch.zeros(shape))
else:
ones = torch.ones(shape)
return tools.ContDist(torchd.uniform.Uniform(-ones, ones))
def train(self, start, context):
return None, {}
#class Plan2Explore(tools.Module):
class Plan2Explore(nn.Module):
def __init__(self, config, world_model, reward=None):
self._config = config
self._reward = reward
self._behavior = models.ImagBehavior(config, world_model)
self.actor = self._behavior.actor
stoch_size = config.dyn_stoch
if config.dyn_discrete:
stoch_size *= config.dyn_discrete
size = {
'embed': 32 * config.cnn_depth,
'stoch': stoch_size,
'deter': config.dyn_deter,
'feat': config.dyn_stoch + config.dyn_deter,
}[self._config.disag_target]
kw = dict(
inp_dim=config.dyn_stoch, # pytorch version
shape=size, layers=config.disag_layers, units=config.disag_units,
act=config.act)
self._networks = [
networks.DenseHead(**kw) for _ in range(config.disag_models)]
self._opt = tools.optimizer(config.opt, self.parameters(),
config.model_lr, config.opt_eps, config.weight_decay)
#self._opt = tools.Optimizer(
# 'ensemble', config.model_lr, config.opt_eps, config.grad_clip,
# config.weight_decay, opt=config.opt)
def train(self, start, context, data):
metrics = {}
stoch = start['stoch']
if self._config.dyn_discrete:
stoch = tf.reshape(
stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]))
target = {
'embed': context['embed'],
'stoch': stoch,
'deter': start['deter'],
'feat': context['feat'],
}[self._config.disag_target]
inputs = context['feat']
if self._config.disag_action_cond:
inputs = tf.concat([inputs, data['action']], -1)
metrics.update(self._train_ensemble(inputs, target))
metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
return None, metrics
def _intrinsic_reward(self, feat, state, action):
inputs = feat
if self._config.disag_action_cond:
inputs = tf.concat([inputs, action], -1)
preds = [head(inputs, tf.float32).mean() for head in self._networks]
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
if self._config.disag_log:
disag = tf.math.log(disag)
reward = self._config.expl_intr_scale * disag
if self._config.expl_extr_scale:
reward += tf.cast(self._config.expl_extr_scale * self._reward(
feat, state, action), tf.float32)
return reward
def _train_ensemble(self, inputs, targets):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset:]
inputs = inputs[:, :-self._config.disag_offset]
targets = tf.stop_gradient(targets)
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
metrics = self._opt(tape, loss, self._networks)
return metrics