forked from tesslerc/TD3-JAX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
183 lines (150 loc) · 6.76 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import functools
import itertools
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from jax import nn
from jax.experimental import optix
from jax.random import PRNGKey
from networks import Actor, Critic, Constant
from util import single_mse, gaussian_likelihood, double_mse, soft_update
class Agent(object):
"""Agent class for the TD3 algorithm. Combines both the agent and the learner functions."""
def __init__(
self,
min_action: np.array,
max_action: np.array,
action_dim: int,
lr: float,
discount: float,
policy_freq: int,
initial_log_alpha=-3.5,
):
self.initial_log_alpha = initial_log_alpha
self.min_action = min_action
self.action_dim = action_dim
self.max_action = max_action
self.lr = lr
self.discount = discount
self.policy_freq = policy_freq
self.target_entropy = -action_dim
self.actor_opt_init, self.actor_opt_update = optix.adam(lr)
self.critic_opt_init, self.critic_opt_update = optix.adam(lr)
self.alpha_opt_init, self.critic_opt_update = optix.adam(lr)
self.actor = hk.without_apply_rng(hk.transform(self.actor))
self.critic = hk.without_apply_rng(hk.transform(self.critic))
self.log_alpha = hk.without_apply_rng(hk.transform(self.log_alpha))
def actor(self, x):
return Actor()(x, action_dim=self.action_dim)
@staticmethod
def critic(x, a):
return Critic()(x, a)
def log_alpha(self):
return Constant()(self.initial_log_alpha)
def train_loop(
self, rng: jnp.ndarray, sample_obs: np.ndarray, sample_action: np.ndarray,
):
rng, actor_rng, critic_rng, alpha_rng = jax.random.split(rng, 4)
actor_params = target_actor_params = self.actor.init(actor_rng, sample_obs)
actor_opt_state = self.actor_opt_init(actor_params)
critic_params = target_critic_params = self.critic.init(
critic_rng, sample_obs, sample_action
)
critic_opt_state = self.critic_opt_init(critic_params)
alpha_params = self.log_alpha.init(alpha_rng)
alpha_opt_state = self.alpha_opt_init(alpha_params)
for update in itertools.count():
sample = yield actor_params
rng, actor_rng, td_rng = jax.random.split(rng, 3)
target_Q = jax.lax.stop_gradient(
self.get_td_target(
next_obs=sample.next_obs,
reward=sample.reward,
not_done=1 - sample.done,
actor_params=actor_params,
critic_target=target_critic_params,
rng=td_rng,
)
)
critic_params, critic_opt_state = self.update_critic(
critic_params=critic_params,
opt_state=critic_opt_state,
obs=sample.obs,
action=sample.action,
target_q=target_Q,
)
if update % self.policy_freq == 0:
actor_params, actor_opt_state, log_p = self.update_actor(
actor_params=actor_params,
critic_params=critic_params,
opt_state=actor_opt_state,
obs=sample.obs,
rng=actor_rng,
)
alpha_params, alpha_opt_state = self.update_alpha(
alpha_params=alpha_params, opt_state=alpha_opt_state, log_pi=log_p
)
target_actor_params = soft_update(target_actor_params, actor_params)
target_critic_params = soft_update(target_critic_params, critic_params)
@functools.partial(jax.jit, static_argnums=0)
def update_actor(self, actor_params, critic_params, opt_state, obs, rng):
def loss(params):
mu, log_sig = self.actor.apply(params, obs)
pi = self.sample_pi(mu, log_sig, rng)
action = self.postprocess_action(pi)
likelihood = gaussian_likelihood(pi, mu, log_sig)
likelihood -= jnp.sum(
jnp.log(nn.relu(1 - jnp.tanh(pi) ** 2) + 1e-6), axis=1
)
q1, q2 = self.critic.apply(critic_params, obs, action)
min_q = jnp.minimum(q1, q2)
actor_loss = single_mse(likelihood, min_q)
return jnp.mean(actor_loss), likelihood
gradient, log_pi = jax.grad(loss, has_aux=True)(actor_params)
updates, opt_state = self.actor_opt_update(gradient, opt_state)
new_params = optix.apply_updates(actor_params, updates)
return new_params, opt_state, log_pi
@functools.partial(jax.jit, static_argnums=0)
def update_critic(self, critic_params, opt_state, obs, action, target_q):
def loss(params):
current_Q1, current_Q2 = self.critic.apply(params, obs, action)
critic_loss = double_mse(current_Q1, current_Q2, target_q)
return jnp.mean(critic_loss)
gradient = jax.grad(loss)(critic_params)
updates, opt_state = self.critic_opt_update(gradient, opt_state)
new_params = optix.apply_updates(critic_params, updates)
return new_params, opt_state
@functools.partial(jax.jit, static_argnums=0)
def update_alpha(self, alpha_params, opt_state, log_pi):
log_pi = jax.lax.stop_gradient(log_pi)
def loss(params):
@jax.vmap
def alpha_loss_fn(lp):
return (
self.log_alpha.apply(params) * (-lp - self.target_entropy)
).mean()
return jnp.mean(alpha_loss_fn(log_pi))
gradient = jax.grad(loss)(alpha_params)
updates, opt_state = self.critic_opt_update(gradient, opt_state)
new_params = optix.apply_updates(alpha_params, updates)
return new_params, opt_state
@functools.partial(jax.jit, static_argnums=0)
def get_td_target(self, next_obs, reward, not_done, critic_target, **kwargs):
next_action = self.policy(obs=next_obs, **kwargs)
target_Q1, target_Q2 = self.critic.apply(critic_target, next_obs, next_action)
target_Q = jnp.minimum(target_Q1, target_Q2)
target_Q = reward + not_done * self.discount * target_Q
return target_Q
@staticmethod
def sample_pi(mu, log_sig, rng):
return mu + jax.random.normal(rng, mu.shape) * jnp.exp(log_sig)
def postprocess_action(self, pi):
return nn.sigmoid(pi) * (self.max_action - self.min_action) + self.min_action
@functools.partial(jax.jit, static_argnums=0)
def policy(
self, actor_params: hk.Params, obs: np.ndarray, rng: PRNGKey = None
) -> jnp.DeviceArray:
mu, log_sig = self.actor.apply(actor_params, obs)
pi = mu if rng is None else self.sample_pi(mu, log_sig, rng)
return self.postprocess_action(pi)