news 2026/4/23 18:16:48

别光看理论了!用PyTorch手把手实现一个Actor-Critic玩CartPole(附完整代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别光看理论了!用PyTorch手把手实现一个Actor-Critic玩CartPole(附完整代码)

从零实现Actor-Critic:用PyTorch征服CartPole的实战指南

在强化学习领域,理论推导和代码实现之间往往存在巨大的鸿沟。许多学习者能够理解策略梯度定理的数学证明,却在面对具体实现时束手无策。本文将带你跨越这道鸿沟,使用PyTorch从零开始构建一个完整的Actor-Critic算法,并在经典的CartPole环境中验证其效果。

1. 环境搭建与核心概念

CartPole(倒立摆)是强化学习中最经典的测试环境之一。游戏目标是通过左右移动小车来保持顶部的杆子竖直不倒。这个看似简单的任务包含了强化学习的核心挑战:如何在连续的状态空间中进行决策,并通过稀疏的奖励信号来优化策略。

关键组件准备

import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import matplotlib.pyplot as plt env = gym.make('CartPole-v1') state_dim = env.observation_space.shape[0] action_dim = env.action_space.n

Actor-Critic架构巧妙结合了策略梯度(Actor)和价值函数(Critic)的优点:

  • Actor:策略网络,负责根据当前状态选择动作
  • Critic:价值网络,评估当前状态-动作对的质量

提示:CartPole-v1环境中,杆子保持直立每步获得+1奖励,最大步数为500。相比v0版本,v1的杆子更长,控制难度更高。

2. 网络架构设计

2.1 Actor网络实现

Actor网络输出的是动作的概率分布。对于CartPole这样的离散动作空间,我们使用softmax输出:

class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_size=128): super(Actor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, action_dim) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) x = torch.softmax(self.fc3(x), dim=-1) return x

2.2 Critic网络实现

Critic网络评估状态-动作对的价值,指导Actor的更新方向:

class Critic(nn.Module): def __init__(self, state_dim, hidden_size=128): super(Critic, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) value = self.fc3(x) return value

网络设计要点对比

组件输入维度输出维度激活函数作用
Actor状态维度动作维度Softmax生成动作概率
Critic状态维度1评估状态价值

3. 训练流程实现

3.1 数据收集与预处理

我们采用在线更新的方式,实时收集轨迹数据:

def collect_trajectory(env, actor, max_steps=1000): states, actions, rewards, next_states, dones = [], [], [], [], [] state = env.reset() for _ in range(max_steps): state_tensor = torch.FloatTensor(state).unsqueeze(0) action_probs = actor(state_tensor) action = torch.multinomial(action_probs, 1).item() next_state, reward, done, _ = env.step(action) states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(done) state = next_state if done: break return states, actions, rewards, next_states, dones

3.2 核心训练循环

完整的Actor-Critic更新包含三个关键步骤:

  1. 计算TD误差
  2. 更新Critic网络
  3. 更新Actor网络
def train(env, actor, critic, actor_optimizer, critic_optimizer, gamma=0.99, epochs=1000): reward_history = [] for epoch in range(epochs): # 收集数据 states, actions, rewards, next_states, dones = collect_trajectory(env, actor) # 转换为张量 states = torch.FloatTensor(states) actions = torch.LongTensor(actions).unsqueeze(1) rewards = torch.FloatTensor(rewards).unsqueeze(1) next_states = torch.FloatTensor(next_states) dones = torch.FloatTensor(dones).unsqueeze(1) # 计算TD目标 with torch.no_grad(): next_values = critic(next_states) td_targets = rewards + gamma * next_values * (1 - dones) # 更新Critic values = critic(states) critic_loss = nn.MSELoss()(values, td_targets) critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() # 更新Actor action_probs = actor(states) selected_probs = action_probs.gather(1, actions) advantages = td_targets - values.detach() actor_loss = -(torch.log(selected_probs) * advantages).mean() actor_optimizer.zero_grad() actor_loss.backward() actor_optimizer.step() # 记录结果 total_reward = sum(rewards) reward_history.append(total_reward) if epoch % 50 == 0: print(f"Epoch {epoch}, Reward: {total_reward}") return reward_history

注意:advantages的计算使用了Critic网络的评估值,但通过detach()切断了梯度传播,避免影响Critic的训练。

4. 超参数调优与训练技巧

4.1 关键超参数设置

经过多次实验验证,以下参数组合在CartPole环境中表现良好:

