news 2026/4/16 7:21:45

【推荐系统】深度学习训练框架(十六):模型并行——推荐系统的TorchRec和大语言模型的FSDP(Fully Sharded Data Parallel)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【推荐系统】深度学习训练框架(十六):模型并行——推荐系统的TorchRec和大语言模型的FSDP(Fully Sharded Data Parallel)

📦 第一部分:TorchRec 实战教程

TorchRec是PyTorch的领域库,专为大规模推荐系统设计。其核心是解决超大规模嵌入表在多GPU/多节点上的高效训练问题。

1. 安装与环境配置
首先安装TorchRec及其依赖。推荐使用CUDA环境以获得最佳性能。

# 1. 安装对应CUDA版本的PyTorch (以CUDA 12.1为例)pipinstalltorch --index-url https://download.pytorch.org/whl/cu121# 2. 安装FBGEMM GPU版本和TorchRecpipinstallfbgemm-gpu --index-url https://download.pytorch.org/whl/cu121 pipinstalltorchrec --index-url https://download.pytorch.org/whl/cu121# 3. 如果是纯CPU环境(性能较低)# pip uninstall fbgemm-gpu -y# pip install fbgemm-gpu-cpu# pip install torchrec

2. 核心概念:分片与并行
理解以下两个关键模块是使用TorchRec的基础:

  • 分片器(Sharder):定义如何将巨大的嵌入表切割并分布到不同设备上。TorchRec支持多种分片策略,如按行(row_wise)、按表(table_wise)等。
  • 分布式模型并行(DistributedModelParallel, DMP):这是TorchRec最核心的高级API。它类似于PyTorch的DistributedDataParallel,但专为封装已分片的稀疏模型部分(嵌入表)和稠密模型部分(如MLP)而设计。

3. 实战:构建一个分布式推荐模型
下面通过一个简化的代码示例,展示如何使用TorchRec的关键组件。

importtorchimporttorch.nnasnnfromtorchrec.distributedimportDistributedModelParallelfromtorchrec.distributed.plannerimportEmbeddingShardingPlannerfromtorchrec.modules.embedding_configsimportEmbeddingBagConfigfromtorchrec.modules.embedding_modulesimportEmbeddingBagCollectionfromtorchrec.distributed.model_parallelimport(get_default_sharders,)# 1. 定义模型(以最简单的稠密-稀疏交互为例)classSimpleRecModel(nn.Module):def__init__(self,embedding_bag_collection):super().__init__()self.ebc=embedding_bag_collection# 假设稀疏特征维度总和为512self.dense=nn.Linear(512,1)defforward(self,sparse_features):embeddings=self.ebc(sparse_features)# 获得稀疏特征嵌入向量concatenated=torch.cat([embforembinembeddings.values()],dim=1)returnself.dense(concatenated)# 2. 初始化分布式环境(必须在代码最开头)importtorch.distributedasdist dist.init_process_group(backend="nccl")# GPU用NCCL,CPU用gloolocal_rank=int(os.environ["LOCAL_RANK"])device=torch.device(f"cuda:{local_rank}")# 3. 配置嵌入表embedding_configs=[EmbeddingBagConfig(name="user",embedding_dim=128,num_embeddings=100_0000,# 一百万用户feature_names=["user_feature"],),EmbeddingBagConfig(name="item",embedding_dim=128,num_embeddings=50_0000,# 五十万物品feature_names=["item_feature"],),]# 4. 在CPU上实例化模型(重要!DMP会处理设备移动)ebc=EmbeddingBagCollection(tables=embedding_configs,device=torch.device("cpu"))model=SimpleRecModel(ebc)# 5. 使用分布式模型并行(DMP)包装模型# get_default_sharders() 提供了适用于常见嵌入模块的分片器model=DistributedModelParallel(module=model,device=device,sharders=get_default_sharders(),# planner=EmbeddingShardingPlanner() # 可选的自动规划器,用于生成优化的分片计划)# 6. 定义优化器(TorchRec的优化器支持稀疏更新,高效处理嵌入梯度)fromtorchrec.optimimportapply_optimizer_in_backwardfromtorch.optimimportSGD# 为嵌入参数设置稀疏优化器apply_optimizer_in_backward(SGD,model.module.ebc.parameters(),{"lr":0.1})# 为稠密参数设置标准优化器dense_optimizer=SGD(model.module.dense.parameters(),lr=0.01)# 此后,在训练循环中,前向传播、反向传播和优化器步骤与非分布式模型基本一致。# DMP会自动处理跨设备的梯度同步和稀疏参数的更新。

⚙️ 第二部分:FSDP 快速指南

FSDP是PyTorch原生的分布式训练策略,核心思想是将模型的参数、梯度和优化器状态全部分片存储,在需要时再通过通信收集,从而极大节省单卡显存。

1. 基本使用模式
以下是使用FSDP包装一个Transformer模型的典型代码:

