asc_mmad_mx
产 品 支 持 情 况
| 产 品 | 是 否 支 持 |
|---|---|
Ascend 950PR/Ascend 950DT | √ |
功 能 说 明
完 成 包 含 放 缩 功 能 的 矩 阵 乘 加 操 作。计 算 公 式 如 下:
$$ c_{matrix} = (a_{matrix} * b_{matrix}) + c_{matrix} $$
函 数 原 型
常 规 计 算
C++__aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp4x2_e1m2_t* a_matrix, __cb__ fp4x2_e1m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp4x2_e1m2_t* a_matrix, __cb__ fp4x2_e2m1_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp4x2_e2m1_t* a_matrix, __cb__ fp4x2_e1m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp4x2_e2m1_t* a_matrix, __cb__ fp4x2_e2m1_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp8_e4m3fn_t* a_matrix, __cb__ fp8_e4m3fn_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp8_e4m3fn_t* a_matrix, __cb__ fp8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp8_e5m2_t* a_matrix, __cb__ fp8_e4m3fn_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx(__cc__ float* c_matrix, __ca__ fp8_e5m2_t* a_matrix, __cb__ fp8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val)同 步 计 算
C++__aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp4x2_e1m2_t* a_matrix, __cb__ fp4x2_e1m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp4x2_e1m2_t* a_matrix, __cb__ fp4x2_e2m1_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp4x2_e2m1_t* a_matrix, __cb__ fp4x2_e1m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp4x2_e2m1_t* a_matrix, __cb__ fp4x2_e2m1_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp8_e4m3fn_t* a_matrix, __cb__ fp8_e4m3fn_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp8_e4m3fn_t* a_matrix, __cb__ fp8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp8_e5m2_t* a_matrix, __cb__ fp8_e4m3fn_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val) __aicore__ inline void asc_mmad_mx_sync(__cc__ float* c_matrix, __ca__ fp8_e5m2_t* a_matrix, __cb__ fp8_e5m2_t* b_matrix, uint16_t left_height, uint16_t n_dim, uint16_t right_width, uint8_t unit_flag, bool disable_gemv, bool c_matrix_source, bool c_matrix_init_val)
参 数 说 明
| 参 数 名 | 输 入/输 出 | 描 述 |
|---|---|---|
| c_matrix | 输 出 | 目 的 操 作 数,结 果 矩 阵。 |
| a_matrix | 输 入 | 源 操 作 数,左 矩 阵A。 |
| b_matrix | 输 入 | 源 操 作 数,右 矩 阵B。 |
| left_height | 输 入 | 左 矩 阵height ,取 值 范 围 为[0,4095]。 |
| n_dim | 输 入 | 左 矩 阵width、右 矩 阵height,取 值 范 围 为[0,4095]。 |
| right_width | 输 入 | 右 矩 阵width,取 值 范 围 为[0,4095]。 |
| unit_flag | 输 入 | unit_flag是 一 种asc_mmad_mx接 口 细 粒 度 的 并 行,开 启 该 功 能 后,硬 件 每 计 算 完 一 个 分 形,计 算 结 果 就 会 被 搬 出,该 功 能 不 适 用 于L0C Buffer累 加 的 场 景。取 值 说 明 如 下: • 0:保 留 值; • 2:开 启unit_flag,硬 件 执 行 完 指 令 后,不 会 关 闭unit_flag功 能; • 3:开 启unit_flag,硬 件 执 行 完 指 令 后,会 关 闭unit_flag功 能。 开 启 该 功 能 时,矩 阵 计 算 的unit_flag在 最 后 一 个 分 形 设 置 为3,其 余 分 形 计 算 设 置 为2即 可。 |
| disable_gemv | 输 入 | 是 否 关 闭GEMV模 式,false表 示 开 启GEMV模 式,true表 示 关 闭GEMV模 式。 GEMV(General Matrix-Vector Multiplication)表 示 实 现 矩 阵 和 向 量 的 乘 积。当left_height=1时,开 启GEMV后,从L0A Buffer读 取 数 据 时,将 以ND格 式 进 行 读 取,而 不 会 将 其 视 为ZZ格 式。 |
| c_matrix_source | 输 入 | 配 置C矩 阵 初 始 值 是 否 来 源 于BiasTable(存 放Bias的 硬 件 缓 存 区)。取 值 说 明 如 下: • true:来 源 于BiasTable。 • false:来 源 于L0C。 |
| c_matrix_init_val | 输 入 | 配 置C矩 阵 初 始 值 是 否 为0。取 值 说 明 如 下: • true:C矩 阵 初 始 值 为0。 • false:C矩 阵 初 始 值 通 过c_matrix_source参 数 进 行 配 置。 |
返 回 值 说 明
无
流 水 类 型
PIPE_M
约 束 说 明
- 当left_height、right_width、n_dim中 的 任 意 一 个 值 为0时,该 指 令 不 会 被 执 行。
- 当 开 启GEMV模 式,即disable_gemv=false时,必 须 要 满 足 left_height=1。
- 操 作 数 地 址 对 齐 约 束 请 参 考通 用 地 址 对 齐 约 束。
调 用 示 例
C++
// total_length指 参 与 搬 运 的 数 据 总 长 度
constexpr uint64_t total_length = 128;
// 以 下 三 个 参 数 分 别 对 应 矩 阵c,a,b的 地 址
__cc__ float c_matrix[total_length];
__ca__ fp8_e4m3fn_t a_matrix[total_length];
__cb__ fp8_e4m3fn_t b_matrix[total_length];
uint16_t left_height = 16;
uint16_t n_dim = 16;
uint16_t right_width =16;
uint8_t unit_flag = 0;
bool disable_gemv = false;
bool c_matrix_source = false;
bool c_matrix_init_val = true;
// 函 数 调 用
asc_mmad_mx_sync(c_matrix, a_matrix, b_matrix, left_height, n_dim, right_width, unit_flag, disable_gemv, c_matrix_source, c_matrix_init_val);