MmadWithSparse
产 品 支 持 情 况
| 产 品 | 是 否 支 持 |
|---|---|
| Ascend 950PR/Ascend 950DT | x |
| Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品 | √ |
| Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品 | √ |
| Atlas 200I/500 A2 推 理 产 品 | x |
| Atlas 推 理 系 列 产 品AI Core | x |
| Atlas 推 理 系 列 产 品Vector Core | x |
| Atlas 训 练 系 列 产 品 | x |
功 能 说 明
头 文 件 路 径 为:#include "basic_api/kernel_operator_mm_intf.h"。
MmadWithSparse接 口 负 责 完 成 特 殊 稀 疏 矩 阵 乘 加 操 作。稀 疏 矩 阵 是 一 种 特 殊 类 型 的 矩 阵,即 矩 阵 中 包 含 较 多 的 零 元 素。4:2结 构 化 稀 疏 要 求 一 个 连 续 的4个 权 重 或 激 活 值 的 组(通 常 是 张 量 中 的 一 行 或 一 列)中,最 多 只 有2个 值 为 非 零,其 余2个 强 制 为 零。
MmadWithSparse接 口 传 入 的 左 矩 阵A为 稀 疏 矩 阵,右 矩 阵B为 稠 密 矩 阵。矩 阵A是 个 全 尺 寸 矩 阵,在MmadWithSparse计 算 时 完 成 稠 密 化;矩 阵B是 经 过4:2结 构 化 稀 疏 过 滤 掉 零 值 之 后 的 稠 密 矩 阵,需 要 在 计 算 执 行 前 的 输 入 数 据 准 备 时 自 行 完 成 稠 密 化(按 照 下 文 中 介 绍 的 稠 密 算 法 进 行 稠 密 化)。B稠 密 矩 阵 需 要 通 过 调 用LoadDataWithSparse载 入,同 时 加 载 索 引 矩 阵,索 引 矩 阵 在 矩 阵B稠 密 化 的 过 程 中 生 成,再 用 于A矩 阵 的 稠 密 化。索 引 矩 阵 存 储 在 内 部 缓 冲 区,该 索 引 矩 阵 的 布 局 和 布 局 大 小 与 矩 阵B相 同,用 于 在 进 行 矩 阵 乘 加 操 作 之 前 进 一 步 将 矩 阵A压 缩。
跟Mmad接 口 实 现 昇 腾NPU矩 阵 乘 计 算 能 力 类 似,MmadWithSparse接 口 的 数 学 表 达 式 为:
$$ C = A \times B + C $$
完 整 示 例 请 参 考:MmadWithSparse样 例
表 1 Sparse矩 阵 计 算 矩 阵A、B、C解 释 说 明
| 矩 阵 计 算 逻 辑 | 矩 阵 计 算 物 理 位 置 | 维 度 | 输 入/输 出 数 据 格 式 | 数 据 类 型 |
|---|---|---|---|---|
| A | L0A Buffer | M x K | Zz | 数 据 类 型 |
| B | L0B Buffer | K/2 x N | Zn | |
| C | L0C Buffer | M x N | Nz |
下 面 的 图 展 示 了Cube如 何 计 算 出 其 中 一 行 和 一 列 的 内 积:
图 1 MmadWithSparse接 口 计 算 流 程 示 意 图 
其 中 矩 阵A原 始 分 形 为(16, 2*C0),索 引 矩 阵Index分 形 为(C0,16),每 一 行 矩 阵A的 数 据 会 基 于 索 引 矩 阵Index中 对 应 的 一 列 数 据 进 行4选2,索 引 矩 阵 分 形 格 式 及 生 成 方 式 请 参 考4选2稀 疏 索 引 矩 阵,选 择 算 法 参 考矩 阵A稀 疏 选 择 算 法 说 明;经 过 选 择 处 理 后 的 矩 阵A分 形 变 成(16, C0),矩 阵B原 始 分 形 为(C0, 16),接 下 来 会 执 行 普 通Mmad运 算,即 矩 阵A中 一 行 和 矩 阵B中 一 列 完 成 内 积 运 算 得 到 结 果 矩 阵C中 对 应 一 个 元 素。
索 引 矩 阵 经 过LoadDataWithSparse指 令 后 存 储 于Cube上 内 置 的 专 用buffer空 间,数 据 类 型 为uint8,分 形 格 式 为 小n大Z,对 应 上 图 中 的 分 形 大 小 为(32,16)。每 一 个uint8类 型 的 索 引 元 素 由4个uint2的 原 始 数 据 组 成,每 两 个2位 索 引 数 据 可 对 应4位 原 始 矩 阵A。针 对 每 一 组2个 索 引 数 据,A矩 阵 的4个 元 素 的 选 择 过 滤 规 则 示 例 如 下 表:
第 一 个 索 引 数 据0用 于 指 示 前3个 元 素 中 第1个 非 零 元 素 的 相 对 位 置。
第 二 个 索 引 数 据1用 于 指 示 第2个 非 零 元 素 在 后3个 元 素 中 的 相 对 位 置。
其 中,“-”表 示 不 关 心 该 位 置 上 的 值,即 会 被 过 滤。
表 2 矩 阵A选 择 过 滤 规 则 表
索 引 数 据0 索 引 数 据1 元 素0 元 素1 元 素2 元 素3 2’b10 2’b10 - - X Y 2’b01 2’b10 - X - Y 2’b00 2’b10 X - - Y 2’b01 2’b01 - X Y - 2’b00 2’b01 X - Y - 2’b00 2’b00 X Y - - 2’b00 2’b10 X - - Y 2’b10 2’b00 - X Y - 2’b01 2’b00 - X/X - - 2’b00 2’b00 X Y - - 2’b00 2’b00 X Y - - 图2展 示 了 一 个uint8类 型 的 索 引 元 素 对 应 选 择8个 原 始 矩 阵A元 素 的 算 法 模 型,最 后 输 出4个 选 择 后 的 矩 阵A元 素。
- 在 正 常 使 用 情 况 下,软 件 应 确 保 最 多 存 在 两 个 非 零 元 素。如 果 发 生 错 误,即 存 在 三 个 或 更 多 非 零 元 素 时,只 会 使 用 最 低 有 效 位(LSB)位 置 的 前 两 个 非 零 元 素。
- 上 表 中 使 用 的“-”表 示“不 关 心 该 位 置 上 的 值”,即 暗 示 可 能 存 在 三 个 或 更 多 非 零 元 素 的 情 况。
函 数 原 型
template <typename T = int32_t, typename U = int8_t, typename Std::enable_if<Std::is_same<PrimT<T>, int32_t>::value, bool>::type = true, typename Std::enable_if<Std::is_same<PrimT<U>, int8_t>::value, bool>::type = true>
__aicore__ inline void MmadWithSparse(const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<U>& filter, const MmadParams& mmadParams)
参 数 说 明
表 3 模 板 参 数 说 明
| 参 数 名 | 描 述 |
|---|---|
| T | dst的 数 据 类 型。 |
| U | fm、filter的 数 据 类 型。 当dst、fm、filter为 基 础 数 据 类 型 时,T必 须 为int32_t类 型,U必 须 为int8_t类 型,否 则 编 译 失 败。 |
表 4 参 数 说 明
| 参 数 名 称 | 输 入/输 出 | 含 义 |
|---|---|---|
| dst | 输 出 | 目 的 操 作 数,结 果 矩 阵,类 型 为LocalTensor,支 持 的 物 理 存 储 位 置 为L0C Buffer(TPosition:CO1)。 LocalTensor的 起 始 地 址 需 要256个 元 素(1024字 节)对 齐。 |
| fm | 输 入 | 源 操 作 数,左 矩 阵A,类 型 为LocalTensor,支 持 的 物 理 存 储 位 置 为L0A Buffer(TPosition: A2)。 LocalTensor的 起 始 地 址 需 要512字 节 对 齐。 |
| filter | 输 入 | 源 操 作 数,右 矩 阵B,类 型 为LocalTensor,支 持 的 物 理 存 储 位 置 为L0B Buffer(TPosition:B2)。 LocalTensor的 起 始 地 址 需 要512字 节 对 齐。 |
| mmadParams | 输 入 | 矩 阵 乘 相 关 参 数,类 型 为MmadParams。 具 体 定 义 请 参 考${INSTALL_DIR}/include/ascendc/basic_api/interface/kernel_struct_mm.h,${INSTALL_DIR}请 替 换 为CANN软 件 安 装 后 文 件 存 储 路 径。 参 数 说 明 请 参 考表5。 |
表 5 MmadParams结 构 体 内 参 数 说 明(Sparse场 景)
| 参 数 名 称 | 含 义 |
|---|---|
| m | 左 矩 阵Height,取 值 范 围:m∈[0,4095]。默 认 值 为0。 |
| n | 右 矩 阵Width,取 值 范 围:n∈[0,4095]。默 认 值 为0。 |
| k | 左 矩 阵Width、右 矩 阵Height,取 值 范 围:k∈[0,4095]。默 认 值 为0。 |
| cmatrixInitVal | 是 否 使 能C矩 阵 默 认 初 始 化 清 零 操 作。默 认 值true。 true:C矩 阵 默 认 初 始 化 为0;false:C矩 阵 不 进 行 默 认 操 作,通 过 设 置cmatrixSource参 数 进 行 初 始 化。 |
| cmatrixSource | 配 置C矩 阵 初 始 值 是 否 来 源 于BT Buffer。默 认 值 为false。 false:不 对L0C进 行 初 始 化 操 作; true:使 用BT Buffer(TPosition:C2)的 数 据 对L0C进 行 初 始 化 操 作。 Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品,支 持 配 置 为true/false。 Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品,支 持 配 置 为true/false。 Atlas 200I/500 A2 推 理 产 品,支 持 配 置 为true/false。 注 意:带bias输 入 的 接 口 配 置 该 参 数 无 效,会 根 据bias输 入 的 位 置 来 判 断C矩 阵 初 始 值 是 否 来 源 于BT Buffer。 |
| isBias | 该 参 数 废 弃,新 开 发 内 容 不 要 使 用 该 参 数。如 果 需 要 累 加 初 始 矩 阵,请 使 用 带bias的 接 口 来 实 现;也 可 以 通 过cmatrixInitVal和cmatrixSource参 数 配 置C矩 阵 的 初 始 值 来 源 来 实 现。推 荐 使 用 带bias的 接 口,相 比 于 配 置cmatrixInitVal和cmatrixSource参 数 更 加 简 单 方 便。 配 置 是 否 需 要 累 加 初 始 矩 阵,默 认 值 为false,取 值 说 明 如 下: false:矩 阵 乘,无 需 累 加 初 始 矩 阵,C = A * B。true:矩 阵 乘 加,需 要 累 加 初 始 矩 阵,C += A * B。 |
| unitFlag | unitFlag是 一 种Mmad指 令 和Fixpipe指 令 细 粒 度 的 并 行,使 能 该 功 能 后,硬 件 每 计 算 完 一 个 分 形,计 算 结 果 就 会 被 搬 出。取 值 说 明 如 下: 0(2'b00):不 使 能unitFlag; 1(2'b01):保 留 值; 2(2'b10):使 能unitFlag,硬 件 执 行 完 指 令 之 后,不 复 位 单 元 标 记 位; 3(2'b11):使 能unitFlag,硬 件 执 行 完 指 令 之 后,复 位 单 元 标 记 位。 使 能 该 功 能 时,须 将Mmad指 令 和Fixpipe指 令 的unitFlag值 设 置 为2或3。 该 参 数 仅 支 持 如 下 型 号: Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品 Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品 参 数 设 置 方 案 和 特 性 细 节 可 参 考: UnitFlag |
| kDirectionAlign | Sparse场 景 本 开 关 默 认 为false,不 支 持 配 置 为true。K方 向 对 齐 的 核 心 功 能 是 通 过kDirectionAlign 参 数 控 制 在 使 用float数 据 类 型 时,L0A和L0B矩 阵 在K方 向 上 的 对 齐 方 式。 |
| fmOffset | 左 矩 阵offset(整 个 左 矩 阵 对 应 一 个 值),支 持Scalar(应 与src_fm.dtype一 致)/立 即 数,默 认0。 注:未 使 用,兼 容 旧 款 产 品 接 口 传 入,Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品 及 往 后 产 品 不 做 处 理。 |
| enSsparse | 使 能 结 构 化 稀 疏 特 性,默 认false; 注:未 使 用,兼 容 旧 款 产 品 接 口 传 入,Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品 及 往 后 产 品 不 做 处 理。 |
| enWinogradA | 指 示 矩 阵a是 否 通 过winograd_feature_map_transform() 生 成,用 于 支 持winograd特 性,bool类 型,默 认false; 注:未 使 用,兼 容 旧 款 产 品 接 口 传 入,Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品 及 往 后 产 品 不 做 处 理。 |
| enWinogradB | 指 示 矩 阵b是 否 通 过winograd_weight_transform()生 成,用 于 支 持winograd特 性,bool类 型,默 认false; 注:未 使 用,兼 容 旧 款 产 品 接 口 传 入,Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品 及 往 后 产 品 不 做 处 理。 |
数 据 类 型
表 6 A、B、C支 持 的 精 度 类 型 组 合(Atlas 200I/500 A2 推 理 产 品)(Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品)(Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品)
| 左 矩 阵A | 右 矩 阵B | 结 果 矩 阵C |
|---|---|---|
| int8_t | int8_t | int32_t |
返 回 值 说 明
无
约 束 说 明
不 同 矩 阵 对 于 存 储 位 置 的 约 束:
- 结 果 矩 阵C只 支 持 位 于 物 理 存 储 位 置 为L0C Buffer(TPosition:CO1)
- 左 矩 阵A只 支 持 位 于 物 理 存 储 位 置 为L0A Buffer(TPosition:A2)
- 右 矩 阵B只 支 持 位 于 物 理 存 储 位 置 为L0B Buffer(TPosition:B2)
原 始 稀 疏 矩 阵B每4个 元 素 中 应 保 证 最 多2个 非 零 元 素,如 果 存 在3个 或 更 多 非 零 元 素,则 仅 使 用 前2个 非 零 元 素。
当M、K、N中 的 任 意 一 个 值 为0时,表 示 指 令 不 会 执 行,该 接 口 将 被 视 为NOP(空 操 作)。
MmadWithSparse接 口 不 支 持Gemv模 式。
其 他 特 殊 场 景 约 束 可 参 考Mmad接 口 约 束 说 明。
调 用 示 例
完 整 使 用 样 例 请 参 见MmadWithSparse样 例。
AscendC::LocalTensor<int8_t> a1Local(AscendC::TPosition::A1, a1Addr, aSize);
AscendC::LocalTensor<int8_t> a2Local(AscendC::TPosition::A2, a2Addr, aSize);
AscendC::LocalTensor<int8_t> b1Local(AscendC::TPosition::B1, b1Addr, bSize);
AscendC::LocalTensor<uint8_t> idxB1Local(AscendC::TPosition::B1, idxB1Addr, bSize / 4);
AscendC::LocalTensor<int8_t> b2Local(AscendC::TPosition::B2, b2Addr, bSize);
AscendC::LocalTensor<int32_t> cLocal(AscendC::TPosition::CO1, cAddr, cSize);
// GM->L1,将 原 始 矩 阵a,稠 密 化 矩 阵b与 对 应idx矩 阵 搬 运 至L1
CopyIn(a1Local, b1Local, idxB1Local);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(EVENT_ID0);
// L1->L0, 将 原 始 矩 阵a,稠 密 化 矩 阵b与 对 应idx矩 阵 搬 运 至L0
SplitA(a1Local, a2Local);
SplitB(b2Local, b1Local, idxB1Local);
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
// mmad 需 要 指 定 矩 阵 的 维 度 进 行 计 算
uint32 m = 128;
uint32 k = 64;
uint32 n = 128;
AscendC::MmadWithSparse(c1Local, a2Local, b2Local, { m, n, k, false, 0, false, false, false });