importtorchimporttorch.nnasnnfromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDPfromtorch.distributed.fsdp.wrapimportdefault_auto_wrap_policy# 1. 初始化分布式环境 (同上,略)# ...# 2. 定义一个大的模型 (例如Transformer)model=nn.Transformer(d_model=2048,nhead=16,num_encoder_layers=12,num_decoder_layers=12)# 3. 定义自动包装策略(按子模块分片)my_auto_wrap_policy=default_auto_wrap_policy(transformer_layer_cls={nn.TransformerEncoderLayer,nn.TransformerDecoderLayer})# 4. 用FSDP包装模型fsdp_model=FSDP(model,auto_wrap_policy=my_auto_wrap_policy,device_id=torch.cuda.current_device(),)# 5. 定义优化器(FSDP会自动处理优化器状态分片)optimizer=torch.optim.Adam(fsdp_model.parameters(),lr=1e-4)# 6. 训练循环与非分布式模型一致# FSDP会在前向传播时透明地收集所需参数,并在反向传播后同步梯度和更新分片。

🤔 第三部分:TorchRec 与 FSDP 核心对比

这两种技术都是为“大模型”设计,但目标完全不同。下表清晰地展示了两者的区别:

对比维度TorchRecFSDP (Fully Sharded Data Parallel)
核心目标专门用于大规模推荐系统,解决稀疏嵌入表的并行训练。通用的大规模稠密模型(如LLM、CV大模型)训练,解决参数显存瓶颈。
主要并行范式混合并行:嵌入表常采用模型并行/张量并行切分,稠密部分使用数据并行。增强的数据并行:在数据并行的基础上,对参数、梯度、优化器状态进行分片
优化核心嵌入表的分片策略(行、列、表),以及稀疏梯度的高效聚合与更新通信与计算的重叠,以及分片策略(全分片、混合分片)以平衡显存和通信开销。
关键优势1. 原生支持超大规模嵌入(十亿/万亿级)。
2. 为推荐系统提供专用原语(如EmbeddingBagCollection)。
3. 优化器支持稀疏更新,计算高效。
1.通用性强,几乎适用于任何PyTorch模型。
2.显存节省显著,是训练千亿参数大模型的标配技术
3.与PyTorch生态无缝集成
典型应用场景电商推荐、广告点击率(CTR)预估、社交网络推荐等具有海量稀疏特征的场景。大语言模型(LLM)预训练与微调、大规模视觉模型训练、稠密科学计算模型。
关系互补。一个复杂模型可同时使用两者:其稀疏嵌入部分用TorchRec分片,而稠密神经网络部分用FSDP分片

💡 第四部分:如何选择与后续建议

如何选择:

  • 如果你的模型核心是处理用户ID、商品ID等海量离散特征,嵌入表参数占模型绝大部分,请直接选择TorchRec
  • 如果你的模型是Transformer、ResNet等稠密结构,参数巨大但并非稀疏特征,应选择FSDP
  • 对于混合模型(大嵌入表+大稠密网络),可以研究组合使用两者
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/16 7:20:52

Dify Custom Tool 调用超时问题排查与解决方案(claude-4.5-opus-high)

在使用 Dify 的 Custom Tool(自定义工具)功能调用外部 API 时,你是否遇到过这样的问题: 工具调用反复重试,日志中出现多次相同请求API 明明执行成功了,但 Dify 显示超时失败复杂的 AI 处理流程总是在中途断…

作者头像 李华
网站建设 2026/4/14 17:40:25

day123—二分查找—H 指数 II(LeetCode-275)

题目描述 给你一个整数数组 citations ,其中 citations[i] 表示研究者的第 i 篇论文被引用的次数,citations 已经按照 非降序排列 。计算并返回该研究者的 h 指数。 h 指数的定义:h 代表“高引用次数”(high citations&#xff…

作者头像 李华
网站建设 2026/4/12 0:17:13

从零搭建VSCode量子作业监控面板:3小时快速上手教程,错过等于落伍

第一章:VSCode 的量子作业监控面板在现代量子计算开发中,可视化与实时监控是提升调试效率的关键。VSCode 通过扩展插件架构,支持集成定制化的量子作业监控面板,使开发者能够在编码环境中直接观察量子电路执行状态、资源分配及任务…

作者头像 李华
网站建设 2026/4/16 7:20:45

【收藏必备】2023年大模型转型完全指南:从零入门到就业的全方位攻略

这篇文章提供了大模型领域从零到就业的全面转型攻略,包括明确职业方向、掌握基础知识、深入学习大模型技术、参与实践项目、加入开源社区、利用学习资源以及职业发展建议等内容。文章不仅提供了技术学习路径,还包含了职业规划和持续学习的方法&#xff0…

作者头像 李华
网站建设 2026/4/13 3:43:22

基于大数据挖掘技术的台风灾害预测系统(毕业设计项目源码+文档)

课题摘要 基于大数据挖掘技术的台风灾害预测系统,直击台风灾害防控 “数据来源分散、预测模型单一、预警响应滞后” 的核心痛点,依托 HadoopSparkTensorFlow 大数据挖掘技术体系,构建 “多源数据融合 智能模型预测 可视化预警赋能” 的一体…

作者头像 李华
网站建设 2026/4/9 8:39:51

车载通信测试60天学习计划:Day5 核心知识点(纯干货)

一、车载诊断核心协议:DoIP与UDS(岗位核心技能)1. DoIP协议基础(诊断通信-over-IP)(1)核心定位与价值DoIP(Diagnostic over IP)是基于以太网的诊断协议,替代传…

作者头像 李华