Skip to content
版 本

Transpose

产 品 支 持 情 况

产 品

是 否 支 持

Ascend 950PR/Ascend 950DT

Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品

Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品

Atlas 200I/500 A2 推 理 产 品

Atlas 推 理 系 列 产 品AI Core

Atlas 推 理 系 列 产 品Vector Core

x

Atlas 训 练 系 列 产 品

Kirin X90

Kirin 9030

功 能 说 明

头 文 件 路 径 为:"basic_api/kernel_operator_vec_transpose_intf.h"。

Transpose接 口 用 于 实 现16*16的 二 维 矩 阵 数 据 块 转 置 或 者[N,C,H,W]与[N,H,W,C]数 据 格 式 互 相 转 换。

16*16的 普 通 转 置 接 口 计 算 原 理 和 参 考 伪 代 码 如 下:

Python
import numpy as np

src = np.random.randn(16, 16).astype(np.float16)
dst = src.T

[N,C,H,W]与[N,H,W,C]数 据 格 式 互 相 转 换 的 增 强 转 置 计 算 原 理 和 参 考 伪 代 码 如 下:

Python
import numpy as np

# transposeParams.transposeType : TRANSPOSE_NCHW2NHWC
src_nchw = np.random.randn(transposeParams.nSize, transposeParams.cSize, transposeParams.hSize, transposeParams.wSize).astype(np.float16)
dst_nhwc = np.transpose(src_nchw, axes=(0,2,3,1))

# transposeParams.transposeType : TRANSPOSE_NHWC2NCHW
src_nhwc = np.random.randn(transposeParams.nSize, transposeParams.hSize, transposeParams.wSize, transposeParams.cSize).astype(np.float16)
dst_nchw = np.transpose(src_nhwc, axes=(0,3,1,2))

函 数 原 型

  • 普 通 转 置,支 持16*16的 二 维 矩 阵 数 据 块 进 行 转 置。

    C++
    template <typename T>
    __aicore__ inline void Transpose(const LocalTensor<T>& dst, const LocalTensor<T>& src)
    
  • 增 强 转 置,支 持16*16的 二 维 矩 阵 数 据 块 转 置,支 持[N,C,H,W]与[N,H,W,C]互 相 转 换。

    C++
    template <typename T>
    __aicore__ inline void Transpose(const LocalTensor<T>& dst, const LocalTensor<T> &src, const LocalTensor<uint8_t> &sharedTmpBuffer, const TransposeParamsExt &transposeParams)
    

参 数 说 明

模 板 参 数 说 明

参 数 名描 述
T操 作 数 的 数 据 类 型。

接 口 参 数 说 明

参 数 名 称输 入/输 出含 义
dst输 出目 的 操 作 数。
类 型 为LocalTensor,支 持 的TPosition为VECIN/VECCALC/VECOUT(存 储 位 置 为Unified Buffer)。
LocalTensor的 起 始 地 址 需 要32字 节 对 齐。
src输 入源 操 作 数。
类 型 为LocalTensor,支 持 的TPosition为VECIN/VECCALC/VECOUT(存 储 位 置 为Unified Buffer)。
LocalTensor的 起 始 地 址 需 要32字 节 对 齐。
数 据 类 型 需 要 与dst保 持 一 致。
sharedTmpBuffer输 入共 享 的 临 时Buffer,sharedTmpBuffer的 大 小 参 考表 sharedTmpBuffer所 需 的 内 存
transposeParams输 入控 制Transpose的 数 据 结 构。结 构 体 内 包 含:输 入 的shape信 息 和transposeType参 数。该 数 据 结 构 的 定 义 请 参 考表 TransposeParamsExt结 构 体 内 参 数 说 明
struct TransposeParamsExt {
__aicore__ TransposeParamsExt() {}
__aicore__ TransposeParamsExt(const uint16_t nSizeIn, const uint16_t cSizeIn, const uint16_t hSizeIn,
const uint16_t wSizeIn, const TransposeType transposeTypeIn)
: nSize(nSizeIn),
cSize(cSizeIn),
hSize(hSizeIn),
wSize(wSizeIn),
transposeType(transposeTypeIn)
{}
uint16_t nSize = 0;
uint16_t cSize = 0;
uint16_t hSize = 0;
uint16_t wSize = 0;
TransposeType transposeType = TransposeType::TRANSPOSE_ND2ND_B16;
};

