CANN/asc-devkit 随路转换NZ2ND搬运
随路转换NZ2ND搬运
产品支持情况
功能说明
支持在数据搬运时进行NZ到ND格式的转换。
函数原型
template <typename T>
__aicore__ inline void DataCopy(const GlobalTensor<T>& dst, const LocalTensor<T>& src, const Nz2NdParamsFull& intriParams)
[!NOTE]说明 各原型支持的具体数据通路和数据类型,请参考支持的通路和数据类型。
参数说明
表 1 模板参数说明
|
源操作数或者目的操作数的数据类型。支持的数据类型请参考支持的通路和数据类型。 |
表 2 接口参数说明
|
搬运参数,类型为Nz2NdParamsFull。 具体定义请参考${INSTALL_DIR}/include/ascendc/basic_api/interface/kernel_struct_data_copy.h,${INSTALL_DIR}请替换为CANN软件安装后文件存储路径。 |
表 3 Nz2NdParamsFull结构体内参数定义
|
源相邻NZ矩阵的偏移(头与头),取值范围:srcNdMatrixStride∈[1, 512],单位256 (16 * 16) 个元素。 |
|
|
目的ND矩阵中,来自源相邻NZ矩阵的偏移(头与头),取值范围:dstNdMatrixStride∈[1, 65535],单位为元素。 |
以half数据类型为例,NZ2ND转换示意图如下,样例中参数设置值和解释说明如下:
- ndNum = 2,表示源NZ矩阵的数目为2 (NZ矩阵1为A1~A4 + B1~B4,NZ矩阵2为C1~C4 + D1~D4)。
- nValue = 4,NZ矩阵的行数,也就是矩阵的高度为4。
- dValue = 32,NZ矩阵的列数,也就是矩阵的宽度为32个元素。
- srcNdMatrixStride = 1,表达相邻NZ矩阵起始地址间的偏移,即为A1~C1的距离,即为256个元素(16个DataBlock * 16个元素)。
- srcNStride = 4, 表示同一个源NZ矩阵的相邻Z排布的偏移,即为A1到B1的距离,即为64个元素(4个DataBlock* 16个元素)。
- dstDStride = 160,表达一个目的ND矩阵的相邻行之间的偏移,即A1和A2之间的距离,即为10个DataBlock,即10 * 16 = 160个元素。
- dstNdMatrixStride = 48,表达dst中第x个目的ND矩阵的起点和第x+1个目的ND矩阵的起点的偏移,即A1和C1之间的距离,即为3个DataBlock,3 * 16 = 48个元素。
图 1 NZ2ND转换示意图(half数据类型)
.png "NZ2ND转换示意图(half数据类型)"?utm_source=gitcode_repo_files)
以float数据类型为例,NZ2ND转换示意图如下,样例中参数设置值和解释说明如下:
- ndNum = 2,表示源NZ矩阵的数目为2 (NZ矩阵1为A1~A8 + B1~B8,NZ矩阵2为C1~C8 + D1~D8)。
- nValue = 4,NZ矩阵的行数,也就是矩阵的高度为4。
- dValue = 32,NZ矩阵的列数,也就是矩阵的宽度为32个元素。
- srcNdMatrixStride = 1,表达相邻NZ矩阵起始地址间的偏移,即A1到C1的距离,为256个元素(32个DataBlock * 8个元素)
- srcNStride = 4, 表示同一个源NZ矩阵的相邻Z排布的偏移,即A1到B1的距离,为64个元素 (8个DataBlock * 8个元素)。
- dstDStride = 144,表示一个目的ND矩阵的相邻行之间的偏移,即A1和A3之间的距离,为18个DataBlock,即18 * 8 = 144个元素。
- dstNdMatrixStride = 40,表示dst中第x个目的ND矩阵的起点和第x+1个目的ND矩阵的起点的偏移,即A1和C1之间的距离,为5个DataBlock,5 * 8 = 40个元素。
图 2 NZ2ND转换示意图(float数据类型)
.png "NZ2ND转换示意图(float数据类型)"?utm_source=gitcode_repo_files)
返回值说明
无
约束说明
无
支持的通路和数据类型
下文的数据通路均通过逻辑位置TPosition来表达,并注明了对应的物理通路。TPosition与物理内存的映射关系见表1。
表 4 Local Memory -> Global Memory具体通路和支持的数据类型
|
bool、int8_t、uint8_t、hifloat8_t、fp8_e5m2_t、fp8_e4m3fn_t、fp8_e8m0_t、int16_t、uint16_t、half、bfloat16_t、int32_t、uint32_t、float、complex32、int64_t、uint64_t、double、complex64 |
||
调用示例
intriParams参数解析请参考图1。
// dstLocal为half类型的LocalTensor,dstGlobal为half类型的GlobalTensor
AscendC::Nz2NdParamsFull intriParams{1, 32, 32, 1, 32, 32, 1};
// Local Memory -> Global Memory
AscendC::DataCopy(dstGlobal, dstLocal, intriParams);
结果示例:
输入数据(srcGlobal): [1 2 3 ... 1024]
输出数据(dstGlobal):[1 2 ... 15 16 513 514 ... 527 528 17 18 ... 31 32 529 530 ... 543 544 ...497 498 ... 511 512 1009 1010... 1023 1024]
更多推荐


所有评论(0)