Compares
产 品 支 持 情 况
功 能 说 明
头 文 件 路 径 为:"basic_api/kernel_operator_vec_cmpsel_intf.h"。
逐 元 素 比 较 一 个tensor中 的 元 素 和 另 一 个Scalar的 大 小,如 果 比 较 后 的 结 果 为 真,则 输 出 结 果 的 对 应 比 特 位 为1,否 则 为0。
支 持 多 种 比 较 模 式:
LT:小 于(less than)
GT:大 于(greater than)
GE:大 于 或 等 于(greater than or equal to)
EQ:等 于(equal to)
NE:不 等 于(not equal to)
LE:小 于 或 等 于(less than or equal to)
函 数 原 型
tensor前n个 数 据 计 算
C++template <typename T, typename U> __aicore__ inline void Compares(const LocalTensor<U>& dst, const LocalTensor<T>& src0, const T src1Scalar, CMPMODE cmpMode, uint32_t count)tensor高 维 切 分 计 算
mask逐bit模 式
C++template <typename T, typename U, bool isSetMask = true> __aicore__ inline void Compares(const LocalTensor<U>& dst, const LocalTensor<T>& src0, const T src1Scalar, CMPMODE cmpMode, const uint64_t mask[], uint8_t repeatTime, const UnaryRepeatParams& repeatParams)mask连 续 模 式
C++template <typename T, typename U, bool isSetMask = true> __aicore__ inline void Compares(const LocalTensor<U>& dst, const LocalTensor<T>& src0, const T src1Scalar, CMPMODE cmpMode, const uint64_t mask, uint8_t repeatTime, const UnaryRepeatParams& repeatParams)
参 数 说 明
模 板 参 数 及 接 口 参 数 说 明
表 1 模 板 参 数 说 明
| 参 数 名 | 描 述 |
|---|---|
| T | 源 操 作 数 数 据 类 型。 |
| U | 目 的 操 作 数 数 据 类 型。 |
| isSetMask | 是 否 在 接 口 内 部 设 置mask。 • true,表 示 在 接 口 内 部 设 置mask。 • false,表 示 在 接 口 外 部 设 置Mask,开 发 者 需 要 使 用SetVectorMask接 口 设 置mask值。这 种 模 式 下,本 接 口 入 参 中 的mask值 必 须 设 置 为 占 位 符MASK_PLACEHOLDER。 |
表 2 接 口 参 数 说 明
| 参 数 名 称 | 输 入/输 出 | 描 述 |
|---|---|---|
| dst | 输 出 | 目 的 操 作 数。 类 型 为LocalTensor,支 持 的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的 起 始 地 址 需 要32字 节 对 齐。 dst用 于 存 储 比 较 结 果,将dst中uint8_t类 型 的 数 据 按 照bit位 展 开,由 左 至 右 依 次 表 征 对 应 位 置 的src0和src1Scalar的 比 较 结 果,如 果 比 较 后 的 结 果 为 真,则 对 应 比 特 位 为1,否 则 为0。 |
| src0 | 输 入 | 源 操 作 数。 类 型 为LocalTensor,支 持 的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的 起 始 地 址 需 要32字 节 对 齐。 |
| src1Scalar | 输 入 | 源 操 作 数,Scalar标 量。数 据 类 型 和src0保 持 一 致。 |
| cmpMode | 输 入 | CMPMODE类 型,表 示 比 较 模 式,包 括EQ,NE,GE,LE,GT,LT。 • LT:src0小 于(less than)src1Scalar • GT:src0大 于(greater than)src1Scalar • GE:src0大 于 或 等 于(greater than or equal to)src1Scalar • EQ:src0等 于(equal to)src1Scalar • NE:src0不 等 于(not equal to)src1Scalar • LE:src0小 于 或 等 于(less than or equal to)src1Scalar |
| mask/mask[] | 输 入 | mask用 于 控 制 每 次 迭 代 内 参 与 计 算 的 元 素。详 细 设 置 参 考掩 码。 |
| repeatTime | 输 入 | 重 复 迭 代 次 数。矢 量 计 算 单 元,每 次 读 取 连 续 的256Bytes数 据 进 行 计 算,为 完 成 对 输 入 数 据 的 处 理,必 须 通 过 多 次 迭 代(repeat)才 能 完 成 所 有 数 据 的 读 取 与 计 算。repeatTime表 示 迭 代 的 次 数。 关 于 该 参 数 的 具 体 描 述 请 参 考高 维 切 分。 |
| repeatParams | 输 入 | 控 制 操 作 数 地 址 步 长 的 参 数。UnaryRepeatParams类 型,包 含 操 作 数 相 邻 迭 代 间 相 同DataBlock的 地 址 步 长,操 作 数 同 一 迭 代 内 不 同DataBlock的 地 址 步 长 等 参 数。 相 邻 迭 代 间 的 地 址 步 长 参 数 说 明 请 参 考repeatStride;同 一 迭 代 内DataBlock的 地 址 步 长 参 数 说 明 请 参 考dataBlockStride。 |
| count | 输 入 | 参 与 计 算 的 元 素 个 数。设 置count时,需 要 保 证count个 元 素 所 占 空 间256字 节 对 齐。未 对 齐 部 分 元 素 不 参 与 计 算,仅 完 整 对 齐 块 有 效。 |
mask/mask[]参 数 说 明
- 针 对Ascend 950PR/Ascend 950DT,设 置 有 效。
- 针 对Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品,保 留 参 数,设 置 无 效。
- 针 对Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品,保 留 参 数,设 置 无 效。
- 针 对Atlas 200I/500 A2 推 理 产 品,设 置 有 效。
- 针 对Atlas 推 理 系 列 产 品AI Core,保 留 参 数,设 置 无 效。
- 针 对Kirin X90,保 留 参 数,设 置 无 效。
- 针 对Kirin 9030,保 留 参 数,设 置 无 效。
数 据 类 型
- 针 对Ascend 950PR/Ascend 950DT
- T支 持 的 数 据 类 型 为:int8_t、uint8_t、int16_t、uint16_t、half、bfloat16_t、int32_t、uint32_t、float、int64_t、uint64_t、double。
- U支 持 的 数 据 类 型 为:uint8_t。
- 针 对Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品
- T支 持 的 数 据 类 型 为:half(所 有CMPMODE都 支 持)、float(所 有CMPMODE都 支 持)、int32_t(只 支 持CMPMODE::EQ)。
- U支 持 的 数 据 类 型 为:uint8_t。
- 针 对Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品
- T支 持 的 数 据 类 型 为:half(所 有CMPMODE都 支 持)、float(所 有CMPMODE都 支 持)、int32_t(只 支 持CMPMODE::EQ)。
- U支 持 的 数 据 类 型 为:uint8_t。
- 针 对Atlas 200I/500 A2 推 理 产 品
- T支 持 的 数 据 类 型 为:half、float。
- U支 持 的 数 据 类 型 为:uint8_t。
- 针 对Atlas 推 理 系 列 产 品AI Core
- T支 持 的 数 据 类 型 为:half、float。
- U支 持 的 数 据 类 型 为:uint8_t。
- 针 对Kirin X90
- T支 持 的 数 据 类 型 为:half(所 有CMPMODE都 支 持)、float(所 有CMPMODE都 支 持)、int32_t(只 支 持CMPMODE::EQ)。
- U支 持 的 数 据 类 型 为:uint8_t。
- 针 对Kirin 9030
- T支 持 的 数 据 类 型 为:half(所 有CMPMODE都 支 持)、float(所 有CMPMODE都 支 持)、int32_t(只 支 持CMPMODE::EQ)。
- U支 持 的 数 据 类 型 为:uint8_t。
返 回 值 说 明
无
约 束 说 明
操 作 数 地 址 对 齐 要 求 请 参 见通 用 地 址 对 齐 约 束。
dst按 照 小 端 顺 序 排 序 成 二 进 制 结 果,对 应src0中 相 应 位 置 的 数 据 比 较 结 果。
使 用tensor前n个 数 据 参 与 计 算 的 接 口,设 置count时,需 要 保 证count个 元 素 所 占 空 间256字 节 对 齐。
调 用 示 例
本 样 例 中,源 操 作 数src0Local存 储 了256个float类 型 的 数 据。样 例 实 现 的 功 能 为,对src0Local中 的 元 素 和src1Local.GetValue(0)中 的 数 据 进 行 比 较,如 果src0Local中 的 元 素 小 于src1Local.GetValue(0)中 的 元 素,dstLocal结 果 中 对 应 的 比 特 位 置1;反 之,则 置0。dst结 果 使 用uint8_t类 型 数 据 存 储。
完 整 的 调 用 样 例 可 参 考Compare类 样 例场 景 三。
tensor前n个 数 据 计 算 接 口 样 例
C++AscendC::Compares(dstLocal, src0Local, src1Scalar, AscendC::CMPMODE::LT, srcDataSize);tensor高 维 切 分 计 算-mask连 续 模 式
C++uint64_t mask = 256 / sizeof(float); // 256为 每 个 迭 代 处 理 的 字 节 数 int repeat = 4; AscendC::UnaryRepeatParams repeatParams = { 1, 1, 8, 8 }; // repeat = 4, 64 elements one repeat, 256 elements total // dstBlkStride, srcBlkStride = 1, no gap between blocks in one repeat // dstRepStride, srcRepStride = 8, no gap between repeats AscendC::Compares(dstLocal, src0Local, src1Scalar, AscendC::CMPMODE::LT, mask, repeat, repeatParams);tensor高 维 切 分 计 算-mask逐bit模 式
C++uint64_t mask[2] = { UINT64_MAX, 0}; int repeat = 4; AscendC::UnaryRepeatParams repeatParams = { 1, 1, 8, 8 }; // repeat = 4, 64 elements one repeat, 256 elements total // srcBlkStride, = 1, no gap between blocks in one repeat // dstRepStride, srcRepStride = 8, no gap between repeats AscendC::Compares(dstLocal, src0Local, src1Scalar, AscendC::CMPMODE::LT, mask, repeat, repeatParams);
结 果 示 例 如 下:
输 入 数 据(src0Local):
[ 16.604824 45.069473 65.108345 -59.68792 21.043684
75.90726 -27.046307 -40.10546 -5.933778 83.56574
58.87062 -12.77814 28.17882 62.549377 -22.310246
-67.69001 81.06072 69.988945 69.10082 -6.667376
96.20256 18.532446 -66.56364 -32.531246 49.980835
35.668995 -16.847628 1.3236234 10.0143795 43.878166
26.628105 31.774637 47.9279 79.7291 -54.09651
95.49459 -18.404795 -86.84594 9.406091 -79.54437
0.49116692 -48.151714 -12.97062 -99.89055 23.475513
-27.366564 -69.229675 83.613304 52.14729 40.98426
-23.422009 -53.386215 1.6576616 -62.36946 54.693733
66.2058 -4.0042257 -25.351263 1.0000885 -6.458584
25.447659 71.647316 82.31162 -7.7359715 28.107353
-79.22045 20.292479 67.7434 -76.054085 -7.754251
38.632687 -4.8460293 -69.791954 -57.574455 -99.96178
-73.29611 -68.57477 98.200035 -55.30482 -55.590027
79.53274 -1.862139 -37.60953 -12.225406 -35.2875
-24.047668 -66.07609 21.9362 80.603516 28.928387
26.579298 97.6649 78.94723 -89.86824 73.29788
18.957182 -73.87053 -23.508097 -51.02931 39.158726
-96.61422 -41.192455 54.973663 47.58695 -3.9818003
-81.05088 -67.62415 -17.491713 -34.916042 -95.993744
-3.4719822 -55.956417 6.223455 12.240832 15.055512
94.70584 -13.33949 -50.46866 54.612816 -28.521824
-87.63997 59.53054 41.000504 -31.266075 -31.419422
-32.940186 53.449913 50.012768 -13.663364 40.931725
-68.80396 -86.63726 76.866585 -83.76385 3.7227867
58.443035 -74.333046 -92.52674 24.249512 -7.935491
24.197245 -34.85033 67.854645 72.65312 13.622443
-70.94266 15.401667 -9.332295 -86.61463 72.659676
-83.63352 9.279887 81.037964 46.285606 -12.967846
-48.72901 69.07614 -40.355286 -94.257034 -45.514374
24.966864 -9.657219 61.803864 -83.09603 77.769035
-97.44226 -89.71987 -53.969315 43.892918 73.88798
67.23104 36.65282 -93.70069 -87.48934 -27.679005
-36.825226 -30.117033 -41.579655 -97.325325 77.1972
-49.883194 33.061394 -63.844925 89.74327 64.549416
80.16943 73.26347 -87.307175 -96.62777 81.8532
7.5365276 28.357092 59.896378 -15.95738 -77.42723
0.03529428 -20.263502 45.59324 -90.160835 89.478004
57.608685 60.71819 45.8125 39.94484 -48.77375
-56.897358 5.2580256 -6.937905 -49.80309 -42.527523
72.91772 89.53271 -62.181187 18.490683 -69.40782
6.141204 13.938042 75.312515 21.766457 -8.157599
55.53147 -30.789118 -12.087165 82.435684 23.4884
82.73172 -2.026827 -8.124383 -10.707488 -74.32759
-54.702602 14.209252 93.73145 98.93554 52.803623
32.200726 41.823833 90.193756 -34.512424 -85.64022
97.47763 33.353424 94.84875 23.03139 99.97347
-72.47978 19.51753 -88.28579 -88.70721 -18.659292
-79.5277 62.90431 21.837631 45.989056 -9.62086
11.4855795 ]
输 入 数 据(src1Scalar):
[-95.16087 -71.4676 51.817818 -12.358237 96.60704 -12.0067835
-44.128048 7.5811195 84.61196 -60.303513 21.470125 98.96244
18.262054 80.014244 48.37233 -75.03457 ]
输 出 数 据(dstLocal):
[ 0 0 0 0 0 8 0 0 0 4 0 0 16 32 0 0 0 0 0 0 32 0 4 16
0 0 0 0 0 0 0 0]