news 2026/4/18 0:03:13

手把手复现NeurIPS 2023 TIGER模型:从RQ-VAE量化语义ID到Transformer生成式召回全流程

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
手把手复现NeurIPS 2023 TIGER模型:从RQ-VAE量化语义ID到Transformer生成式召回全流程

从零实现TIGER模型:基于语义ID的生成式推荐系统实战指南

在推荐系统领域,传统双塔模型和协同过滤方法长期占据主导地位,但它们面临着冷启动、反馈循环和语义理解不足等固有挑战。Google Research在NeurIPS 2023提出的TIGER框架,通过结合残差量化VAE和Transformer,开创了生成式推荐的新范式。本文将带您从零开始完整复现这一前沿工作,重点解决三个核心问题:如何构建具有层级语义的物品ID?如何训练端到端的生成式推荐模型?以及如何在实际场景中应用这一创新架构?

1. 环境准备与数据预处理

复现TIGER模型需要搭建特定的技术栈。我们推荐使用Python 3.9+和JAX生态系统,这能充分发挥现代硬件加速的优势。以下是关键依赖的配置方案:

# 基础环境 pip install jax==0.4.13 jaxlib==0.4.13 # 确保CUDA版本匹配 pip install flax==0.7.4 t5x==0.9.3 # Transformer实现框架 pip install sentence-transformers==2.2.2 # 语义编码器

对于数据集选择,Amazon Product Data (Beauty类别)是个理想的起点。这个中等规模的数据集包含商品标题、描述、类别和用户交互序列,正好满足我们的需求。数据预处理需要特别注意三个关键转换:

  1. 会话序列构建:将原始点击流按用户分组,并按时间戳排序
  2. 文本特征整合:把商品标题、品牌和类别拼接成统一文本描述
  3. 交互序列截断:设置合理的最大序列长度(建议50-100)
