-
Notifications
You must be signed in to change notification settings - Fork 1
/
differentiability.py
171 lines (144 loc) · 4.76 KB
/
differentiability.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# %%
import os
import matplotlib.pyplot as plt
import numpy as np
import optax
# if JAX_BACKEND is set the import will be from jax.numpy
if os.environ.get("JAX_STL_BACKEND") == "jax":
print("Using JAX backend")
import jax
from ds.stl_jax import STL, RectAvoidPredicate, RectReachPredicate
from ds.utils import default_tensor
else:
print("Using PyTorch backend")
import torch
from torch.optim import Adam
from ds.stl import STL, RectAvoidPredicate, RectReachPredicate
from ds.utils import default_tensor
def eval_reach_avoid(mute=False):
"""
The evaluation of a formula
"""
# Define the formula predicates
# goal is a rectangle area centered in [0, 0] with width and height 1
goal = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal"))
# obs is a rectangle area centered in [3, 2] with width and height 1
obs = STL(RectAvoidPredicate(np.array([3, 2]), np.array([1, 1]), "obs"))
# form is the formula goal eventually in 0 to 10 and avoid obs always in 0 to 10
form = goal.eventually(0, 10) & obs.always(0, 10)
# Define 2 initial paths in batch
path_1 = default_tensor(
np.array(
[
[
[9, 9],
[8, 8],
[7, 7],
[6, 6],
[5, 5],
[4, 4],
[3, 3],
[2, 2],
[1, 1],
[0, 0],
[0, 0],
[0, 0],
[0, 0],
],
[
[9, 9],
[8, 8],
[7, 7],
[6, 6],
[5, 5],
[4, 4],
[3, 3],
[2, 2],
[1, 1],
[1, 1],
[1, 1],
[1, 1],
[1, 1],
],
]
)
)
# eval the formula, default at time 0
res1 = form.eval(path=path_1)
if not mute:
print("eval result at time 0: ", res1)
# eval the formula at time 2
res2 = form.eval(path=path_1, t=2)
if not mute:
print("eval result at time 2: ", res2)
return res1, res2
def backward(mute=True):
"""
Planning with gradient descent
"""
# Define the formula predicates
# goal_1 is a rectangle area centered in [0, 0] with width and height 1
goal_1 = STL(RectReachPredicate(np.array([0, 0]), np.array([1, 1]), "goal_1"))
# goal_2 is a rectangle area centered in [2, 2] with width and height 1
goal_2 = STL(RectReachPredicate(np.array([2, 2]), np.array([1, 1]), "goal_2"))
# form is the formula goal_1 eventually in 0 to 5 and goal_2 eventually in 0 to 5
# and that holds always in 0 to 8
# In other words, the path will repeatedly visit goal_1 and goal_2 in 0 to 13
form = (goal_1.eventually(0, 5) & goal_2.eventually(0, 5)).always(0, 8)
path = default_tensor(
np.array(
[
[
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[1, 0],
[1, 0],
],
]
)
)
loss = None
lr = 0.1
num_iterations = 1000
if os.environ.get("JAX_STL_BACKEND") == "jax":
solver = optax.adam(lr)
var_solver_state = solver.init(path)
@jax.jit
def train_step(params, solver_state):
# Performs a one step update.
(loss), grad = jax.value_and_grad(form.eval)(params)
updates, solver_state = solver.update(-grad, solver_state)
params = optax.apply_updates(params, updates)
return params, solver_state, loss
for _ in range(num_iterations):
path, var_solver_state, train_loss = train_step(path, var_solver_state)
loss = form.eval(path)
else:
# PyTorch backend (slower when num_iterations is high)
path.requires_grad = True
opt = Adam(params=[path], lr=lr)
for _ in range(num_iterations):
loss = -torch.mean(form.eval(path))
opt.zero_grad()
loss.backward()
opt.step()
if not mute:
print(f"final loss: {loss.item()}")
print(path)
plt.plot(path[0, :, 0].numpy(force=True), path[0, :, 1].numpy(force=True))
plt.show()
return path, loss
if __name__ == "__main__":
eval_reach_avoid()
backward()
# %%