MXFP4MatmulTla Example Readme
【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass
注意:社区包暂不支持 950 能力,后续支持的版本敬请期待
功能介绍
- 演示 Ascend 950 上的MX FP4 矩阵乘:左矩阵 A、右矩阵 B 经 MX 缩放(
float8_e8m0)后在 Cube 上完成乘加,输出为 FP32。 - 本示例中 A、B 元素类型为
float4_e2m1x2_t(E2M1 打包);缩放因子为float8_e8m0_t。未启用 Bias(ElementBias为void)。 - 默认布局为 A
RowMajor、BColumnMajor、CRowMajor,与gen_data.py在trans_a=0, trans_b=1时生成的数据一致。
代码组织
├── 54_ascend950_fp4_mx_matmul │ ├── CMakeLists.txt # CMake编译文件 │ ├── README.md │ ├── gen_data.py │ └── fp4_mx_matmul.cpp # 主文件使用示例
- 获取代码之后编译相应的算子可执行文件,可参考quickstart,本用例为 Ascend950(3510)算子,编译时需加
-DCATLASS_ARCH=3510。L1 分块为 256×256×448、L0 为 256×256×128,以满足 512KiB L1 与 L0 容量约束(勿随意增大 L1 的 K,否则L1TileShape exceeding the L1 space)。 - 执行算子
# 编译指定用例 bash scripts/build.sh 54_ascend950_fp4_mx_matmul -DCATLASS_ARCH=3510 # 生成测试样例(在 examples/54_ascend950_fp4_mx_matmul/data 下生成 input/ 与 golden/) python3 examples/54_ascend950_fp4_mx_matmul/gen_data.py 256 512 1024 0 1 # 输入参数分别对应 m, n, k, trans_a, trans_b # trans_a表示A矩阵是否转置,0是不转置,1是转置 # trans_b表示B矩阵是否转置,0是不转置,1是转置 # 执行测试样例 ./output/bin/54_ascend950_fp4_mx_matmul 256 512 1024 0 # 可执行文件名 |矩阵m轴|n轴|k轴|Device ID # Device ID可选,默认为0执行结果如下,说明精度比对成功。
Compare success.使用说明
1、gen_data.py的输入支持trans_a和trans_b,但54_ascend950_fp4_mx_matmul可执行文件不支持,仅仅是trans_a为0及trans_b为1的example示例。
若要对应转置情况请修改example示例中的layout,因为layout隐式表征转置状态,即layout::RowMajor表示不转置,layout::ColumnMajor表示转置。
其对应关系如下表:
| trans_a | trans_b | LayoutA | LayoutB |
|---|---|---|---|
| 0 | 0 | layout::RowMajor | layout::RowMajor |
| 0 | 1 | layout::RowMajor | layout::ColumnMajor |
| 1 | 0 | layout::ColumnMajor | layout::RowMajor |
| 1 | 1 | layout::ColumnMajor | layout::ColumnMajor |
2、 本example完成mx量化矩阵乘: C = (MxScaleA x A) * (MxScaleB x B) + Bias A、B支持数据类型为float4_e1m2或float4_e2m1 MxScaleA、MxScaleB支持数据类型为float8_e8m0
其中对于MxScaleA、MxScaleB的数据排布要求如下: 当A为RowMajor时,MxScaleA的shape为(m, ceil(k/64), 2) 当A为ColumnMajor时,MxScaleA的shape为(ceil(k/64), m, 2) 当B为RowMajor时,MxScaleB的shape为(ceil(k/64), n, 2) 当B为ColumnMajor时,MxScaleB的shape为(n, ceil(k/64), 2)
3、MxMatmulTla与BlockMmadTla搭配使用的 DispatchPolicy 为Gemm::MmadMx(定义见include/catlass/gemm/dispatch_policy.hpp),模板参数顺序与默认值如下:
| 模板参数 | 默认值 | 参数说明 |
|---|---|---|
ArchTag | 无 | 架构标签,例如Arch::Ascend950 |
ENABLE_UNIT_FLAG | false | 是否开启 UnitFlag;当L0C_STAGES > 1(L0C 多缓冲)时必须为false |
L1_SCALE_FACTOR_K | 16 | GM→L1 的 MX scale 一次驻留所覆盖的L1 K 方向条带个数;为1时表示每个 L1 K 条带各搬一次 scale(见类型内注释) |
L0C_STAGES | 1 | L0C 缓冲段数;设为2可开启 L0C 双缓冲(需与ENABLE_UNIT_FLAG约束一致) |
ENABLE_L1_RESIDENT | false | 是否开启 L1 常驻 |
L1A_STAGES | 2 | L1 上加载矩阵 A 的 buffer 数量 |
L1B_STAGES | 2 | L1 上加载矩阵 B 的 buffer 数量 |
L0A_STAGES | 2 | L0 上加载矩阵 A 的 buffer 数量 |
L0B_STAGES | 2 | L0 上加载矩阵 B 的 buffer 数量 |
设矩阵Shape为M N K, L1上的分块大小为m1 n1 k1,M方向的分块数量mTiles = CeilDiv(M, m1),N方向的分块数量nTiles = CeilDiv(N, n1),总任务数为taskBlocks = mTiles * nTiles,在以下两种情况下可以选择开启enableL1Resident:
1.mTiles = 1,且nTiles > CoreNum,且K < 2 * k1。此时还可以设置l0CStages=2(需要关闭enableUnitFlag),如果空间不足无法设置l0CStages=2,则将n1设置为原来的一半。
2.nTiles = 1,且mTiles > CoreNum, 且K < 2 * k1。此时还可以设置l0CStages=2(需要关闭enableUnitFlag),如果空间不足无法设置l0CStages=2,则将m1设置为原来的一半。
【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考