Skip to content

Commit

Permalink
Merge pull request #97 from cpnota/release/0.3.0
Browse files Browse the repository at this point in the history
Release/0.3.0
  • Loading branch information
cpnota authored Aug 2, 2019
2 parents c9c85ef + 28f1e81 commit e4fbe6e
Show file tree
Hide file tree
Showing 74 changed files with 924 additions and 576 deletions.
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
install:
pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
pip install torchvision tensorflow
pip install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl
pip install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp37-cp37m-linux_x86_64.whl
pip install tensorflow
pip install -e .

lint:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ We provide out-of-the-box modules for:
- [x] Generalized Advantage Estimation (GAE)
- [x] Target networks
- [x] Polyak averaging
- [x] Easy parameter and learning rate scheduling
- [x] An enhanced `nn` module (includes dueling layers, noisy layers, action bounds, and the coveted `nn.Flatten`)
- [x] `gym` to `pytorch` wrappers
- [x] Atari wrappers
Expand Down
3 changes: 2 additions & 1 deletion all/agents/_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from all.optim import Schedulable

class Agent(ABC):
class Agent(ABC, Schedulable):
"""
A reinforcement learning agent.
Expand Down
30 changes: 18 additions & 12 deletions all/agents/a2c.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import torch
from all.environments import State
from all.memory import NStepAdvantageBuffer
from ._agent import Agent

Expand All @@ -22,26 +20,34 @@ def __init__(
self.n_envs = n_envs
self.n_steps = n_steps
self.discount_factor = discount_factor
self._states = None
self._actions = None
self._batch_size = n_envs * n_steps
self._buffer = self._make_buffer()
self._features = []

def act(self, states, rewards):
self._buffer.store(states, torch.zeros(self.n_envs), rewards)
self._train()
features = self.features(states)
self._features.append(features)
return self.policy(features)
self._store_transitions(rewards)
self._train(states)
self._states = states
self._actions = self.policy.eval(self.features.eval(states))
return self._actions

def _train(self):
def _store_transitions(self, rewards):
if self._states:
self._buffer.store(self._states, self._actions, rewards)

def _train(self, states):
if len(self._buffer) >= self._batch_size:
states = State.from_list(self._features)
_, _, advantages = self._buffer.sample(self._batch_size)
self.v(states)
states, actions, advantages = self._buffer.advantages(states)
# forward pass
features = self.features(states)
self.v(features)
self.policy(features, actions)
# backward pass
self.v.reinforce(advantages)
self.policy.reinforce(advantages)
self.features.reinforce()
self._features = []

def _make_buffer(self):
return NStepAdvantageBuffer(
Expand Down
2 changes: 1 addition & 1 deletion all/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _train(self):
# train q function
td_errors = (
rewards +
self.discount_factor * self.q.eval(next_states, self.policy.eval(next_states)) -
self.discount_factor * self.q.target(next_states, self.policy.target(next_states)) -
self.q(states, torch.cat(actions))
)
self.q.reinforce(weights * td_errors)
Expand Down
2 changes: 1 addition & 1 deletion all/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _train(self):
self.minibatch_size)
td_errors = (
rewards +
self.discount_factor * torch.max(self.q.eval(next_states), dim=1)[0] -
self.discount_factor * torch.max(self.q.target(next_states), dim=1)[0] -
self.q(states, actions)
)
self.q.reinforce(weights * td_errors)
Expand Down
Empty file.
29 changes: 15 additions & 14 deletions all/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
self.n_steps = n_steps
self.discount_factor = discount_factor
self.lam = lam
self._states = None
self._actions = None
self._epsilon = epsilon
self._epochs = epochs
self._batch_size = n_envs * n_steps
Expand All @@ -34,14 +36,19 @@ def __init__(
self._features = []

def act(self, states, rewards):
self._train()
actions = self.policy.eval(self.features.eval(states))
self._buffer.store(states, actions, rewards)
return actions
self._store_transitions(rewards)
self._train(states)
self._states = states
self._actions = self.policy.eval(self.features.eval(states))
return self._actions

def _train(self):
def _store_transitions(self, rewards):
if self._states:
self._buffer.store(self._states, self._actions, rewards)

def _train(self, _states):
if len(self._buffer) >= self._batch_size:
states, actions, advantages = self._buffer.sample(self._batch_size)
states, actions, advantages = self._buffer.advantages(_states)
with torch.no_grad():
features = self.features.eval(states)
pi_0 = self.policy.eval(features, actions)
Expand All @@ -65,18 +72,12 @@ def _train_minibatch(self, states, actions, pi_0, advantages, targets):
self.v.reinforce(targets - self.v(features))
self.features.reinforce()

def _compute_targets(self, returns, next_states, lengths):
return (
returns +
(self.discount_factor ** lengths)
* self.v.eval(self.features.eval(next_states))
)

def _compute_policy_loss(self, pi_0, advantages):
def _policy_loss(pi_i):
ratios = torch.exp(pi_i - pi_0)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1.0 - self._epsilon, 1.0 + self._epsilon) * advantages
epsilon = self._epsilon
surr2 = torch.clamp(ratios, 1.0 - epsilon, 1.0 + epsilon) * advantages
return -torch.min(surr1, surr2).mean()
return _policy_loss

Expand Down
31 changes: 20 additions & 11 deletions all/agents/sac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from all.experiments import DummyWriter
from all.logging import DummyWriter
from ._agent import Agent

class SAC(Agent):
Expand All @@ -9,7 +9,9 @@ def __init__(self,
q_2,
v,
replay_buffer,
entropy_regularizer=0.01,
entropy_target=-2., # usually -action_space.size[0]
temperature_initial=0.1,
lr_temperature=1e-4,
discount_factor=0.99,
minibatch_size=32,
replay_start_size=5000,
Expand All @@ -28,7 +30,10 @@ def __init__(self,
self.update_frequency = update_frequency
self.minibatch_size = minibatch_size
self.discount_factor = discount_factor
self.entropy_regularizer = entropy_regularizer
# vars for learning the temperature
self.entropy_target = entropy_target
self.temperature = temperature_initial
self.lr_temperature = lr_temperature
# data
self.env = None
self.state = None
Expand All @@ -39,8 +44,7 @@ def act(self, state, reward):
self._store_transition(state, reward)
self._train()
self.state = state
with torch.no_grad():
self.action = self.policy(state)
self.action = self.policy.eval(state)
return self.action

def _store_transition(self, state, reward):
Expand All @@ -58,14 +62,17 @@ def _train(self):
# compute targets for Q and V
with torch.no_grad():
_actions, _log_probs = self.policy(states, log_prob=True)
q_targets = rewards + self.discount_factor * self.v.eval(next_states)
q_targets = rewards + self.discount_factor * self.v.target(next_states)
v_targets = torch.min(
self.q_1.eval(states, _actions),
self.q_2.eval(states, _actions),
) - self.entropy_regularizer * _log_probs
self.q_1.target(states, _actions),
self.q_2.target(states, _actions),
) - self.temperature * _log_probs
temperature_loss = ((_log_probs + self.entropy_target).detach().mean())
self.writer.add_loss('entropy', -_log_probs.mean())
self.writer.add_loss('v_mean', v_targets.mean())
self.writer.add_loss('r_mean', rewards.mean())
self.writer.add_loss('temperature_loss', temperature_loss)
self.writer.add_loss('temperature', self.temperature)

# update Q-functions
q_1_errors = q_targets - self.q_1(states, actions)
Expand All @@ -79,15 +86,17 @@ def _train(self):

# train policy
_actions, _log_probs = self.policy(states, log_prob=True)

loss = -(
self.q_1(states, _actions, detach=False)
- self.entropy_regularizer * _log_probs
- self.temperature * _log_probs
).mean()
loss.backward()
self.policy.step()
self.q_1.zero_grad()

# adjust temperature
self.temperature += self.lr_temperature * temperature_loss

def _should_train(self):
return (self.frames_seen > self.replay_start_size and
self.frames_seen % self.update_frequency == 0)
2 changes: 1 addition & 1 deletion all/agents/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def act(self, state, reward):
if self._previous_features:
td_error = (
reward
+ self.gamma * self.v.eval(self.features.eval(state))
+ self.gamma * self.v.target(self.features.eval(state))
- self.v(self._previous_features)
)
self.v.reinforce(td_error)
Expand Down
2 changes: 1 addition & 1 deletion all/agents/vqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def act(self, state, reward):
if self.previous_state:
td_error = (
reward
+ self.gamma * torch.max(self.q.eval(state), dim=1)[0]
+ self.gamma * torch.max(self.q.target(state), dim=1)[0]
- self.q(self.previous_state, self.previous_action)
)
self.q.reinforce(td_error)
Expand Down
2 changes: 1 addition & 1 deletion all/agents/vsarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def act(self, state, reward):
if self.previous_state:
td_error = (
reward
+ self.gamma * self.q.eval(state, action)
+ self.gamma * self.q.target(state, action)
- self.q(self.previous_state, self.previous_action)
)
self.q.reinforce(td_error)
Expand Down
1 change: 1 addition & 0 deletions all/approximation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .v_network import VNetwork
from .feature_network import FeatureNetwork
from .target import TargetNetwork, FixedTarget, PolyakTarget, TrivialTarget
from .checkpointer import Checkpointer, DummyCheckpointer, PeriodicCheckpointer
48 changes: 39 additions & 9 deletions all/approximation/approximation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import torch
from torch.nn import utils
from torch.nn.functional import mse_loss
from all.experiments import DummyWriter
from .target import FixedTarget, TrivialTarget
from all.logging import DummyWriter
from .target import TrivialTarget
from .checkpointer import PeriodicCheckpointer

DEFAULT_CHECKPOINT_FREQUENCY = 200

class Approximation():
def __init__(
Expand All @@ -13,12 +17,15 @@ def __init__(
loss_scaling=1,
loss=mse_loss,
name='approximation',
scheduler=None,
target=None,
writer=DummyWriter(),
checkpointer=None
):
self.model = model
self.device = next(model.parameters()).device
self._target = target or TrivialTarget()
self._scheduler = scheduler
self._target.init(model)
self._updates = 0
self._optimizer = optimizer
Expand All @@ -29,17 +36,38 @@ def __init__(
self._writer = writer
self._name = name

if checkpointer is None:
checkpointer = PeriodicCheckpointer(DEFAULT_CHECKPOINT_FREQUENCY)
self._checkpointer = checkpointer
self._checkpointer.init(
self.model,
os.path.join(writer.log_dir, name + '.pt')
)

def __call__(self, *inputs, detach=True):
'''
Run a forward pass of the model.
If detach=True, the computation graph is cached and the result is detached.
If detach=False, nothing is cached and instead returns the attached result.
'''
result = self.model(*inputs)
if detach:
self._enqueue(result)
return result.detach()
return result

def eval(self, *inputs):
'''Run a forward pass of the model in no_grad mode.'''
with torch.no_grad():
return self.model(*inputs)

def target(self, *inputs):
'''Run a forward pass of the target network.'''
return self._target(*inputs)

def reinforce(self, errors, retain_graph=False):
'''Update the model using the cache and the errors passed in.'''
batch_size = len(errors)
cache = self._dequeue(batch_size)
if cache.requires_grad:
Expand All @@ -49,11 +77,16 @@ def reinforce(self, errors, retain_graph=False):
self.step()

def step(self):
'''Given that a bakcward pass has been made, run an optimization step.'''
if self._clip_grad != 0:
utils.clip_grad_norm_(self.model.parameters(), self._clip_grad)
self._optimizer.step()
self._optimizer.zero_grad()
self._target.update()
if self._scheduler:
self._writer.add_schedule(self._name + '/lr', self._optimizer.param_groups[0]['lr'])
self._scheduler.step()
self._checkpointer()

def zero_grad(self):
self._optimizer.zero_grad()
Expand All @@ -73,10 +106,7 @@ def _dequeue(self, batch_size):
self._cache = self._cache[i:]
return items

def _init_target_model(self, target_update_frequency):
if target_update_frequency is not None:
self._target = FixedTarget(target_update_frequency)
self._target.init(self.model)
else:
self._target = TrivialTarget()
self._target.init(self.model)
class ConstantLR():
'''Dummy LRScheduler'''
def step(self):
pass
Loading

0 comments on commit e4fbe6e

Please sign in to comment.