SetVectorMask
产 品 支 持 情 况
功 能 说 明
头 文 件 路 径 为:"basic_api/kernel_common.h"。
本 接 口 用 于 在 矢 量 计 算 时 设 置mask。
当 模 板 参 数isSetMask = false时,推 荐 使 用 本 接 口。使 用 前 需 要 先 调 用SetMaskCount/SetMaskNorm手 动 设 置Mask模 式,并 通 过 调 用 本 接 口 设 置Counter/Normal模 式 下 的mask。在 不 同 的 模 式 下,mask的 含 义 不 同:
Normal模 式 下,mask参 数 用 来 控 制 单 次 迭 代 内 参 与 计 算 的 元 素 个 数。此 时 又 可 以 划 分 为 如 下 两 种 模 式:
连 续 模 式:表 示 前 面 连 续 的 多 少 个 元 素 参 与 计 算。取 值 范 围 和 操 作 数 的 数 据 类 型 有 关,数 据 类 型 不 同,每 次 迭 代 内 能 够 处 理 的 元 素 个 数 最 大 值 不 同。当 操 作 数 为16位 时,mask∈[1, 128];当 操 作 数 为32位 时,mask∈[1, 64];当 操 作 数 为64位 时,mask∈[1, 32]。
逐 比 特 模 式:可 以 按 位 控 制 哪 些 元 素 参 与 计 算,bit位 的 值 为1表 示 参 与 计 算,0表 示 不 参 与。分 为maskHigh(高 位mask)和maskLow(低 位mask)。参 数 取 值 范 围 和 操 作 数 的 数 据 类 型 有 关,数 据 类 型 不 同,每 次 迭 代 内 能 够 处 理 的 元 素 个 数 最 大 值 不 同。当 操 作 数 为16位 时,maskLow、maskHigh∈[0, 264-1],并 且 不 同 时 为0;当 操 作 数 为32位 时,maskHigh为0,maskLow∈(0, 264-1];当 操 作 数 为64位 时,maskHigh为0,maskLow∈(0, 232-1]。
Counter模 式 下,mask参 数 表 示 整 个 矢 量 计 算 参 与 计 算 的 元 素 个 数。
当 模 板 参 数isSetMask = true时,不 需 要 调 用 本 接 口,参 考接 口 内 设 置Mask。
函 数 原 型
适 用 于Normal模 式 下mask逐 比 特 模 式 和Counter模 式
C++template <typename T, MaskMode mode = MaskMode::NORMAL> __aicore__ static inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow)适 用 于Normal模 式 下Mask连 续 模 式 和Counter模 式
C++template <typename T, MaskMode mode = MaskMode::NORMAL> __aicore__ static inline void SetVectorMask(int32_t len)
参 数 说 明
表 1 模 板 参 数 说 明
| 参 数 名 | 描 述 |
|---|---|
| T | 矢 量 计 算 操 作 数 数 据 类 型。 |
| mode | Mask模 式,MaskMode类 型,定 义 如 下:enum class MaskMode : uint8_t { |
表 2 参 数 说 明
| 参 数 名 | 输 入/输 出 | 描 述 |
|---|---|---|
| maskHigh | 输 入 | Normal模 式:对 应Normal模 式 下 的 逐 比 特 模 式,可 以 按 位 控 制 哪 些 元 素 参 与 计 算。传 入 高 位mask值。 Counter模 式:需 要 置0,本 入 参 不 生 效。 |
| maskLow | 输 入 | Normal模 式:对 应Normal模 式 下 的 逐 比 特 模 式,可 以 按 位 控 制 哪 些 元 素 参 与 计 算。传 入 低 位mask值。 Counter模 式:整 个 矢 量 计 算 过 程 中,参 与 计 算 的 元 素 个 数。 |
| len | 输 入 | Normal模 式:对 应Normal模 式 下 的Mask连 续 模 式,表 示 单 次 迭 代 内 前 面 连 续 的 多 少 个 元 素 参 与 计 算。 Counter模 式:整 个 矢 量 计 算 过 程 中,参 与 计 算 的 元 素 个 数。 |
数 据 类 型
支 持 数 据 类 型 为:b8/b16/b32。
返 回 值 说 明
无
约 束 说 明
- 该 接 口 仅 在 矢 量 计 算API的isSetMask模 板 参 数 为false时 生 效,使 用 完 成 后 需 要 使 用ResetMask将mask恢 复 为 默 认 值。
- 针 对Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品、Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品,mask = 0表 示 指 令 不 会 执 行 计 算 操 作,该 接 口 将 被 视 为NOP(空 操 作)。
调 用 示 例
可 结 合SetMaskCount与SetMaskNorm使 用,先 设 置Mask模 式 再 设 置mask:
Normal模 式 调 用 示 例
C++AscendC::LocalTensor<half> dstLocal; AscendC::LocalTensor<half> src0Local; AscendC::LocalTensor<half> src1Local; // Normal模 式 AscendC::SetMaskNorm(); AscendC::SetVectorMask<half, AscendC::MaskMode::NORMAL>(0xffffffffffffffff, 0xffffffffffffffff); // 逐bit模 式 // SetVectorMask<half, MaskMode::NORMAL>(128); // 连 续 模 式 // 多 次 调 用 矢 量 计 算API, 可 以 统 一 设 置 为Normal模 式,并 设 置mask参 数,无 需 在API内 部 反 复 设 置,省 去 了 在API反 复 设 置 的 过 程,会 有 一 定 的 性 能 优 势 // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单 次 迭 代 内 数 据 连 续 读 取 和 写 入 // dstRepStride, src0RepStride, src1RepStride = 8, 相 邻 迭 代 间 数 据 连 续 读 取 和 写 入 AscendC::Add<half, false>(dstLocal, src0Local, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 }); AscendC::Sub<half, false>(src0Local, dstLocal, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 }); AscendC::Mul<half, false>(src1Local, dstLocal, src0Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 }); AscendC::ResetMask();Counter模 式 调 用 示 例
C++// Counter模 式 和tensor高 维 切 分 计 算 接 口 配 合 使 用 AscendC::LocalTensor<half> dstLocal; AscendC::LocalTensor<half> src0Local; AscendC::LocalTensor<half> src1Local; int32_t len = 128; // 参 与 计 算 的 元 素 个 数 AscendC::SetMaskCount(); AscendC::SetVectorMask<half, AscendC::MaskMode::COUNTER>(len); AscendC::Add<half, false>(dstLocal, src0Local, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 }); AscendC::Sub<half, false>(src0Local, dstLocal, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 }); AscendC::Mul<half, false>(src1Local, dstLocal, src0Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 }); AscendC::SetMaskNorm(); AscendC::ResetMask(); // Counter模 式 和tensor前n个 数 据 计 算 接 口 配 合 使 用 AscendC::LocalTensor<half> dstLocal; AscendC::LocalTensor<half> src0Local; half num = 2; AscendC::SetMaskCount(); AscendC::SetVectorMask<half, AscendC::MaskMode::COUNTER>(128); // 参 与 计 算 的 元 素 个 数 为128 AscendC::Adds<half, false>(dstLocal, src0Local, num, 1); AscendC::Muls<half, false>(dstLocal, src0Local, num, 1); AscendC::SetMaskNorm(); AscendC::ResetMask();
更 多 示 例 请 参 考接 口 外 设 置Mask。