强化学习算法总结

强化学习经典算法总结

各经典算法总结

算法 off/on policy 说明
DQN off policy 维护Q网络和Target Q网络
DDQN off policy 维护Q网络和Target Q网络,下一步的action从当前Q网络选择,代入Target Q网络,避免Target过大
SARSA on policy 收集一次transition,更新一次模型
PG on policy 需要完整采样,解决连续动作问题,P网络直接输出动作的概率,优化目标为最大化R
PPO off policy 需要完整采样,PG的变形,引入importance sampling可使用任意分布的采样。优化方式和PG类似:policy网络最大化R,critic网络最小化误差。引入KL散度和clip。
A2C/A3C on policy PG的改进,不需要完整采样,结合了PG和时序差分,policy输出动作,最大化R。critic网络预估V值,计算优势函数$Advantage_t = r_t+V(s_{t+1})-V(s_t)$
DDPQ off policy DQN的连续动作版本,Actor给出动作P值,Critic评估动作的Q值,Target Actor和Target Critic作为目标网络,Actor目标为最大化期望回报,Critic目标为最小化与真实回报的误差。
TD3 off policy DDPG的改进版本,Critic使用两个相同结构的Q网络,作为目标Q值时取两个Q网络中较小的那个。actor的更新相对critic延迟。

各算法核心公式

DQN

初始化$Q$网络和Target$\hat Q$网络

根据Q值获取 action:

$a_t= argmax_a(Q(s_t,\theta)\space if \space random > \epsilon\ else \ random (A)$

存储 transition:$(s_t,a_t,r_t,s_{t+1})$

计算 Loss:

根据 transition:$(s_t,a_t,r_t,s_{t+1})$

$\hat y_j = Q(s_j,a_j,\theta)$

$y_j=r_j+\gamma max_{a_{j+1}}\hat Q(s_{j+1},a_{j+1};\hat \theta)
\space if\ episode \ not \ terminates \space at \space step \space j+1 \ else \ r_j $

$loss = (\hat y_j-y_j)^2$

DDQN

初始化$Q$网络和Target$\hat Q$网络

根据Q值获取 action:

$a_t=argmax_a(Q(s_t,\theta)\space if \space random > \epsilon \ else \ random (A) $

存储 transition:$(s_t,a_t,r_t,s_{t+1})$

计算 Loss:

根据 transition:$(s_t,a_t,r_t,s_{t+1})$

$\hat y_j = Q(s_j,a_j,\theta)$

$y_j=r_j+\gamma \hat Q(s_{j+1},argmax_{a_{j+1}}(Q(s_{j+1},\theta));\hat \theta) \space if\space episode \space not \ terminates \space at \space step \space j+1 \ else \ r_j $

$loss = (\hat y_j-y_j)^2$

SARSA(表格方式)

初始化$Q$网络

根据Q值和$\epsilon-greedy$获取 action:

记录transition:$(s_t,a_t,r_t,s_{t+1})$

$\hat y_t = Q(s_t,a_t,\theta)$

$y_t = r_t+\gamma (Q(s_{t+1},a_{t+1}))$,$a_{t+1}$根据Q值和$\epsilon-greedy$获取

更新:$Q(s,a)\leftarrow$ $Q(s,a)$+$lr*(y_t - \hat y_t)$

Policy Gradient

初始化P网络

使用P网络获取一个episode的transition:$(s_1,a_1,r_1,s_{2},a_2,r_2…s_t,a_{t},r_{t})$

计算该episode的$G$:$G_t=r_{t+1}+\gamma G_{t+1}$

对于该episode中的每一个动作,计算$loss_{t} = -G_{t}*logP(a_t)$

批量更新梯度

Loss 的推导过程:

https://github.com/abcdcamey/RL-learning/blob/main/PolicyGradient/agent.png

PPO-Clip(Proximal Policy Optimization)

初始化Actor网络和Critic网络,Actor网络用来采集数据同时记录下动作概率P,Critic网络用来评估优势函数。

使用Actor网络收集若干个episode的transition$(s_t,a_t,dist_t,value_t,logprob_{a_t},r_t…)$

对于每一个transition:

1:计算优势函数$Advantage_t = \sum_{k=t}^{T}discount_{k-t}(r_{k}+\gamma v_{k+1}-v_{k})$

2:计算Importance Samples的调整项:$prob ratio_t = e^{newprob_t}/e^{oldprob_t}$,$new prob_{t}$为最新Actor产生的$logP(a_t)$

3、根据公式,计算目标函数,并根据目标函数计算Actor的更新梯度:$J(\theta)=actor \ loss=-min(prob \ ratio \cdot Advantage_t,Clip(prob \ ratio,1-\epsilon,1+\epsilon) \cdot Advantage_t)$

4、计算Critic的网络的目标函数:

$critic \ loss = ((Advantage_t+old \ critic \ value_t)-new \ critic \ value_t)^2$,$new \ critic \ value_t$为最新Critic产生的状态价值函数

5、根据两个目标函数梯度,反向传播更新网络参数

A2C

初始化Actor网络和Critic网络,Actor网络用来采集数据同时记录下动作概率P,Critic网络用来评估优势函数。

使用Actor网络收集n个step的transition$(s_t,a_t,dist_t,value_t,logprob_{a_t},r_t…)$

每n个step的transition:

1、计算优势函数$Advantage_t = \sum_{k=n}^{1}r_{k}+\gamma v_{k+1}-v_{k}$

2、对于每一个动作,计算$actor \ loss_{t} = -Advantage_t \cdot logP(a_t)$

3、对于每一个动作,计算$critic \ loss_t=Advantage_t^2$

4、根据两个目标函数梯度,反向传播更新网络参数

DDPG(Deep Deterministic Policy Gradient)

初始化Actor网络($u(s)$)和Critic网络($Q(s,a)$),copy初始化Target Actor网络$u’$和Target Critic网络$Q’$

根据Actor获取 action,存储 transition:$(s_t,a_t,r_t,s_{t+1})$

对于一个batch的transition:

1、计算Critic的目标$y_t=r_t+\gamma Q’(s_{t+1},u’(s_{t+1}))$

2、计算$critic \ loss^{\theta^Q} = (y_t-Q(s_t,a_t))^2$,更新$Q$的参数

3、计算$actor \ loss^{\theta^u} = -Q(s_t,u(s_t))$,更新$u$的参数

4、软更新Target Actor网络$u’$和Target Critic网络$Q’$

TD3(Twin Delayed DDPG)

初始化Actor网络($u(s)$)和Critic网络($Q(s,a)$),copy初始化Target Actor网络$u’$和Target Critic网络$Q’$

根据Actor获取 action,存储 transition:$(s_t,a_t,r_t,s_{t+1})$

对于一个batch的transition:

1、计算Critic的目标$y_t=r_t+\gamma min(Q_1’(s_{t+1},u’(s_{t+1})),Q_2’(s_{t+1},u’(s_{t+1})))$

2、计算$critic \ loss = 0.5 \cdot (y_t-Q_1(s_t,a_t))^2+0.5 \cdot (y_t-Q_2(s_t,a_t))^2$

3、计算$actor \ loss^{\theta^u} = -Q_1(s_t,u(s_t))$,更新$u$的参数

4、更新critic网络和actor网络,actor更新相对critic延迟

5、定期软更新Target Actor网络$u’$和Target Critic网络$Q_1’,Q_2’$