从零实现Actor-Critic:用PyTorch征服CartPole的实战指南
在强化学习领域,理论推导和代码实现之间往往存在巨大的鸿沟。许多学习者能够理解策略梯度定理的数学证明,却在面对具体实现时束手无策。本文将带你跨越这道鸿沟,使用PyTorch从零开始构建一个完整的Actor-Critic算法,并在经典的CartPole环境中验证其效果。
1. 环境搭建与核心概念
CartPole(倒立摆)是强化学习中最经典的测试环境之一。游戏目标是通过左右移动小车来保持顶部的杆子竖直不倒。这个看似简单的任务包含了强化学习的核心挑战:如何在连续的状态空间中进行决策,并通过稀疏的奖励信号来优化策略。
关键组件准备:
import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import matplotlib.pyplot as plt env = gym.make('CartPole-v1') state_dim = env.observation_space.shape[0] action_dim = env.action_space.nActor-Critic架构巧妙结合了策略梯度(Actor)和价值函数(Critic)的优点:
- Actor:策略网络,负责根据当前状态选择动作
- Critic:价值网络,评估当前状态-动作对的质量
提示:CartPole-v1环境中,杆子保持直立每步获得+1奖励,最大步数为500。相比v0版本,v1的杆子更长,控制难度更高。
2. 网络架构设计
2.1 Actor网络实现
Actor网络输出的是动作的概率分布。对于CartPole这样的离散动作空间,我们使用softmax输出:
class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_size=128): super(Actor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, action_dim) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) x = torch.softmax(self.fc3(x), dim=-1) return x2.2 Critic网络实现
Critic网络评估状态-动作对的价值,指导Actor的更新方向:
class Critic(nn.Module): def __init__(self, state_dim, hidden_size=128): super(Critic, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) value = self.fc3(x) return value网络设计要点对比:
| 组件 | 输入维度 | 输出维度 | 激活函数 | 作用 |
|---|---|---|---|---|
| Actor | 状态维度 | 动作维度 | Softmax | 生成动作概率 |
| Critic | 状态维度 | 1 | 无 | 评估状态价值 |
3. 训练流程实现
3.1 数据收集与预处理
我们采用在线更新的方式,实时收集轨迹数据:
def collect_trajectory(env, actor, max_steps=1000): states, actions, rewards, next_states, dones = [], [], [], [], [] state = env.reset() for _ in range(max_steps): state_tensor = torch.FloatTensor(state).unsqueeze(0) action_probs = actor(state_tensor) action = torch.multinomial(action_probs, 1).item() next_state, reward, done, _ = env.step(action) states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(done) state = next_state if done: break return states, actions, rewards, next_states, dones3.2 核心训练循环
完整的Actor-Critic更新包含三个关键步骤:
- 计算TD误差
- 更新Critic网络
- 更新Actor网络
def train(env, actor, critic, actor_optimizer, critic_optimizer, gamma=0.99, epochs=1000): reward_history = [] for epoch in range(epochs): # 收集数据 states, actions, rewards, next_states, dones = collect_trajectory(env, actor) # 转换为张量 states = torch.FloatTensor(states) actions = torch.LongTensor(actions).unsqueeze(1) rewards = torch.FloatTensor(rewards).unsqueeze(1) next_states = torch.FloatTensor(next_states) dones = torch.FloatTensor(dones).unsqueeze(1) # 计算TD目标 with torch.no_grad(): next_values = critic(next_states) td_targets = rewards + gamma * next_values * (1 - dones) # 更新Critic values = critic(states) critic_loss = nn.MSELoss()(values, td_targets) critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() # 更新Actor action_probs = actor(states) selected_probs = action_probs.gather(1, actions) advantages = td_targets - values.detach() actor_loss = -(torch.log(selected_probs) * advantages).mean() actor_optimizer.zero_grad() actor_loss.backward() actor_optimizer.step() # 记录结果 total_reward = sum(rewards) reward_history.append(total_reward) if epoch % 50 == 0: print(f"Epoch {epoch}, Reward: {total_reward}") return reward_history注意:advantages的计算使用了Critic网络的评估值,但通过detach()切断了梯度传播,避免影响Critic的训练。
4. 超参数调优与训练技巧
4.1 关键超参数设置
经过多次实验验证,以下参数组合在CartPole环境中表现良好:
hyperparams = { 'hidden_size': 64, # 网络隐藏层维度 'gamma': 0.99, # 折扣因子 'actor_lr': 0.001, # Actor学习率 'critic_lr': 0.005, # Critic学习率(通常设置更大) 'epochs': 500 # 训练轮数 }学习率设置原理:
- Critic需要更快收敛以提供准确的评估
- Actor更新步长应较小以保证策略稳定改进
4.2 训练过程监控
实现实时渲染和奖励曲线绘制,直观观察训练进展:
def plot_rewards(reward_history, window_size=50): moving_avg = [] for i in range(len(reward_history) - window_size + 1): window = reward_history[i:i+window_size] moving_avg.append(sum(window) / window_size) plt.plot(reward_history, alpha=0.3, label='Raw') plt.plot(moving_avg, label=f'Moving Avg ({window_size} eps)') plt.xlabel('Episode') plt.ylabel('Total Reward') plt.legend() plt.show()4.3 常见问题排查
训练不稳定问题解决方案:
奖励不增长:
- 检查网络结构是否足够复杂
- 尝试增大折扣因子gamma
- 调整学习率组合
策略过早收敛:
- 引入熵正则项鼓励探索
- 在损失函数中添加:
entropy = (probs * torch.log(probs)).sum(1).mean()
Critic估值偏差:
- 使用目标网络稳定训练
- 实现经验回放缓冲(需注意同策略限制)
# 熵正则化示例 def actor_loss_with_entropy(probs, actions, advantages, beta=0.01): selected_probs = probs.gather(1, actions) policy_loss = -(torch.log(selected_probs) * advantages).mean() entropy = (probs * torch.log(probs)).sum(1).mean() return policy_loss - beta * entropy5. 进阶优化与扩展
5.1 目标网络实现
为Critic添加目标网络可以显著提升训练稳定性:
class ActorCritic: def __init__(self, state_dim, action_dim): self.actor = Actor(state_dim, action_dim) self.critic = Critic(state_dim) self.target_critic = Critic(state_dim) self.update_target(tau=1.0) # 初始完全同步 def update_target(self, tau=0.01): """软更新目标网络""" for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()): target_param.data.copy_(tau*param.data + (1-tau)*target_param.data)5.2 多步TD学习
单步TD更新可能引入较大偏差,实现n步TD改进:
def compute_nstep_td(rewards, next_values, dones, gamma, n_step=5): n_step_returns = [] R = 0 for i in reversed(range(len(rewards))): R = rewards[i] + gamma * R * (1 - dones[i]) if i + n_step < len(rewards): R = R - (gamma**n_step) * rewards[i+n_step] n_step_returns.insert(0, R) returns = torch.FloatTensor(n_step_returns).unsqueeze(1) with torch.no_grad(): next_n_values = next_values[-len(returns):] td_targets = returns + (gamma**n_step) * next_n_values * (1 - dones[-len(returns):]) return td_targets5.3 并行环境采样
加速数据收集的终极方案是使用多环境并行采样:
from multiprocessing import Process, Queue def worker(env_name, actor, queue, num_episodes): env = gym.make(env_name) for _ in range(num_episodes): states, actions, rewards, next_states, dones = collect_trajectory(env, actor) queue.put((states, actions, rewards, next_states, dones)) queue.put(None) # 结束信号 def parallel_collect(env_name, actor, num_workers=4, episodes_per_worker=2): queue = Queue() workers = [] for _ in range(num_workers): p = Process(target=worker, args=(env_name, actor, queue, episodes_per_worker)) p.start() workers.append(p) trajectories = [] finished_workers = 0 while finished_workers < num_workers: data = queue.get() if data is None: finished_workers += 1 else: trajectories.append(data) for p in workers: p.join() return trajectories在实际项目中,我发现并行采样能显著提升训练效率,特别是在环境交互耗时较长的情况下。一个实用的技巧是将不同worker的探索率设置为略有差异的值,这样可以增加样本的多样性。