news 2026/4/16 16:42:54

【LLaVA-NeXT】LLaVATrainer说明

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【LLaVA-NeXT】LLaVATrainer说明

LLaVATrainer

classllava.train.llava_trainer.LLaVATrainer(Trainer)

用于训练 LLaVA (Large Language and Vision Assistant) 多模态模型的训练器类,继承自transformers.Trainer

该类在标准 Transformer Trainer 基础上扩展了以下功能:

  • 支持MeZO (Memory-efficient Zeroth-Order Optimization)零阶优化训练模式
  • 提供多种基于长度和模态的数据采样策略
  • 支持DeepSpeedFSDP分布式训练
  • 提供针对多模态适配器 (MM Adapter) 的特定检查点保存功能

参数

该类接受所有transformers.Trainer支持的关键字参数,同时支持以下额外参数(通过args传入):

参数类型默认值描述
trainer_modestr"regular"训练模式。可选"regular"(常规反向传播训练)或"zo"(MeZO 零阶优化训练)。
zo_epsfloat1e-3MeZO 超参数 epsilon,控制参数扰动的幅度。
zo_num_directionsint1MeZO 优化中使用的随机方向数量。
group_by_lengthboolFalse是否按序列长度分组采样。
group_by_modality_lengthboolFalse是否按模态长度分组采样。
group_by_modality_length_autoboolFalse是否使用自动模态长度分组采样。
group_by_varlenboolFalse是否使用可变长度分组采样。
mm_projector_lrfloat,optionalNone多模态投影层的独立学习率。
mm_vision_tower_lrfloat,optionalNone视觉编码器的独立学习率。

属性

属性类型描述
trainer_modestr当前训练模式("regular""zo")。
zo_epsfloatMeZO epsilon 超参数。
zo_num_directionsintMeZO 随机方向数量。
trainable_paramsList[Tuple[str, Parameter]]可训练参数列表,包含参数名称和参数本身。
mezo_update_historyList[Dict]MeZO 更新历史记录,用于检查点恢复。

方法

zo_perturb_parameters

zo_perturb_parameters(scaling_factor:float=1.0)->None

使用随机向量z zz扰动模型参数。

参数:

  • scaling_factor(float) – 扰动的缩放因子。正值表示正向扰动,负值表示反向扰动。

示例:

# 正向扰动trainer.zo_perturb_parameters(scaling_factor=1.0)# 反向扰动(恢复原始参数后再扰动)trainer.zo_perturb_parameters(scaling_factor=-2.0)

zo_forward

zo_forward(model:nn.Module,inputs:Dict)->torch.Tensor

在推理模式下计算前向传播损失。

参数:

  • model(nn.Module) – 需要计算损失的模型。
  • inputs(Dict) – 输入批次数据。

返回:

  • torch.Tensor– 计算得到的损失值(已 detach)。

zo_step

zo_step(model:nn.Module,inputs:Dict)->torch.Tensor

使用 MeZO 算法执行单步梯度估计。通过正向和反向扰动的损失差来近似梯度。

参数:

  • model(nn.Module) – 模型实例。
  • inputs(Dict) – 输入批次数据。

返回:

  • torch.Tensor– 归一化后的损失值。

注意事项:

该方法在gradient_accumulation_steps期间累积多个方向的梯度估计,在zo_update中统一应用。


zo_update

zo_update(learning_rate:float)->None

根据累积的梯度估计更新模型参数。

参数:

  • learning_rate(float) – 当前学习率。

注意事项:

  • 该方法自动处理 weight decay
  • biaslayer_normlayernorm参数不会应用 weight decay
  • 调用后会清空累积的梯度估计

save_model

save_model(output_dir:Optional[str]=None,_internal_call:bool=False)

保存模型检查点。当使用 MeZO 模式时,会额外保存轻量级的 MeZO 状态检查点。

参数:

  • output_dir(str,optional) – 保存路径。默认使用args.output_dir
  • _internal_call(bool) – 是否为内部调用。

_save_checkpoint

_save_checkpoint(model,trial,metrics=None)->None

保存训练检查点。该方法重写了父类的检查点保存逻辑,以支持仅保存多模态适配器 (MM Adapter) 权重的场景。

参数:

  • model– 需要保存的模型实例。
  • trial– 超参数搜索试验对象(用于确定输出目录)。
  • metrics(Dict,optional) – 评估指标字典。

行为说明:

当满足以下任一条件时,仅保存适配器权重:

  • args.tune_mm_mlp_adapter=True
  • args.mm_tunable_parts仅包含"mm_mlp_adapter""mm_vision_resampler"

在这种情况下,会保存:

  • 模型配置文件 (config.json)
  • 适配器权重文件 (mm_projector.bin)

保存的权重包括:

  • mm_projector相关参数
  • vision_resampler相关参数
  • 如果use_im_start_end=True,还包括embed_tokensembed_in