import pandas as pd from collections import defaultdict def preprocess_interactions(df): user_sequences = defaultdict(list) for _, row in df.sort_values('timestamp').iterrows(): user_sequences[row['user_id']].append(row['item_id']) return { uid: seq[-100:] # 截断长序列 for uid, seq in user_sequences.items() if len(seq) >= 3 # 过滤短序列 }

提示:预处理阶段建议保留原始item_id到metadata的映射关系,后续语义ID生成阶段需要用到商品文本特征。

2. 构建层级语义ID系统

TIGER模型的核心创新在于用结构化语义ID替代传统随机ID。我们采用Sentence-T5结合RQ-VAE的方案,将商品语义编码为层级式离散表示。

2.1 语义嵌入提取

首先使用Sentence-T5将商品文本特征转换为稠密向量。这个预训练编码器能捕捉细粒度的语义关系:

from sentence_transformers import SentenceTransformer encoder = SentenceTransformer('sentence-t5-base') item_descriptions = [...] # 从预处理数据加载 item_embeddings = encoder.encode(item_descriptions, batch_size=128, show_progress_bar=True)

2.2 RQ-VAE实现关键细节

残差量化VAE是生成层级ID的核心组件,其实现有几个技术要点:

  1. Codebook初始化:使用k-means对首批样本聚类,避免codebook坍塌
  2. 残差量化过程:分层逐步逼近原始向量,保留最大语义信息
  3. 重构损失设计:平衡量化误差和模型表达能力

以下是RQ-VAE量化层的核心代码:

import jax.numpy as jnp from jax import random class ResidualQuantizer: def __init__(self, num_layers=3, codebook_size=256, latent_dim=32): self.codebooks = [ random.normal(random.PRNGKey(i), (codebook_size, latent_dim)) for i in range(num_layers) ] def quantize(self, z): residuals = [z] codes = [] quantized = jnp.zeros_like(z) for cb in self.codebooks: # 计算当前残差与codebook的距离 distances = jnp.sum((residuals[-1][:, None] - cb)**2, axis=-1) # 选择最近邻 code = jnp.argmin(distances, axis=-1) quantized += cb[code] residuals.append(residuals[-1] - cb[code]) codes.append(code) return jnp.stack(codes, axis=-1), quantized

训练RQ-VAE时,建议监控以下指标确保稳定收敛:

  • Codebook使用率(目标>80%)
  • 重构误差下降曲线
  • 各层残差的L2范数分布

3. 序列数据构建与增强

获得语义ID后,需要将其转换为适合Transformer训练的序列格式。这个过程有几个关键决策点:

3.1 序列格式化策略

原始论文采用展开(flatten)策略,将用户交互序列中的每个商品表示为其完整的语义ID序列。例如,若语义ID长度为4,用户历史包含3个商品,则输入序列长度为12:

用户历史: [itemA, itemB, itemC] 语义ID展开: [A1,A2,A3,A4, B1,B2,B3,B4, C1,C2,C3,C4]

这种表示虽然直观,但会导致序列长度快速增长。我们实验发现两种改进方案:

  1. 分层采样:随机选择某些层级token进行预测
  2. 前缀压缩:对共享前缀的连续商品进行合并

3.2 负采样与课程学习

生成式推荐面临的一个挑战是如何处理海量候选商品。我们采用动态负采样策略:

def generate_training_batch(sequences, item_pool, neg_ratio=5): batch = [] for seq in sequences: # 正样本是序列中的下一个商品 for i in range(len(seq)-1): pos_id = seq[i+1] # 负样本从商品池中随机抽取 neg_ids = random.choice( item_pool, size=min(neg_ratio, len(item_pool)-1), replace=False ) batch.append({ 'context': seq[:i+1], 'positive': pos_id, 'negatives': neg_ids }) return batch

注意:随着训练进行,可以逐步增加负样本比例和难度,模拟课程学习过程。

4. Transformer模型设计与训练

TIGER采用encoder-decoder架构处理语义ID序列,这与传统推荐模型有显著区别。我们的实现重点解决三个工程挑战:

4.1 模型架构优化

基于T5X框架,我们对原始Transformer做了以下调整:

  1. 相对位置编码:更好处理长序列推荐场景
  2. 层级注意力掩码:保持语义ID的层级结构
  3. 共享embedding:减少参数量的同时提升泛化
from flax import linen as nn class TigerTransformer(nn.Module): vocab_size: int num_layers: int = 4 num_heads: int = 6 embed_dim: int = 128 @nn.compact def __call__(self, inputs, targets): # 共享token embedding embed = nn.Embed(self.vocab_size, self.embed_dim) x = embed(inputs) # 编码器处理历史序列 for _ in range(self.num_layers): x = nn.SelfAttention(num_heads=self.num_heads)(x) x = nn.Dense(self.embed_dim*4)(x) x = nn.relu(x) x = nn.Dense(self.embed_dim)(x) # 解码器自回归生成 logits = [] for i in range(targets.shape[1]): pos = nn.Embed(targets.shape[1], self.embed_dim)(jnp.arange(i+1)) decoder_out = x + pos[:i+1] logits.append(nn.Dense(self.vocab_size)(decoder_out)) return jnp.stack(logits, axis=1)

4.2 训练技巧与调参

在实际训练中,我们发现以下几个策略至关重要:

  1. 渐进式序列长度:从短序列开始,逐步增加长度
  2. 动态温度采样:平衡探索与利用
  3. 混合精度训练:大幅提升训练速度

建议的优化器配置:

from optax import chain, add_decayed_weights, scale_by_adam optimizer = chain( add_decayed_weights(0.01), # L2正则 scale_by_adam(b1=0.9, b2=0.98), optax.scale_by_learning_rate_schedule( initial_learning_rate=0.01, transition_steps=10000, transition_begin=0, decay_rate=0.5 ) )

4.3 解码与候选生成

与传统推荐不同,TIGER通过自回归生成预测结果。这带来两个独特挑战:

  1. 无效ID处理:生成的语义ID可能不对应任何商品
  2. 束搜索优化:需要在多样性和相关性间取得平衡

我们实现了一个带后处理的束搜索解码器:

def beam_search_decoder(model, context, beam_size=5, max_len=4): # 初始化束 beams = [([], 0.0)] # (tokens, score) for step in range(max_len): new_beams = [] for seq, score in beams: # 获取下一个token的概率 logits = model(context, jnp.array([seq])) probs = jax.nn.softmax(logits[0, -1]) # 扩展beam top_k = jnp.argsort(probs)[-beam_size:] for token in top_k: new_seq = seq + [token] new_score = score + jnp.log(probs[token]) new_beams.append((new_seq, new_score)) # 选择top-k候选 beams = sorted(new_beams, key=lambda x: -x[1])[:beam_size] # 后处理:过滤无效ID并返回 valid_beams = [] for seq, score in beams: if is_valid_id(seq): # 检查ID是否对应真实商品 valid_beams.append((seq, score)) return valid_beams or beams # 若无有效ID则返回原始结果

5. 评估与生产部署

生成式推荐的评估指标需要特别设计,既要考虑传统推荐指标,也要关注生成质量。

5.1 离线评估方案

我们扩展了标准推荐评估协议:

指标类型具体指标说明
传统推荐指标Recall@K, NDCG@K衡量推荐准确性
生成质量指标Valid ID Rate有效生成ID的比例
多样性指标Intra-list Diversity推荐列表内商品间的差异度
冷启动指标Novelty@K对新商品的推荐能力

实验表明,在Amazon Beauty数据集上,我们的实现达到以下效果:

| 模型变体 | Recall@10 | NDCG@10 | Valid ID Rate | |----------------|-----------|---------|---------------| | 双塔基准 | 0.142 | 0.078 | - | | TIGER (beam=1) | 0.158 | 0.085 | 98.3% | | TIGER (beam=5) | 0.167 | 0.091 | 99.1% |

5.2 生产部署考量

将TIGER部署到实际系统需要考虑几个工程因素:

  1. 实时性要求:自回归生成相比传统检索更耗时
  2. 索引构建:需要维护语义ID到商品的快速查找表
  3. 混合部署:可与传统检索系统并行运行,取长补短

一个可行的部署架构:

用户请求 → 特征提取 → 并行执行: 分支1: TIGER生成式推荐 (高相关性) 分支2: 传统ANN检索 (高召回) → 结果融合与排序 → 返回推荐

5.3 持续学习策略

推荐系统需要持续更新以适应新商品和用户偏好变化。我们设计了两阶段更新机制:

  1. 语义ID增量更新:定期用新商品微调RQ-VAE
  2. Transformer在线学习:通过以下方式适应变化:
    • 新交互数据的fine-tuning
    • 模型蒸馏保持轻量
    • 基于用户反馈的强化学习

在实际项目中,这种架构显著提升了系统对时尚品类等快速变化场景的适应能力。一个有趣的发现是,语义ID的层级结构天然支持"相似推荐"功能——只需固定前几位token,随机生成后几位,就能获得语义相似但又有差异的商品推荐。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 23:57:18

在泰山派(RK3566)上给ST7789屏幕写SPI驱动,我踩过的那些设备树和DMA的坑

在RK3566平台为ST7789屏幕开发SPI驱动的实战避坑指南 当一块ST7789 SPI屏幕遇上Rockchip RK3566芯片,看似简单的驱动开发背后隐藏着无数细节陷阱。本文将带你深入设备树配置、DMA优化和SPI时序调校的实战现场,还原从零搭建显示系统的完整思考路径。 1. 设…

作者头像 李华
网站建设 2026/4/17 23:55:00

【平衡小车进阶】(一)蓝牙串口协议解析与多模式遥控实现(附源码)

1. 蓝牙串口通信基础与硬件选型 玩平衡小车最爽的部分莫过于用手机遥控了,但很多小伙伴卡在蓝牙通信这一关。我当年第一次用HC-05模块时,光是AT指令配置就折腾了一整天。现在回头看,其实只要掌握几个关键点就能少走弯路。 核心硬件选择方面&a…

作者头像 李华
网站建设 2026/4/17 23:53:14

CentOS 7防火墙实战:firewall-cmd端口转发配置与排错指南

1. 端口转发基础概念与原理 端口转发就像邮局的分拣员工作。想象你寄往"大楼A-8080房间"的包裹,被分拣员悄悄改成了"大楼B-8088房间"的地址标签,而收件人完全不知道这个变化。在CentOS 7中,firewalld就是这个智能分拣员&…

作者头像 李华
网站建设 2026/4/17 23:53:14

Golang如何部署到Linux服务器_Golang Linux部署教程【实用】

必须手动下载官方tar.gz包解压至/usr/local/go并配置GOROOT、PATH和GOPROXY;禁用apt等包管理器安装,因其版本滞后、路径混乱且不支持embed/泛型等新特性。直接上结论:别用 apt install golang,必须手动下载官方 go1.22.5.linux-am…

作者头像 李华