Skip to content
版 本

SetVectorMask

产 品 支 持 情 况

产 品

是 否 支 持

Ascend 950PR/Ascend 950DT

Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品

Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品

Atlas 200I/500 A2 推 理 产 品

Atlas 推 理 系 列 产 品AI Core

Atlas 推 理 系 列 产 品Vector Core

x

Atlas 训 练 系 列 产 品

Kirin X90

Kirin 9030

功 能 说 明

头 文 件 路 径 为:"basic_api/kernel_common.h"

本 接 口 用 于 在 矢 量 计 算 时 设 置mask。

  • 当 模 板 参 数isSetMask = false时,推 荐 使 用 本 接 口。使 用 前 需 要 先 调 用SetMaskCount/SetMaskNorm手 动 设 置Mask模 式,并 通 过 调 用 本 接 口 设 置Counter/Normal模 式 下 的mask。在 不 同 的 模 式 下,mask的 含 义 不 同:

    • Normal模 式 下,mask参 数 用 来 控 制 单 次 迭 代 内 参 与 计 算 的 元 素 个 数。此 时 又 可 以 划 分 为 如 下 两 种 模 式:

      • 连 续 模 式:表 示 前 面 连 续 的 多 少 个 元 素 参 与 计 算。取 值 范 围 和 操 作 数 的 数 据 类 型 有 关,数 据 类 型 不 同,每 次 迭 代 内 能 够 处 理 的 元 素 个 数 最 大 值 不 同。当 操 作 数 为16位 时,mask∈[1, 128];当 操 作 数 为32位 时,mask∈[1, 64];当 操 作 数 为64位 时,mask∈[1, 32]。

      • 逐 比 特 模 式:可 以 按 位 控 制 哪 些 元 素 参 与 计 算,bit位 的 值 为1表 示 参 与 计 算,0表 示 不 参 与。分 为maskHigh(高 位mask)和maskLow(低 位mask)。参 数 取 值 范 围 和 操 作 数 的 数 据 类 型 有 关,数 据 类 型 不 同,每 次 迭 代 内 能 够 处 理 的 元 素 个 数 最 大 值 不 同。当 操 作 数 为16位 时,maskLow、maskHigh∈[0, 264-1],并 且 不 同 时 为0;当 操 作 数 为32位 时,maskHigh为0,maskLow∈(0, 264-1];当 操 作 数 为64位 时,maskHigh为0,maskLow∈(0, 232-1]。

    • Counter模 式 下,mask参 数 表 示 整 个 矢 量 计 算 参 与 计 算 的 元 素 个 数。

  • 当 模 板 参 数isSetMask = true时,不 需 要 调 用 本 接 口,参 考接 口 内 设 置Mask

函 数 原 型

  • 适 用 于Normal模 式 下mask逐 比 特 模 式 和Counter模 式

    C++
    template <typename T, MaskMode mode = MaskMode::NORMAL>
    __aicore__ static inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow)
    
  • 适 用 于Normal模 式 下Mask连 续 模 式 和Counter模 式

    C++
    template <typename T, MaskMode mode = MaskMode::NORMAL>
    __aicore__ static inline void SetVectorMask(int32_t len)
    

参 数 说 明

表 1 模 板 参 数 说 明

参 数 名描 述
T矢 量 计 算 操 作 数 数 据 类 型。
modeMask模 式,MaskMode类 型,定 义 如 下:
enum class MaskMode : uint8_t {
NORMAL = 0, // Normal模 式
COUNTER // Counter模 式
};

表 2 参 数 说 明

参 数 名输 入/输 出描 述
maskHigh输 入Normal模 式:对 应Normal模 式 下 的 逐 比 特 模 式,可 以 按 位 控 制 哪 些 元 素 参 与 计 算。传 入 高 位mask值。
Counter模 式:需 要 置0,本 入 参 不 生 效。
maskLow输 入Normal模 式:对 应Normal模 式 下 的 逐 比 特 模 式,可 以 按 位 控 制 哪 些 元 素 参 与 计 算。传 入 低 位mask值。
Counter模 式:整 个 矢 量 计 算 过 程 中,参 与 计 算 的 元 素 个 数。
len输 入Normal模 式:对 应Normal模 式 下 的Mask连 续 模 式,表 示 单 次 迭 代 内 前 面 连 续 的 多 少 个 元 素 参 与 计 算。
Counter模 式:整 个 矢 量 计 算 过 程 中,参 与 计 算 的 元 素 个 数。