TransposeParamsExt结 构 体 内 参 数 说 明

参 数 名 称含 义
nSizen轴 长 度。默 认 值 为0。
•二 维 矩 阵 数 据 块 转 置,无 需 传 入,传 入 数 值 无 效。
•[N,C,H,W]与[N,H,W,C]数 据 格 式 互 相 转 换,取 值 范 围:nSize∈[0, 65535]。
cSizec轴 长 度。默 认 值 为0。
•二 维 矩 阵 数 据 块 转 置,无 需 传 入,传 入 数 值 无 效。
•[N,C,H,W]与[N,H,W,C]数 据 格 式 互 相 转 换,取 值 范 围:cSize∈[0, 4095]。
hSizeh轴 长 度。默 认 值 为0。
•二 维 矩 阵 数 据 块 转 置,固 定 传 入16。
•[N,C,H,W]与[N,H,W,C]数 据 格 式 互 相 转 换,取 值 范 围:hSize * wSize ∈[0, 4095],hSize * wSize * sizeof(T)需 要 保 证32B对 齐。
wSizew轴 长 度。默 认 值 为0。
•二 维 矩 阵 数 据 块 转 置,固 定 传 入16。
•[N,C,H,W]与[N,H,W,C]数 据 格 式 互 相 转 换,取 值 范 围:hSize * wSize ∈[0, 4095],hSize * wSize * sizeof(T)需 要 保 证32B对 齐。
transposeType数 据 排 布 及reshape的 类 型,类 型 为TransposeType枚 举 类。默 认 值 为TRANSPOSE_ND2ND_B16。
enum class TransposeType : uint8_t {
TRANSPOSE_TYPE_NONE, // API不 做 任 何 处 理
TRANSPOSE_NZ2ND_0213, // 当 前 不 支 持
TRANSPOSE_NZ2NZ_0213, // 当 前 不 支 持
TRANSPOSE_NZ2NZ_012_WITH_N, // 当 前 不 支 持
TRANSPOSE_NZ2ND_012_WITH_N, // 当 前 不 支 持
TRANSPOSE_NZ2ND_012_WITHOUT_N, // 当 前 不 支 持
TRANSPOSE_NZ2NZ_012_WITHOUT_N, // 当 前 不 支 持
TRANSPOSE_ND2ND_ONLY, // 当 前 不 支 持
TRANSPOSE_ND_UB_GM, // 当 前 不 支 持
TRANSPOSE_GRAD_ND_UB_GM, // 当 前 不 支 持
TRANSPOSE_ND2ND_B16, // [16,16]二 维 矩 阵 转 置
TRANSPOSE_NCHW2NHWC, // [N,C,H,W]->[N,H,W,C],
TRANSPOSE_NHWC2NCHW // [N,H,W,C]->[N,C,H,W]
};

Ascend 950PR/Ascend 950DT sharedTmpBuffer所 需 的 内 存

transposeTypesharedTmpBuffer所 需 的 大 小
TRANSPOSE_ND2ND_B16不 需 要 临 时Buffer。
TRANSPOSE_NCHW2NHWC临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize + 2) * h0 * w0 * sizeof(type);
TRANSPOSE_NHWC2NCHW临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize * 2 + 1) * h0 * w0 * sizeof(type);

Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品sharedTmpBuffer所 需 的 内 存

