1. 变分联合嵌入(VJE)的核心思想解析
变分联合嵌入(VJE)是一种基于变分推断的表示学习方法,它通过构建一个概率生成模型来学习数据的低维表示。与传统确定性方法不同,VJE显式地建模了表示空间中的不确定性,为每个数据点学习一个分布而非单个点表示。
VJE的核心创新在于其特殊的似然函数设计。它采用了因子化的Student-t分布作为似然函数,将其分解为方向性(directional)和径向(radial)两个独立的部分:
- 方向性部分:建模单位球面上的角度关系,使用定义在球面SD-1上的Student-t分布
- 径向部分:建模表示向量范数差异,使用一维Student-t分布
这种分解的关键优势在于:
- 避免了标准Student-t似然中方向与范数的耦合导致的优化问题
- 更符合表示学习任务的内在几何特性
- 提供了更稳定的训练动态,防止后验坍塌
在实际实现中,方向性部分涉及复杂的微分几何运算,包括测地线对数映射、Schur补马氏距离计算等,这些我们将在后续章节详细解析。
2. VJE的概率图模型与ELBO推导
2.1 生成模型设定
VJE的概率图模型可以表示为:
x → z → s → y其中:
- x:输入数据(如图像)
- z:确定性编码器fθ输出的表示
- s:潜在变量,由变分后验q(s|z)生成
- y:观测到的表示空间数据,包括归一化方向ˆz和范数∥z∥
条件概率分解为: p(y|s) = p(ˆz|s) · p(∥z∥ | ∥s∥)
2.2 变分下界(ELBO)推导
对于两个视图x1,x2,VJE优化以下对称条件ELBO:
L = E[log p(y2|s1)] + E[log p(y1|s2)] - β(KL(q(s1|z1)∥p(s)) + KL(q(s2|z2)∥p(s)))
其中:
- 期望项对应两个视图间的相互预测似然
- KL项正则化变分后验,防止过拟合
- β控制正则化强度
2.3 停止梯度(Stop-gradient)的数学含义
在实现中,计算log p(yj|si)时会冻结(停止梯度)zj。这对应于概率建模中的固定观测语义:
∇θL = E[∇s log p(yj|si)∇θs] + E[∇y log p(yj|si)∇θyj]
停止梯度即令∇θyj=0,仅保留第一项,确保模型参数更新是为了解释固定观测,而非同时修改观测本身。
3. 核心算法实现细节
3.1 方向性负对数似然计算
方向性负对数似然计算涉及多个几何运算步骤,如算法1所示:
def directional_nll(z, s, sigma_sq, nu, D): # 单位方向向量计算 z_hat = z / max(np.linalg.norm(z), 1e-6) s_hat = s / max(np.linalg.norm(s), 1e-6) n = s_hat # 测地线距离计算 cos_theta = z_hat.T @ s_hat theta = np.arccos(cos_theta) sin_theta = np.sqrt(1 - cos_theta**2) # 对数映射(log-map) t = (theta / sin_theta) * (z_hat - cos_theta * s_hat) # 精度权重 w = 1 / sigma_sq # Schur补马氏距离计算 c = np.sum(s_hat**2 * w) a = np.sum(t**2 * w) b = np.sum(t * s_hat * w) Q = a - b**2 / c # 切线空间对数行列式 logdet = 0.5 * np.sum(np.log(sigma_sq)) + 0.5 * np.log(c) # 指数映射雅可比校正 jac = (D-2) * (np.log(sin_theta) - np.log(theta)) k = D - 1 return 0.5*(nu + k)*np.log(1 + Q/nu) + logdet + jac关键实现细节:
- 数值稳定性处理:归一化时添加小常数(1e-6)防止除零
- 测地线运算:精确计算球面上的几何量
- 方差处理:使用Softplus激活确保σ²>0,并设置下限(1e-5)
3.2 径向负对数似然计算
径向部分相对简单,主要处理范数差异:
def radial_nll(z, s, nu): rz = np.linalg.norm(z) rs = np.linalg.norm(s) delta_r = rz - rs return 0.5*(nu + 1)*np.log(1 + delta_r**2/nu)3.3 KL散度计算
对角高斯后验与标准高斯先验间的KL散度有解析解:
def kl_div(mu, sigma_sq): return 0.5 * np.sum(sigma_sq + mu**2 - 1 - np.log(sigma_sq))4. 推断网络架构设计
VJE的推断网络gϕ采用瓶颈(bottleneck)结构设计:
z → Linear(D,H) → LN → ReLU → Linear(H,H) → LN → ReLU → [Linear(H,D), Linear(H,D)] → [µ, σ²]关键设计选择:
- 层归一化(LayerNorm):稳定训练,替代偏置项
- 瓶颈比例r=0.25:平衡表达能力和计算效率
- 双头输出:分别预测均值µ和方差σ²
- 方差激活:Softplus确保正定性,加小常数(1e-5)下限
对于ResNet-18(D=512),隐藏层H=128;ResNet-50(D=2048),H=512。
5. 训练流程与实现技巧
5.1 完整训练步骤
如算法4所述,一个训练步骤包含:
- 编码器前向:z1=fθ(x1), z2=fθ(x2)
- 推断网络:µ1,σ²1=gϕ(z1); µ2,σ²2=gϕ(z2)
- 重参数化采样:s = µ + σ⊙ε, ε∼N(0,I)
- 损失计算:
- 方向性似然
- 径向似然
- KL散度
- 反向传播更新
5.2 关键实现技巧
EMA与停止梯度:两种实现固定观测语义的方式
- 停止梯度:简单直接,计算图中断开梯度
- EMA目标编码器:更稳定,但增加内存
蒙特卡洛采样:实验发现K=1足够,增加K无显著改进
方差居中技巧:对高维表示(D=2048),将σ²归一化为单位几何平均,防止log-determinant数值爆炸
各向异性方差必要性:实验表明标量方差(各向同性)会导致模型失效,必须使用特征级方差
6. 与能量基方法的对比
6.1 目标层面的差异
能量基方法(如SimSiam)最小化预测与目标间的点态差异:
L = d(gϕ(zi), zj)
而VJE优化条件对数似然:
L = E[log p(yj|si)] - βKL
6.2 几何正则化的来源
能量基方法依赖:
- 架构技巧(停止梯度、动量编码器)
- 显式正则(如VICReg的协方差约束)
VJE则通过:
- 解析KL项锚定后验
- 似然项的几何结构
6.3 特殊情形下的等价性
在某些极限情况下,VJE可退化为能量基方法:
- 高斯似然,固定方差λI,σ²→0,β→0 ⇒ 平方误差损失
- 方向似然,σ²=1,ν→∞ ⇒ 余弦相似度
但这些极限会失去VJE的概率解释和正则化优势。
7. 实验分析与实用建议
7.1 因子化似然的必要性
标准Student-t似然会导致:
- 后验方差坍塌(σ²→0)
- KL爆炸性增长
- 表示空间秩崩溃
因子化似然通过分离方向与径向分量,避免了这些问题。
7.2 超参数选择经验
- 自由度ν:小值(如1.0)提供更重尾分布,对异常值鲁棒
- KL权重β:1.0通常表现良好,可微调平衡重构与正则
- 学习率:与标准SSL方法相当,无需特殊调整
7.3 计算效率考量
VJE的计算开销主要来自:
- 编码器前向(与基线相同)
- 推断网络(比典型投影头更小)
- 几何运算(方向性似然)
实际测量显示,VJE与SimSiam等方法的每步耗时相当。
8. 扩展应用与未来方向
VJE框架可扩展至:
- 多模态学习:不同模态作为不同"视图"
- 异常检测:利用似然分数识别分布外样本
- 不确定性量化:后验方差作为置信度指标
可能的改进方向包括:
- 更丰富的后验分布族
- 自适应几何结构
- 大规模分布式训练优化