news 2026/5/12 10:59:18

告别梯度下降的‘之’字路:用Python手把手实现共轭梯度法(CG)求解Ax=b

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
告别梯度下降的‘之’字路:用Python手把手实现共轭梯度法(CG)求解Ax=b

告别梯度下降的“之”字路:用Python手把手实现共轭梯度法(CG)求解Ax=b

在机器学习和科学计算领域,求解线性方程组Ax=b是一个基础但至关重要的问题。传统梯度下降法虽然简单易懂,但其"之"字形收敛路径常常导致迭代效率低下。本文将带你用Python从零实现共轭梯度法(Conjugate Gradient, CG),这种优雅的算法能在最多n步内找到n维问题的精确解(理论上),特别适合处理大规模稀疏矩阵问题。

1. 为什么需要共轭梯度法?

1.1 梯度下降的困境

梯度下降法在优化二次函数ϕ(x)=1/2xᵀAx - bᵀx时,每次迭代沿着当前点的梯度方向前进。这种策略存在两个明显缺陷:

  • 之字形路径:相邻迭代方向正交,导致收敛路径曲折
  • 收敛速度慢:迭代次数与矩阵A的条件数κ(A)成正比
# 梯度下降法简单实现 def gradient_descent(A, b, x0, max_iter=100): x = x0.copy() history = [x0] for _ in range(max_iter): r = A @ x - b # 计算残差(即梯度) alpha = (r.T @ r) / (r.T @ A @ r) # 最优步长 x = x - alpha * r history.append(x) return x, history

1.2 共轭梯度法的优势

CG方法通过精心选择的共轭方向,解决了梯度下降的核心痛点:

  • 共轭方向:pᵢᵀApⱼ=0 (i≠j),确保每个方向上的优化互不干扰
  • 有限步收敛:对于n维问题,最多n步即可得到精确解
  • 内存友好:只需存储几个向量,适合大规模问题

实际应用中,由于数值误差,CG通常作为迭代方法使用,在远小于n步时就能获得满意解。

2. 共轭梯度法原理剖析

2.1 算法核心思想

CG方法的神奇之处在于它动态构建共轭方向。每个新方向pₖ是当前残差(负梯度)与前一个共轭方向的线性组合:

pₖ = -rₖ + βₖpₖ₋₁

其中βₖ的选择要保证pₖ与pₖ₋₁关于A共轭。常用的计算方式有:

  • Fletcher-Reeves:βₖ = (rₖᵀrₖ)/(rₖ₋₁ᵀrₖ₋₁)
  • Polak-Ribière:βₖ = [rₖᵀ(rₖ - rₖ₋₁)]/(rₖ₋₁ᵀrₖ₋₁)

2.2 算法步骤分解

完整的线性CG算法包含以下关键步骤:

  1. 初始化:x₀, r₀ = b - Ax₀, p₀ = r₀
  2. 迭代直到收敛:
    • 计算步长:αₖ = (rₖᵀrₖ)/(pₖᵀApₖ)
    • 更新解:xₖ₊₁ = xₖ + αₖpₖ
    • 更新残差:rₖ₊₁ = rₖ - αₖApₖ
    • 计算βₖ₊₁ = (rₖ₊₁ᵀrₖ₊₁)/(rₖᵀrₖ)
    • 更新方向:pₖ₊₁ = rₖ₊₁ + βₖ₊₁pₖ

注意:CG要求系数矩阵A对称正定。对于非对称问题,可考虑广义最小残差法(GMRES)等替代方案。

3. Python实现与可视化对比

3.1 完整CG算法实现

import numpy as np import matplotlib.pyplot as plt def conjugate_gradient(A, b, x0, max_iter=None, tol=1e-6): if max_iter is None: max_iter = len(b) x = x0.copy() r = b - A @ x p = r.copy() rs_old = r.T @ r history = [x0] residuals = [np.sqrt(rs_old)] for i in range(max_iter): Ap = A @ p alpha = rs_old / (p.T @ Ap) x = x + alpha * p r = r - alpha * Ap rs_new = r.T @ r residuals.append(np.sqrt(rs_new)) if np.sqrt(rs_new) < tol: break beta = rs_new / rs_old p = r + beta * p rs_old = rs_new history.append(x.copy()) return x, history, residuals

3.2 与梯度下降的对比实验

让我们构造一个二维问题直观比较两种方法:

# 构造测试问题 A = np.array([[3, 2], [2, 6]]) b = np.array([2, -8]) x0 = np.array([-2, -2]) # 运行两种算法 x_gd, hist_gd = gradient_descent(A, b, x0, max_iter=20) x_cg, hist_cg, res_cg = conjugate_gradient(A, b, x0, max_iter=5) # 绘制收敛路径 def plot_contour(): x = np.linspace(-2.5, 1.5, 100) y = np.linspace(-3.5, 0.5, 100) X, Y = np.meshgrid(x, y) Z = 0.5*(3*X**2 + 4*X*Y + 6*Y**2) - 2*X + 8*Y plt.figure(figsize=(10, 6)) plt.contour(X, Y, Z, levels=20) plt.plot(*zip(*hist_gd), 'o-', label='Gradient Descent') plt.plot(*zip(*hist_cg), 's-', label='Conjugate Gradient') plt.legend() plt.xlabel('x1') plt.ylabel('x2') plt.title('Optimization Paths Comparison') plt.show() plot_contour()

