LLaVATrainer
classllava.train.llava_trainer.LLaVATrainer(Trainer)用于训练 LLaVA (Large Language and Vision Assistant) 多模态模型的训练器类,继承自transformers.Trainer。
该类在标准 Transformer Trainer 基础上扩展了以下功能:
- 支持MeZO (Memory-efficient Zeroth-Order Optimization)零阶优化训练模式
- 提供多种基于长度和模态的数据采样策略
- 支持DeepSpeed和FSDP分布式训练
- 提供针对多模态适配器 (MM Adapter) 的特定检查点保存功能
参数
该类接受所有transformers.Trainer支持的关键字参数,同时支持以下额外参数(通过args传入):
| 参数 | 类型 | 默认值 | 描述 |
|---|---|---|---|
trainer_mode | str | "regular" | 训练模式。可选"regular"(常规反向传播训练)或"zo"(MeZO 零阶优化训练)。 |
zo_eps | float | 1e-3 | MeZO 超参数 epsilon,控制参数扰动的幅度。 |
zo_num_directions | int | 1 | MeZO 优化中使用的随机方向数量。 |
group_by_length | bool | False | 是否按序列长度分组采样。 |
group_by_modality_length | bool | False | 是否按模态长度分组采样。 |
group_by_modality_length_auto | bool | False | 是否使用自动模态长度分组采样。 |
group_by_varlen | bool | False | 是否使用可变长度分组采样。 |
mm_projector_lr | float,optional | None | 多模态投影层的独立学习率。 |
mm_vision_tower_lr | float,optional | None | 视觉编码器的独立学习率。 |
属性
| 属性 | 类型 | 描述 |
|---|---|---|
trainer_mode | str | 当前训练模式("regular"或"zo")。 |
zo_eps | float | MeZO epsilon 超参数。 |
zo_num_directions | int | MeZO 随机方向数量。 |
trainable_params | List[Tuple[str, Parameter]] | 可训练参数列表,包含参数名称和参数本身。 |
mezo_update_history | List[Dict] | MeZO 更新历史记录,用于检查点恢复。 |
方法
zo_perturb_parameters
zo_perturb_parameters(scaling_factor:float=1.0)->None使用随机向量z zz扰动模型参数。
参数:
- scaling_factor(
float) – 扰动的缩放因子。正值表示正向扰动,负值表示反向扰动。
示例:
# 正向扰动trainer.zo_perturb_parameters(scaling_factor=1.0)# 反向扰动(恢复原始参数后再扰动)trainer.zo_perturb_parameters(scaling_factor=-2.0)zo_forward
zo_forward(model:nn.Module,inputs:Dict)->torch.Tensor在推理模式下计算前向传播损失。
参数:
- model(
nn.Module) – 需要计算损失的模型。 - inputs(
Dict) – 输入批次数据。
返回:
torch.Tensor– 计算得到的损失值(已 detach)。
zo_step
zo_step(model:nn.Module,inputs:Dict)->torch.Tensor使用 MeZO 算法执行单步梯度估计。通过正向和反向扰动的损失差来近似梯度。
参数:
- model(
nn.Module) – 模型实例。 - inputs(
Dict) – 输入批次数据。
返回:
torch.Tensor– 归一化后的损失值。
注意事项:
该方法在
gradient_accumulation_steps期间累积多个方向的梯度估计,在zo_update中统一应用。
zo_update
zo_update(learning_rate:float)->None根据累积的梯度估计更新模型参数。
参数:
- learning_rate(
float) – 当前学习率。
注意事项:
- 该方法自动处理 weight decay
bias、layer_norm和layernorm参数不会应用 weight decay- 调用后会清空累积的梯度估计
save_model
save_model(output_dir:Optional[str]=None,_internal_call:bool=False)保存模型检查点。当使用 MeZO 模式时,会额外保存轻量级的 MeZO 状态检查点。
参数:
- output_dir(
str,optional) – 保存路径。默认使用args.output_dir。 - _internal_call(
bool) – 是否为内部调用。
_save_checkpoint
_save_checkpoint(model,trial,metrics=None)->None保存训练检查点。该方法重写了父类的检查点保存逻辑,以支持仅保存多模态适配器 (MM Adapter) 权重的场景。
参数:
- model– 需要保存的模型实例。
- trial– 超参数搜索试验对象(用于确定输出目录)。
- metrics(
Dict,optional) – 评估指标字典。
行为说明:
当满足以下任一条件时,仅保存适配器权重:
args.tune_mm_mlp_adapter=Trueargs.mm_tunable_parts仅包含"mm_mlp_adapter"或"mm_vision_resampler"
在这种情况下,会保存:
- 模型配置文件 (
config.json) - 适配器权重文件 (
mm_projector.bin)
保存的权重包括:
mm_projector相关参数vision_resampler相关参数- 如果
use_im_start_end=True,还包括embed_tokens和embed_in
其他情况下,调用父类Trainer._save_checkpoint()进行完整模型保存。
注意事项:
- 该方法支持DeepSpeed ZeRO-3模式,会正确收集分布在多个 GPU 上的参数
- 仅在主进程(
local_rank == 0或local_rank == -1)上执行实际的保存操作
示例:
# 仅微调 MM Adapter 时的配置training_args.tune_mm_mlp_adapter=True# 或者通过 mm_tunable_parts 指定training_args.mm_tunable_parts="mm_mlp_adapter"# 训练过程中的检查点将只包含适配器权重# 保存路径示例: output_dir/checkpoint-1000/mm_projector.bincreate_optimizer
create_optimizer()->torch.optim.Optimizer创建优化器。支持为不同模块设置独立学习率(如mm_projector和vision_tower)。
返回:
torch.optim.Optimizer– 配置好的优化器实例。
注意事项:
在 MeZO 模式下,会创建一个虚拟优化器(dummy optimizer),实际参数更新由
zo_update方法执行。
get_train_dataloader
get_train_dataloader()->DataLoader创建并返回训练数据加载器。
返回:
torch.utils.data.DataLoader– 训练数据加载器。
示例
基本使用
fromllava.train.llava_trainerimportLLaVATrainerfromtransformersimportTrainingArguments# 配置训练参数training_args=TrainingArguments(output_dir="./output",per_device_train_batch_size=4,gradient_accumulation_steps=8,learning_rate=2e-5,num_train_epochs=1,group_by_modality_length=True,# 启用模态长度分组)# 创建训练器trainer=LLaVATrainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=train_dataset,data_collator=data_collator,)# 开始训练trainer.train()# 保存模型trainer.save_model("./final_model")使用 MeZO 训练模式
fromllava.train.llava_trainerimportLLaVATrainer# 配置 MeZO 相关参数training_args.trainer_mode="zo"training_args.zo_eps=1e-3training_args.zo_num_directions=1# 创建训练器trainer=LLaVATrainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=train_dataset,)# MeZO 模式训练trainer.train()设置模块独立学习率
# 为多模态投影层和视觉编码器设置独立学习率training_args.mm_projector_lr=1e-4training_args.mm_vision_tower_lr=2e-6trainer=LLaVATrainer(model=model,args=training_args,...)参见
transformers.Trainer– 基类文档LengthGroupedSampler– 长度分组采样器LLaVADPOTrainer– 用于 DPO (Direct Preference Optimization) 训练的变体