transposeTypesharedTmpBuffer所 需 的 大 小
TRANSPOSE_ND2ND_B16不 需 要 临 时Buffer。
TRANSPOSE_NCHW2NHWC临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize + 2) * h0 * w0 * sizeof(type);
TRANSPOSE_NHWC2NCHW临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize * 2 + 1) * h0 * w0 * sizeof(type);

Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品sharedTmpBuffer所 需 的 内 存

transposeTypesharedTmpBuffer所 需 的 大 小
TRANSPOSE_ND2ND_B16不 需 要 临 时Buffer。
TRANSPOSE_NCHW2NHWC临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize + 2) * h0 * w0 * sizeof(type);
TRANSPOSE_NHWC2NCHW临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize * 2 + 1) * h0 * w0 * sizeof(type);

Atlas 200I/500 A2 推 理 产 品sharedTmpBuffer所 需 的 内 存

transposeTypesharedTmpBuffer所 需 的 大 小
TRANSPOSE_ND2ND_B16不 需 要 临 时Buffer。

Atlas 推 理 系 列 产 品AI Core sharedTmpBuffer所 需 的 内 存

transposeTypesharedTmpBuffer所 需 的 大 小
TRANSPOSE_ND2ND_B16不 需 要 临 时Buffer。
TRANSPOSE_NCHW2NHWC不 需 要 临 时Buffer。
TRANSPOSE_NHWC2NCHW不 需 要 临 时Buffer。

Kirin X90 sharedTmpBuffer所 需 的 内 存

transposeTypesharedTmpBuffer所 需 的 大 小
TRANSPOSE_ND2ND_B16不 需 要 临 时Buffer。
TRANSPOSE_NCHW2NHWC临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize + 2) * h0 * w0 * sizeof(type);
TRANSPOSE_NHWC2NCHW临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize * 2 + 1) * h0 * w0 * sizeof(type);

Kirin 9030 sharedTmpBuffer所 需 的 内 存

transposeTypesharedTmpBuffer所 需 的 大 小
TRANSPOSE_ND2ND_B16不 需 要 临 时Buffer。
TRANSPOSE_NCHW2NHWC临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize + 2) * h0 * w0 * sizeof(type);
TRANSPOSE_NHWC2NCHW临 时Buffer的 大 小 按 照 下 述 计 算 规 则(伪 代 码)进 行 计 算。
auto h0 = 16; // 当 数 据 类 型 的 位 宽 为8时,h0 = 32;其 他 情 况 下,h0 = 16
auto w0 = 32 / sizeof(type); // type代 表 数 据 类 型
auto tmpBufferSize = (cSize * 2 + 1) * h0 * w0 * sizeof(type);

数 据 类 型

  • 普 通 转 置:

    Ascend 950PR/Ascend 950DT,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

    Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

    Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

    Atlas 200I/500 A2 推 理 产 品,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

    Atlas 推 理 系 列 产 品AI Core,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

    Atlas 训 练 系 列 产 品,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

    Kirin X90,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

    Kirin 9030,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

  • 增 强 转 置:

    • transposeType为TRANSPOSE_ND2ND_B16:

      Ascend 950PR/Ascend 950DT,操 作 数 支 持 的 数 据 类 型 为:int16_t、uint16_t、half。

      Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品,操 作 数 支 持 的 数 据 类 型 为:uint16_t。

      Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品,操 作 数 支 持 的 数 据 类 型 为:uint16_t。

      Atlas 200I/500 A2 推 理 产 品,操 作 数 支 持 的 数 据 类 型 为:uint16_t。

      Atlas 推 理 系 列 产 品AI Core,操 作 数 支 持 的 数 据 类 型 为:uint16_t。

    • transposeType为TRANSPOSE_NCHW2NHWC或TRANSPOSE_NHWC2NCHW:

      Ascend 950PR/Ascend 950DT,操 作 数 支 持 的 数 据 类 型 为:int8_t、uint8_t、fp4x2_e2m1_t、fp4x2_e1m2_t、hifloat8_t、fp8_e8m0_t、fp8_e5m2_t、fp8_e4m3fn_t、int4x2_t、int16_t、uint16_t、half、bfloat16_t、int32_t、uint32_t、float、complex32。

      Atlas A3 训 练 系 列 产 品/Atlas A3 推 理 系 列 产 品,操 作 数 支 持 的 数 据 类 型 为:int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。

      Atlas A2 训 练 系 列 产 品/Atlas A2 推 理 系 列 产 品,操 作 数 支 持 的 数 据 类 型 为:int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。

      Atlas 推 理 系 列 产 品AI Core,操 作 数 支 持 的 数 据 类 型 为:int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。

      Kirin X90,操 作 数 支 持 的 数 据 类 型 为:int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。

      Kirin 9030,操 作 数 支 持 的 数 据 类 型 为:int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。