数 据 类 型

支 持 数 据 类 型 为:b8/b16/b32。

返 回 值 说 明

约 束 说 明

  • 该 接 口 仅 在 矢 量 计 算API的isSetMask模 板 参 数 为false时 生 效,使 用 完 成 后 需 要 使 用ResetMask将mask恢 复 为 默 认 值。
  • 针 对Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品、Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品,mask = 0表 示 指 令 不 会 执 行 计 算 操 作,该 接 口 将 被 视 为NOP(空 操 作)。

调 用 示 例

可 结 合SetMaskCountSetMaskNorm使 用,先 设 置Mask模 式 再 设 置mask:

  • Normal模 式 调 用 示 例

    C++
    AscendC::LocalTensor<half> dstLocal;
    AscendC::LocalTensor<half> src0Local;
    AscendC::LocalTensor<half> src1Local;
    
    // Normal模 式
    AscendC::SetMaskNorm();
    AscendC::SetVectorMask<half, AscendC::MaskMode::NORMAL>(0xffffffffffffffff, 0xffffffffffffffff);  // 逐bit模 式
    
    // SetVectorMask<half, MaskMode::NORMAL>(128);  // 连 续 模 式
    // 多 次 调 用 矢 量 计 算API, 可 以 统 一 设 置 为Normal模 式,并 设 置mask参 数,无 需 在API内 部 反 复 设 置,省 去 了 在API反 复 设 置 的 过 程,会 有 一 定 的 性 能 优 势
    // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单 次 迭 代 内 数 据 连 续 读 取 和 写 入
    // dstRepStride, src0RepStride, src1RepStride = 8, 相 邻 迭 代 间 数 据 连 续 读 取 和 写 入
    AscendC::Add<half, false>(dstLocal, src0Local, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 });
    AscendC::Sub<half, false>(src0Local, dstLocal, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 });
    AscendC::Mul<half, false>(src1Local, dstLocal, src0Local, AscendC::MASK_PLACEHOLDER, 1, { 2, 2, 2, 8, 8, 8 });
    AscendC::ResetMask();
    
  • Counter模 式 调 用 示 例

    C++
    // Counter模 式 和tensor高 维 切 分 计 算 接 口 配 合 使 用
    AscendC::LocalTensor<half> dstLocal;
    AscendC::LocalTensor<half> src0Local;
    AscendC::LocalTensor<half> src1Local;
    int32_t len = 128;  // 参 与 计 算 的 元 素 个 数
    AscendC::SetMaskCount();
    AscendC::SetVectorMask<half, AscendC::MaskMode::COUNTER>(len);
    AscendC::Add<half, false>(dstLocal, src0Local, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 });
    AscendC::Sub<half, false>(src0Local, dstLocal, src1Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 });
    AscendC::Mul<half, false>(src1Local, dstLocal, src0Local, AscendC::MASK_PLACEHOLDER, 1, { 1, 1, 1, 8, 8, 8 });
    AscendC::SetMaskNorm();
    AscendC::ResetMask();
    
    // Counter模 式 和tensor前n个 数 据 计 算 接 口 配 合 使 用
    AscendC::LocalTensor<half> dstLocal;
    AscendC::LocalTensor<half> src0Local;
    half num = 2; 
    AscendC::SetMaskCount();
    AscendC::SetVectorMask<half, AscendC::MaskMode::COUNTER>(128); // 参 与 计 算 的 元 素 个 数 为128
    AscendC::Adds<half, false>(dstLocal, src0Local, num, 1);
    AscendC::Muls<half, false>(dstLocal, src0Local, num, 1);
    AscendC::SetMaskNorm();
    AscendC::ResetMask();
    

更 多 示 例 请 参 考接 口 外 设 置Mask

免 责 声 明:本 站 内 容 由 asc-devkit 仓 master 分 支 自 动 编 译 生 成,属 于 持 续 开 发 版 本,可 能 存 在 缺 陷,仅 供 预 览 与 参 考。如 需 稳 定 及 商 用 资 料,请 查 阅 官 方 昇 腾 社 区