zl程序教程

您现在的位置是:首页 >  后端

当前栏目

【深度强化学习】Actor-Critic算法

算法学习 深度 强化 Actor
2023-09-14 09:10:02 时间

Actor-Critic算法

回顾策略梯度算法:
在这里插入图片描述
Actor-Critic算法的区别就是对 R ( τ n ) R(\tau^n) R(τn)进行了修改。
在这里插入图片描述
R ( τ n ) R(\tau^n) R(τn)具有上述三种形式时,便是经典的AC算法,在AC算法,我们通过另一个叫做Critic的神经网络来估计 V π ( s t ) V^{\pi}(s_{t}) Vπ(st)(或其他,视具体情况而定)。
本博客实现的便是基于TD残差的AC算法,其策略网络的梯度如下图所示
在这里插入图片描述
相应的ActorCritic损失函数为:
在这里插入图片描述
这里 e i = r t n + V π ( s t + 1 n ) − V π ( s t n ) e_i=r^n_t+V^{\pi}(s_{t+1}^n)-V^{\pi}(s^n_{t}) ei=rtn+Vπ(st+1n)Vπ(stn),称为TD残差,而 A π ( s , a ) = e i A_{\pi}(s,a)=e_i Aπ(s,a)=ei

代码

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

# Hyperparameters
learning_rate = 0.0002
gamma = 0.98
n_rollout = 10


class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []

        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)                             #策略网络actor,根据状态输出每个动作的概率
        self.fc_v = nn.Linear(256, 1)                              #值网络critic,根据状态输出价值V(s)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, softmax_dim=0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)                         #softmax输出概率值
        return prob

    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)                                             #输出Critic对状态s的评价
        return v

    def put_data(self, transition):
        self.data.append(transition)

    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []    #s,a,r,s_,done
        for transition in self.data:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r / 100.0])                                      #为什么除以100?
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])                                   #s_是否是terminal

        s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(
            a_lst), \
                                                               torch.tensor(r_lst, dtype=torch.float), torch.tensor(
            s_prime_lst, dtype=torch.float), \
                                                               torch.tensor(done_lst, dtype=torch.float)
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch

    def train_net(self):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.v(s_prime) * done                                  #计算r+gamma*V(s_),若s_为terminal状态,则td_target=r
        delta = td_target - self.v(s)                                                   #计算TD残差,r+gamma*V(s_)-V(s)

        pi = self.pi(s, softmax_dim=1)
        pi_a = pi.gather(1, a)
        loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())     #前半部分为actor网络的损失函数,后半部分为critic网络的损失函数

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()


def main():
    env = gym.make('CartPole-v1')
    model = ActorCritic()
    print_interval = 20
    score = 0.0

    for n_epi in range(10000):
        done = False
        s = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()                              #根据actor选择动作a
                s_prime, r, done, info = env.step(a)
                model.put_data((s, a, r, s_prime, done))

                s = s_prime
                score += r

                if done:
                    break

            model.train_net()                                    #每回合训练一次网络

        if n_epi % print_interval == 0 and n_epi != 0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, score / print_interval))
            score = 0.0
    env.close()


if __name__ == '__main__':
    main()

结果展示

在这里插入图片描述