返 回 值 说 明

无。

约 束 说 明

  • 操 作 数 地 址 对 齐 要 求 请 参 见Unified Buffer地 址 对 齐 约 束
  • 普 通 转 置 接 口 支 持src和dst复 用。
  • 增 强 转 置 接 口,transposeType为TRANSPOSE_ND2ND_B16时 支 持src和dst复 用,transposeType为TRANSPOSE_NCHW2NHWC、TRANSPOSE_NHWC2NCHW时 不 支 持src和dst复 用。
  • 二 维 矩 阵 数 据 块 转 置 时,nSize、cSize无 需 传 入,传 入 数 值 无 效;hSize、wSize固 定 传 入16。
  • 增 强 转 置 接 口,transposeType为TRANSPOSE_NCHW2NHWC、TRANSPOSE_NHWC2NCHW时,如 果nSize、cSize、hSize、wSize为0,不 会 执 行 计 算 操 作,不 会 对 目 的 操 作 数 进 行 写 入。
  • [N,C,H,W]与[N,H,W,C]数 据 格 式 互 相 转 换,参 数 取 值 范 围:nSize∈[0, 65535],cSize∈[0, 4095],hSize * wSize ∈[0, 4095],hSize * wSize * sizeof(T)需 要 保 证32B对 齐。
  • 转 置 增 强 接 口 中,入 参sharedTmpBuffer的 大 小 不 得 小 于 计 算 所 需 的 最 小 阈 值。

