SetBatchInfoForNormal
【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言,原生支持C和C++标准规范,主要由类库和语言扩展层构成,提供多层级API,满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit
功能说明
设置A/B矩阵的M/N/K轴信息,以及A/B矩阵的Batch数。Layout类型为NORMAL的场景,调用IterateBatch或者IterateNBatch接口之前,需要在Host侧Tiling实现中通过本接口设置A/B矩阵的M/N/K轴等信息。
函数原型
int32_t SetBatchInfoForNormal(int32_t batchA, int32_t batchB, int32_t m, int32_t n, int32_t k)参数说明
表1参数说明
| 参数名 | 输入/输出 | 描述 |
|---|---|---|
| batchA | 输入 | A矩阵的batch数 |
| batchB | 输入 | B矩阵的batch数 |
| m | 输入 | A矩阵的M轴信息 |
| n | 输入 | B矩阵的N轴信息 |
| k | 输入 | A/B矩阵的K轴信息 |
返回值说明
-1表示设置失败; 0表示设置成功。
约束说明
Layout类型为NORMAL的场景,调用IterateBatch或者IterateNBatch接口之前,需要在Host侧Tiling实现中通过本接口设置A/B矩阵的M/N/K轴等信息。
调用示例
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); matmul_tiling::MultiCoreMatmulTiling tiling(ascendcPlatform); int32_t M = 32; int32_t N = 256; int32_t K = 64; tiling.SetDim(1); tiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); tiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); tiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); tiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); tiling.SetShape(M, N, K); tiling.SetOrgShape(M, N, K); tiling.SetBias(true); tiling.SetBufferSpace(-1, -1, -1); constexpr int32_t BATCH_NUM = 3; tiling.SetBatchInfoForNormal(BATCH_NUM, BATCH_NUM, M, N, K); // 设置矩阵排布 tiling.SetBufferSpace(-1, -1, -1); optiling::TCubeTiling tilingData; int ret = tiling.GetTiling(tilingData);【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言,原生支持C和C++标准规范,主要由类库和语言扩展层构成,提供多层级API,满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考