背景痛点:ChatTTS 原生 PyTorch 的“慢”与“重”
第一次把 ChatTTS 放到线上做语音合成时,我整个人是懵的:
一张 A10 卡,单条 10 s 音频要 2.3 s 才能吐出来,GPU 显存直接飙到 6 GB+,并发一多就 OOM。
问题根因并不神秘——
- 生成式模型本身自回归,每一步都要把上一帧 hidden 重新喂回网络,计算图无法整图融合。
- PyTorch 每次
forward都重新建图、申请显存,碎片严重。 - Python GIL + 多线程调度,让 batch 推理“假并行”变成真排队。
线上业务可等不起,于是把“模型瘦身”提上日程。
技术选型:ONNX Runtime 为什么胜出
我把 TensorRT、OpenVINO、ONNX Runtime 拉到同一张表格里对比:
| 维度 | TensorRT | OpenVINO | ONNX Runtime |
|---|---|---|---|
| 跨平台 | ×(NVIDIA 专属) | △(x86/ARM) | √(Win/Linux/macOS) |
| 算子完整度 | △(自定义算子需 plugin) | △ | √(官方支持 Transformer 全套) |
| 开发成本 | 高(C++ plugin 编译) | 中 | 低(Python 即可) |
| 量化生态 | 强(FP16/INT8) | 强 | 中(FP16 简单,INT8 需 QDQ) |
结论:
- 公司线上既有 NVIDIA 也有 Intel 节点,ONNX 一次导出、多端运行,最省心。
- ChatTTS 里大量
torch.nn.MultiheadAttention在 ONNX 里已原生映射,无需手写 plugin。 - Python 侧就能完成 FP16 量化,算法同事自己维护,不麻烦运维。
于是拍板:用 ONNX Runtime 作为推理后端。
核心实现:从.pt到.onnx的惊险一跃
1. 模型导出关键参数
ChatTTS 的 TTS 部分接受三个动态轴:batch、seq_len、mel_len,导出脚本如下:
# export_onnx.py import torch from chattts import ChatTTS # 伪代码,替换成你的模型入口 model: torch.nn.Module = ChatTTS.load("checkpoints") model.eval() dummy_x = torch.randn(1, 512, 80) # mel 谱 dummy_y = torch.randint(0, 300, (1, 128)) # phoneme id dynamic_axes = { "mel": {0: "batch", 1: "seq"}, "phoneme": {0: "batch", 1: "seq"}, "audio": {0: "batch", 1: "time"}, } torch.onnx.export( model, (dummy_x, dummy_y), "chattts.onnx", input_names=["mel", "phoneme"], output_names=["audio"], dynamic_axes=dynamic_axes, opset_version=14, do_constant_folding=True, )注意:
opset=14以上才支持trilu,ChatTTS 里 causal mask 会用到。- 如果模型里出现
torch.repeat_interleave,先换成expand+reshape,否则 ONNX 会报“op not supported”。
2. 带异常处理的加载封装
# onnx_wrapper.py from pathlib import Path import onnxruntime as ort import numpy as np from typing import Tuple class ChatTTSOnnx: def __init__(self, onnx_path: Path, providers=None): if providers is None: providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if not onnx_path.exists(): raise FileNotFoundError(f"ONNX 文件不存在: {onnx_path}") try: self.sess = ort.InferenceSession(str(onnx_path), providers=providers) except Exception as e: raise RuntimeError(f"加载 ONNX 失败: {e}") def synthesize(self, mel: np.ndarray, phoneme: np.ndarray) -> np.ndarray: """返回音频波形,float32 [-1,1]""" outputs = self.sess.run( ["audio"], { "mel": mel.astype(np.float32), "phoneme": phoneme.astype(np.int64), }, ) return outputs[0]3. VAD + STFT 后处理集成
ChatTTS 输出的是 22 kHz 波形,但线上常需要 16 kHz、带音量归一化。用 ONNX Runtime 的OnnxVectorized能一次把 VAD、重采样、STFT 打包成子图,减少 Python 来回拷贝。
核心思路:
- VAD 用 Silero VAD ONNX(已经官方提供)。
- STFT 用
onnx.helper建一个Constant+STFT子图,导出为post.onnx。 - 主模型与后处理模型用
Session.run链式调用,显存复用同一块IOBinding。
性能优化:FP16 与 batch 的魔法数字
1. FP16 量化一行代码
from onnxruntime.tools import optimizer optimized = optimizer.optimize_model( "chattts.onnx", model_type="bert", # 通用 transformer 优化 num_heads=16, hidden_size=1024, ) optimized.convert_float_to_float16() optimized.save("chattts_fp16.onnx")实测 A10 上 10 s 音频:
- FP32 显存峰值 6.3 GB → FP16 降到 3.1 GB,降幅51%。
- RTF 从 0.23 降到 0.11,提速 2.1×。
2. batch 大小对 RTF 的影响
| batch | RTF (FP16) | 首包延迟 |
|---|---|---|
| 1 | 0.11 | 180 ms |
| 4 | 0.08 | 190 ms |
| 8 | 0.07 | 210 ms |
可见 batch=4 是吞吐与延迟的甜蜜点,再大收益递减。
避坑指南:自定义算子与多线程
1. 自定义算子注册
ChatTTS 里为了提速,写了一个torch.ops.unfold1d的 C++ extension,ONNX 没有对应算子。解决步骤:
- 把
unfold1d换成nn unfold + reshape,保证纯 ONNX 算子。 - 如果非要用原版,可注册自定义 op:
- 写
my_unfold1d.cc,实现OrtCustomOp接口。 - 编译为
libmyop.so,SessionOptions.RegisterCustomOpsLibrary("libmyop.so")。 - Python 侧无需改代码,只要
.so在LD_LIBRARY_PATH。
- 写
2. 多线程 session 复用
ONNX Runtime 的InferenceSession非线程安全,但创建成本大。线上做法:
- 每个线程预创建 1 个 session,用
threading.local()保存。 - 全局维护 1 个
Queue[Session],请求到达时get(),用完put(),避免反复 new。
import threading from queue import Queue sess_pool = Queue(maxsize=4) for _ in range(4): sess_pool.put(ChatTTSOnnx("chatts_fp16.onnx")) def worker(): sess = sess_pool.get() try: audio = sess.synthesize(mel, phoneme) finally: sess_pool.put(sess)代码规范小结
- 所有公开接口带类型标注,返回
np.ndarray而非List[float],减少隐式拷贝。 - 关键步骤抛出自定义异常,方便 Sentry 聚类。
- 日志统一用
structlog,字段rtf=round(time/audio_len, 3),方便监控大盘。
延伸思考:流式推理与动态 shape
目前方案是“整句合成”,线上最长 30 s 音频,首包 200 ms 左右还能接受。但要做到“边合成边播放”,就得拆成 chunk 级流式。
挑战:
- 自回归模型每步依赖上一帧 hidden,如何跨 chunk 传递 KV-Cache?
- ONNX 动态 shape 虽然支持
-1,但 CUDA provider 在past_key_values变化时会重新 malloc,导致抖动。 - 需要把 cache 大小固定为
max_len,用mask控制实际长度,牺牲一点显存换速度。
下一步计划:
- 把 decoder 拆成
init_decoder.onnx+step_decoder.onnx,用 C++ 写流式调度器,保证 300 ms 首包、RTF<0.05。 - 探索 ONNX Runtime Web,浏览器里直接跑,让 TTS 走端侧,服务端只下发音人 embedding。
把 ChatTTS 搬到 ONNX 后,线上同样一张 A10,并发从 20 QPS 提到 70 QPS,显存还省了一半。最开心的是算法同学——他们继续用 PyTorch 训练,导出脚本一键搞定,无需关心底层硬件。如果你也在为生成式语音合成的延迟和内存头疼,不妨先按本文流程跑一次,相信你会回来点赞的。