调 用 示 例

  • 普 通 接 口 调 用 示 例 片 段,完 整 片 段 请 参 考Transpose类 样 例场 景 一,该 示 例 对[16,16]的half类 型 矩 阵 进 行 转 置。

    C++
    // dstLocal:目 的 操 作 数tensor
    // srcLocal:源 操 作 数tensor
    AscendC::Transpose<half>(dstLocal, srcLocal);
    
    输 入 数 据src_gm:
    [[  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.
       14.  15.]
     [ 16.  17.  18.  19.  20.  21.  22.  23.  24.  25.  26.  27.  28.  29.
       30.  31.]
     [ 32.  33.  34.  35.  36.  37.  38.  39.  40.  41.  42.  43.  44.  45.
       46.  47.]
     [ 48.  49.  50.  51.  52.  53.  54.  55.  56.  57.  58.  59.  60.  61.
       62.  63.]
     [ 64.  65.  66.  67.  68.  69.  70.  71.  72.  73.  74.  75.  76.  77.
       78.  79.]
     [ 80.  81.  82.  83.  84.  85.  86.  87.  88.  89.  90.  91.  92.  93.
       94.  95.]
     [ 96.  97.  98.  99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109.
      110. 111.]
     [112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.
      126. 127.]
     [128. 129. 130. 131. 132. 133. 134. 135. 136. 137. 138. 139. 140. 141.
      142. 143.]
     [144. 145. 146. 147. 148. 149. 150. 151. 152. 153. 154. 155. 156. 157.
      158. 159.]
     [160. 161. 162. 163. 164. 165. 166. 167. 168. 169. 170. 171. 172. 173.
      174. 175.]
     [176. 177. 178. 179. 180. 181. 182. 183. 184. 185. 186. 187. 188. 189.
      190. 191.]
     [192. 193. 194. 195. 196. 197. 198. 199. 200. 201. 202. 203. 204. 205.
      206. 207.]
     [208. 209. 210. 211. 212. 213. 214. 215. 216. 217. 218. 219. 220. 221.
      222. 223.]
     [224. 225. 226. 227. 228. 229. 230. 231. 232. 233. 234. 235. 236. 237.
      238. 239.]
     [240. 241. 242. 243. 244. 245. 246. 247. 248. 249. 250. 251. 252. 253.
      254. 255.]]
    
    输 出 数 据dst_gm:
    [[  0.  16.  32.  48.  64.  80.  96. 112. 128. 144. 160. 176. 192. 208.
      224. 240.]
     [  1.  17.  33.  49.  65.  81.  97. 113. 129. 145. 161. 177. 193. 209.
      225. 241.]
     [  2.  18.  34.  50.  66.  82.  98. 114. 130. 146. 162. 178. 194. 210.
      226. 242.]
     [  3.  19.  35.  51.  67.  83.  99. 115. 131. 147. 163. 179. 195. 211.
      227. 243.]
     [  4.  20.  36.  52.  68.  84. 100. 116. 132. 148. 164. 180. 196. 212.
      228. 244.]
     [  5.  21.  37.  53.  69.  85. 101. 117. 133. 149. 165. 181. 197. 213.
      229. 245.]
     [  6.  22.  38.  54.  70.  86. 102. 118. 134. 150. 166. 182. 198. 214.
      230. 246.]
     [  7.  23.  39.  55.  71.  87. 103. 119. 135. 151. 167. 183. 199. 215.
      231. 247.]
     [  8.  24.  40.  56.  72.  88. 104. 120. 136. 152. 168. 184. 200. 216.
      232. 248.]
     [  9.  25.  41.  57.  73.  89. 105. 121. 137. 153. 169. 185. 201. 217.
      233. 249.]
     [ 10.  26.  42.  58.  74.  90. 106. 122. 138. 154. 170. 186. 202. 218.
      234. 250.]
     [ 11.  27.  43.  59.  75.  91. 107. 123. 139. 155. 171. 187. 203. 219.
      235. 251.]
     [ 12.  28.  44.  60.  76.  92. 108. 124. 140. 156. 172. 188. 204. 220.
      236. 252.]
     [ 13.  29.  45.  61.  77.  93. 109. 125. 141. 157. 173. 189. 205. 221.
      237. 253.]
     [ 14.  30.  46.  62.  78.  94. 110. 126. 142. 158. 174. 190. 206. 222.
      238. 254.]
     [ 15.  31.  47.  63.  79.  95. 111. 127. 143. 159. 175. 191. 207. 223.
      239. 255.]]
    
  • 增 强 接 口 调 用 示 例 片 段,完 整 代 码 请 参 考Transpose类 样 例场 景 二,完 成half类 型 的[N,C,H,W]->[N,H,W,C]转 置。

    C++
    AscendC::TransposeParamsExt transposeParams;
    transposeParams.nSize = N; // N轴 长 度
    transposeParams.cSize = C; // C轴 长 度
    transposeParams.hSize = H; // H轴 长 度
    transposeParams.wSize = W; // W轴 长 度
    transposeParams.transposeType = transposeType; 
    AscendC::Transpose(dstLocal, srcLocal, stackBuffer, transposeParams);
    

免 责 声 明:本 站 内 容 由 asc-devkit 仓 master 分 支 自 动 编 译 生 成,属 于 持 续 开 发 版 本,可 能 存 在 缺 陷,仅 供 预 览 与 参 考。如 需 稳 定 及 商 用 资 料,请 查 阅 官 方 昇 腾 社 区