其他情况下,调用父类Trainer._save_checkpoint()进行完整模型保存。

注意事项:

  • 该方法支持DeepSpeed ZeRO-3模式,会正确收集分布在多个 GPU 上的参数
  • 仅在主进程(local_rank == 0local_rank == -1)上执行实际的保存操作

示例:

# 仅微调 MM Adapter 时的配置training_args.tune_mm_mlp_adapter=True# 或者通过 mm_tunable_parts 指定training_args.mm_tunable_parts="mm_mlp_adapter"# 训练过程中的检查点将只包含适配器权重# 保存路径示例: output_dir/checkpoint-1000/mm_projector.bin

create_optimizer

create_optimizer()->torch.optim.Optimizer

创建优化器。支持为不同模块设置独立学习率(如mm_projectorvision_tower)。

返回:

  • torch.optim.Optimizer– 配置好的优化器实例。

注意事项:

在 MeZO 模式下,会创建一个虚拟优化器(dummy optimizer),实际参数更新由zo_update方法执行。


get_train_dataloader

get_train_dataloader()->DataLoader

创建并返回训练数据加载器。

返回:

  • torch.utils.data.DataLoader– 训练数据加载器。

示例

基本使用

fromllava.train.llava_trainerimportLLaVATrainerfromtransformersimportTrainingArguments# 配置训练参数training_args=TrainingArguments(output_dir="./output",per_device_train_batch_size=4,gradient_accumulation_steps=8,learning_rate=2e-5,num_train_epochs=1,group_by_modality_length=True,# 启用模态长度分组)# 创建训练器trainer=LLaVATrainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=train_dataset,data_collator=data_collator,)# 开始训练trainer.train()# 保存模型trainer.save_model("./final_model")

使用 MeZO 训练模式

fromllava.train.llava_trainerimportLLaVATrainer# 配置 MeZO 相关参数training_args.trainer_mode="zo"training_args.zo_eps=1e-3training_args.zo_num_directions=1# 创建训练器trainer=LLaVATrainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=train_dataset,)# MeZO 模式训练trainer.train()

设置模块独立学习率

# 为多模态投影层和视觉编码器设置独立学习率training_args.mm_projector_lr=1e-4training_args.mm_vision_tower_lr=2e-6trainer=LLaVATrainer(model=model,args=training_args,...)

参见

  • transformers.Trainer– 基类文档
  • LengthGroupedSampler– 长度分组采样器
  • LLaVADPOTrainer– 用于 DPO (Direct Preference Optimization) 训练的变体
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/16 12:18:17

Java毕设选题推荐:基于springboot的医药配药管理系统【附源码、mysql、文档、调试+代码讲解+全bao等】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/4/16 13:37:15

用AI写测试用例?这5个提示词模板让你效率翻倍

一、AI重构测试生产力:从耗时手工到精准自动化 在持续交付成为行业标配的今天,测试工程师面临用例设计耗时与覆盖率不足的双重压力。传统手工编写用例模式下,一个中级工程师完成核心功能测试需30-40分钟,而AI辅助可将此过程压缩至…

作者头像 李华
网站建设 2026/4/16 12:27:56

5个月学习GIS开发计划:带你揭秘特训营都在学习哪些内容?

第一阶段:Web开发入门 主要学习web前端三件套,能手动制作一些静态、动态的网页效果。网页中每一个地图界面、每一个弹窗、每一个交互面板,都是由HTML和CSS构建的;此外,地图的缩放、平移、点击查询、图层切换等所有交互…

作者头像 李华
网站建设 2026/4/16 12:58:11

springboot_ssm847儿童福利院管理系统ssm

目录具体实现截图儿童福利院管理系统(SpringBootSSM框架)摘要系统所用技术介绍写作提纲源码文档获取/同行可拿货,招校园代理 :文章底部获取博主联系方式!具体实现截图 儿童福利院管理系统(SpringBootSSM框架&#xff…

作者头像 李华
网站建设 2026/4/15 20:23:00

直面Oracle国产化替代的典型陷阱与攻坚策略

Oracle数据库迁移实战 KingbaseES 集成了丰富的 Oracle 兼容特性,这在实际迁移场景中通常只需对原导出脚本进行少量调整,甚至在全功能兼容时无需修改。此外,系统还支持使用 KDTS、KFS 等多种辅助工具,进一步简化迁移流程。 本节…

作者头像 李华
网站建设 2026/4/16 14:04:11

同一篇论文,维普AI率67%→9%,我是怎么做到的

维普AIGC检测高?6款工具帮你降到合格线 TL;DR:维普AIGC检测算法和知网不同,很多知网能过的工具在维普可能过不了。实测对维普效果最好的是嘎嘎降AI(67%→9%),其次是比话降AI(60%→12%&#xff0…

作者头像 李华