ComplexMatDot算子实现
【免费下载链接】ops-blas本项目是CANN提供的高性能线性代数计算以及轻量化GEMM调用算子库。项目地址: https://gitcode.com/cann/ops-blas
概述
BLAS ComplexMatDot算子实现。
ComplexMatDot(复数矩阵点乘)算子实现了两个复数矩阵的逐元素乘法运算,是BLAS基础线性代数库中的扩展算子之一。
该算子针对复数运算特性进行了优化,使用GatherMask操作高效完成复数矩阵的逐元素乘法。
支持的产品
- Atlas A3 训练系列产品/Atlas A3 推理系列产品
- Atlas A2 训练系列产品/Atlas A2 推理系列产品
目录结构介绍
├── complex_mat_dot │ ├── CMakeLists.txt // 编译工程文件 │ ├── README.md // 说明文档 │ └── complex_mat_dot_test.cpp // 算子调用样例算子描述
- 算子功能:
ComplexMatDot算子实现了两个复数矩阵的逐元素乘法。对应的数学表达式为:
result[i, j] = matx[i, j] * maty[i, j]matx和maty是复数矩阵,result是输出复数矩阵
复数乘法公式:(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
对应的接口为:
int aclblasComplexMatDot(const float *matx, const float *maty, float *result, const int64_t m, const int64_t n, void *stream);| 参数 | complex_mat_dot 参数说明 | |||
| 参数列表 | Param. | Memory | in/out | 含义 |
| m | in | 矩阵的行数。 | ||
| n | in | 矩阵的列数。 | ||
| matx | device | in | 复数矩阵,维度为 m × n,存储为 2*m*n 个float。 | |
| maty | device | in | 复数矩阵,维度为 m × n,存储为 2*m*n 个float。 | |
| result | device | out | 复数矩阵,维度为 m × n,存储为 2*m*n 个float。 | |
算子规格:
算子类型(OpType) ComplexMatDot 算子输入 name shape data type format matx m × n complex ND maty m × n complex ND 算子输出 result m × n complex ND 核函数名 complex_mat_dot_kernel 算子实现:
将输入数据从matx和maty的GM地址分块搬运到UB,使用GatherMask分离实部和虚部,进行复数乘法计算后再搬出到result所在的GM地址。
调用实现
使用内核调用符<<<>>>调用核函数。
编译运行
在本样例根目录下执行如下步骤,编译并执行算子。
配置环境变量
请根据当前环境上CANN开发套件包的安装方式,选择对应配置环境变量的命令。默认路径,root用户安装CANN软件包
source /usr/local/Ascend/cann/set_env.sh默认路径,非root用户安装CANN软件包
source $HOME/Ascend/cann/set_env.sh指定路径install_path,安装CANN软件包
source ${install_path}/cann/set_env.sh
样例执行
bash build.sh --ops=complex_mat_dot --run # --ops=<算子名> --run可选参数,执行测试样例执行结果如下,说明精度对比成功。
[Success] Case accuracy is verification passed.
【免费下载链接】ops-blas本项目是CANN提供的高性能线性代数计算以及轻量化GEMM调用算子库。项目地址: https://gitcode.com/cann/ops-blas
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考