-
Notifications
You must be signed in to change notification settings - Fork 6
/
enjoy_tls.py
32 lines (23 loc) · 873 Bytes
/
enjoy_tls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gym
from baselines import deepq
from TrafficLightFlow import GetTrafficLightEnv, getFlowParamsForTls
from flow.core.experiment import Experiment
t = 0
def static_rl_actions(state):
global t
t += 1
return t%20 == 0
def main():
env = GetTrafficLightEnv()
act = deepq.learn(env, network='mlp', total_timesteps=0, load_path="tls_model.pkl")
reward = 0
iterations = 1
exp = Experiment(getFlowParamsForTls())
## This is the RL agent that is using the trained model that we saved from train_tls file
rl_agent = lambda state: act(state[None])[0]
## This is the static agent that switches the light every 20s
static_agent = static_rl_actions
# Passing the appropriate lambda among static and rl, you can perform the experiment
exp.run(10, rl_agent, convert_to_csv=True)
if __name__ == '__main__':
main()