hyperparams = { 'hidden_size': 64, # 网络隐藏层维度 'gamma': 0.99, # 折扣因子 'actor_lr': 0.001, # Actor学习率 'critic_lr': 0.005, # Critic学习率(通常设置更大) 'epochs': 500 # 训练轮数 }

学习率设置原理

  • Critic需要更快收敛以提供准确的评估
  • Actor更新步长应较小以保证策略稳定改进

4.2 训练过程监控

实现实时渲染和奖励曲线绘制,直观观察训练进展:

def plot_rewards(reward_history, window_size=50): moving_avg = [] for i in range(len(reward_history) - window_size + 1): window = reward_history[i:i+window_size] moving_avg.append(sum(window) / window_size) plt.plot(reward_history, alpha=0.3, label='Raw') plt.plot(moving_avg, label=f'Moving Avg ({window_size} eps)') plt.xlabel('Episode') plt.ylabel('Total Reward') plt.legend() plt.show()

4.3 常见问题排查

训练不稳定问题解决方案

  1. 奖励不增长

    • 检查网络结构是否足够复杂
    • 尝试增大折扣因子gamma
    • 调整学习率组合
  2. 策略过早收敛

    • 引入熵正则项鼓励探索
    • 在损失函数中添加:entropy = (probs * torch.log(probs)).sum(1).mean()
  3. Critic估值偏差

    • 使用目标网络稳定训练
    • 实现经验回放缓冲(需注意同策略限制)
# 熵正则化示例 def actor_loss_with_entropy(probs, actions, advantages, beta=0.01): selected_probs = probs.gather(1, actions) policy_loss = -(torch.log(selected_probs) * advantages).mean() entropy = (probs * torch.log(probs)).sum(1).mean() return policy_loss - beta * entropy

5. 进阶优化与扩展

5.1 目标网络实现

为Critic添加目标网络可以显著提升训练稳定性:

class ActorCritic: def __init__(self, state_dim, action_dim): self.actor = Actor(state_dim, action_dim) self.critic = Critic(state_dim) self.target_critic = Critic(state_dim) self.update_target(tau=1.0) # 初始完全同步 def update_target(self, tau=0.01): """软更新目标网络""" for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()): target_param.data.copy_(tau*param.data + (1-tau)*target_param.data)

5.2 多步TD学习

单步TD更新可能引入较大偏差,实现n步TD改进:

def compute_nstep_td(rewards, next_values, dones, gamma, n_step=5): n_step_returns = [] R = 0 for i in reversed(range(len(rewards))): R = rewards[i] + gamma * R * (1 - dones[i]) if i + n_step < len(rewards): R = R - (gamma**n_step) * rewards[i+n_step] n_step_returns.insert(0, R) returns = torch.FloatTensor(n_step_returns).unsqueeze(1) with torch.no_grad(): next_n_values = next_values[-len(returns):] td_targets = returns + (gamma**n_step) * next_n_values * (1 - dones[-len(returns):]) return td_targets

5.3 并行环境采样

加速数据收集的终极方案是使用多环境并行采样:

from multiprocessing import Process, Queue def worker(env_name, actor, queue, num_episodes): env = gym.make(env_name) for _ in range(num_episodes): states, actions, rewards, next_states, dones = collect_trajectory(env, actor) queue.put((states, actions, rewards, next_states, dones)) queue.put(None) # 结束信号 def parallel_collect(env_name, actor, num_workers=4, episodes_per_worker=2): queue = Queue() workers = [] for _ in range(num_workers): p = Process(target=worker, args=(env_name, actor, queue, episodes_per_worker)) p.start() workers.append(p) trajectories = [] finished_workers = 0 while finished_workers < num_workers: data = queue.get() if data is None: finished_workers += 1 else: trajectories.append(data) for p in workers: p.join() return trajectories

在实际项目中,我发现并行采样能显著提升训练效率,特别是在环境交互耗时较长的情况下。一个实用的技巧是将不同worker的探索率设置为略有差异的值,这样可以增加样本的多样性。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/21 8:13:19

手把手教你用脚本自动化安装Nvidia驱动到Ubuntu实时内核

手把手教你用脚本自动化安装Nvidia驱动到Ubuntu实时内核 上周给实验室三台实时系统工作站部署Nvidia驱动时&#xff0c;发现每次手动操作都要重复近20个步骤&#xff0c;稍不留神就会在某个环节出错。于是花了两天时间封装了个全自动安装脚本&#xff0c;现在新机器部署时间从原…

作者头像 李华
网站建设 2026/4/21 13:26:06

国际化技术中的多语言本地化与文化适配

在全球化的数字时代&#xff0c;国际化技术已成为企业拓展市场的核心战略。多语言本地化与文化适配不仅是简单的文本翻译&#xff0c;更是跨越语言障碍、融入目标市场文化的关键过程。从跨国电商到社交媒体平台&#xff0c;如何让产品和服务被不同地区的用户自然接受&#xff1…

作者头像 李华
网站建设 2026/4/22 0:46:12

蓝牙HID实战:从零构建Android触控板,解锁多设备跨屏操控新姿势

1. 为什么需要Android蓝牙触控板&#xff1f; 每次看到抽屉里吃灰的旧手机&#xff0c;总觉得浪费了那块高清触摸屏。你有没有想过&#xff0c;其实只需要200行代码就能把它变成跨平台的无线触控板&#xff1f;我去年用一台退役的华为P30给工作室的三台电脑做共享触控板&#x…

作者头像 李华
网站建设 2026/4/22 9:04:16

别再死记硬背了!ROS开发者必备:rosbag record/play/info 高频命令速查手册(附常用场景组合)

ROS开发者效率手册&#xff1a;rosbag高阶场景化命令实战指南 在机器人开发流程中&#xff0c;数据采集与分析环节往往占据30%以上的调试时间。许多中高级ROS开发者虽然熟悉基础指令&#xff0c;却在复杂场景组合命令时频繁查阅文档。本文将彻底改变这种低效模式——我们不是简…

作者头像 李华