从可视化结果可以清晰看到:

  • 梯度下降呈现典型的"之"字形路径
  • CG方法几乎沿直线快速收敛到最优解

4. 工程实践中的关键技巧

4.1 预处理技术

当矩阵A条件数较大时,CG收敛速度会显著下降。预处理技术通过引入预处理矩阵M≈A⁻¹来改善收敛性:

预处理类型构造方法适用场景
雅可比预处理M = diag(A)⁻¹对角占优矩阵
不完全CholeskyM = LLᵀ ≈ A稀疏矩阵
多项式预处理M = p(A)特定特征值分布

预处理CG算法只需修改残差计算为zₖ = M⁻¹rₖ,其余步骤保持不变。

4.2 实用终止条件

实际实现中应考虑以下收敛标准组合:

  • 相对残差:‖rₖ‖/‖b‖ < ε
  • 绝对残差:‖rₖ‖ < ε
  • 最大迭代次数:k ≥ max_iter
  • 停滞检测:|ϕ(xₖ)-ϕ(xₖ₋₁)| < δ
# 增强的终止条件实现 def should_stop(r, b, k, max_iter, tol=1e-6, rel_tol=1e-6, min_improve=1e-8): norm_r = np.linalg.norm(r) norm_b = np.linalg.norm(b) criteria = [ norm_r < tol, norm_r/norm_b < rel_tol, k >= max_iter, (k > 0 and norm_r < min_improve) ] return any(criteria)

4.3 非线性扩展

对于非线性优化问题min f(x),非线性CG方法只需做两点调整:

  1. 残差rₖ替换为梯度∇f(xₖ)
  2. 步长αₖ通过线搜索确定

常用的非线性CG变体包括:

  • Fletcher-Reeves (FR)
  • Polak-Ribière (PR)
  • Hestenes-Stiefel (HS)

PR方法通常表现最好,建议配合强Wolfe条件线搜索使用。

5. 性能优化与常见陷阱

5.1 计算效率优化

CG算法的主要计算开销在于矩阵-向量乘积Ap。针对不同矩阵类型可采用特定优化:

稀疏矩阵

from scipy.sparse import csr_matrix A_sparse = csr_matrix(A) Ap = A_sparse.dot(p) # 高效稀疏乘法

结构化矩阵

# 例如A是Laplacian矩阵 def laplacian_mult(p): # 利用矩阵结构实现快速乘法 return ... # 在CG中直接使用 Ap = laplacian_mult(p)

5.2 数值稳定性问题

CG理论上应在有限步收敛,但实际计算中可能遇到:

  • 舍入误差累积:导致方向向量失去共轭性
  • 重启策略:每k步重置pₖ = -rₖ
  • 混合精度计算:关键计算使用更高精度

实践中发现,当残差不再显著下降时重启CG通常能改善收敛性。

5.3 实用调试技巧

当CG表现异常时,检查以下方面:

  1. 矩阵对称性:验证A是否对称
  2. 正定性:检查A的最小特征值
  3. 预处理质量:评估M⁻¹A的条件数
  4. 残差行为:监控‖rₖ‖的下降曲线
# 对称性检查 def is_symmetric(A, tol=1e-8): return np.allclose(A, A.T, atol=tol) # 正定性检查 def is_positive_definite(A): try: np.linalg.cholesky(A) return True except np.linalg.LinAlgError: return False

在实现CG算法时,我经常遇到预处理选择不当导致收敛缓慢的情况。一个实用的经验是先用简单预处理(如雅可比)快速验证算法正确性,再尝试更复杂的预处理方案。对于特别困难的问题,结合Krylov子空间方法的多重网格预处理往往能带来惊喜。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/12 10:55:29

老股东腾讯跟投阶跃星辰新一轮融资,双方合作深化发力AI座舱Agent

《科创板日报》消息称&#xff0c;老股东腾讯跟投阶跃星辰新一轮融资。上周有消息透露阶跃将完成25亿美金融资并拆除红筹架构&#xff0c;加速赴港IPO。此前腾讯已连续跟投&#xff0c;且双方刚签署战略合作。腾讯跟投新一轮融资据《科创板日报》&#xff0c;老股东腾讯跟投了阶…

作者头像 李华
网站建设 2026/5/12 10:51:26

CentOS停服后,除了改仓库地址,你的vim和net-tools还能这样‘抢救’安装

CentOS停服后的运维生存指南&#xff1a;高效安装vim与net-tools的实战方案 当CentOS官方停止维护后&#xff0c;许多依赖其软件仓库的运维工作突然变得棘手起来。vim和net-tools这类基础工具无法通过常规方式安装&#xff0c;确实会让日常运维陷入困境。但别担心&#xff0c;这…

作者头像 李华
网站建设 2026/5/12 10:47:51

3分钟学会离线语音转文字:TMSpeech让你的会议记录不再遗漏

3分钟学会离线语音转文字&#xff1a;TMSpeech让你的会议记录不再遗漏 【免费下载链接】TMSpeech 腾讯会议摸鱼工具 项目地址: https://gitcode.com/gh_mirrors/tm/TMSpeech 你是否经常因为会议内容太多记不住而焦虑&#xff1f;是否担心网络语音识别会泄露你的隐私&…

作者头像 李华