实验有需求,需要对tensor中每一行取一个不同的索引的元素,其中tensor为2维(本文方法适合任意维),因此本文以2维tensor为例。

# 二维tensor
g = tf.constant([[1,2,3,4,5,6,7,8],[9,8,7,6,5,4,3,2]])
# 每一行取的index,在本例中,正确取值为[3, 2],即第一行index=2的元素和第二行index=7的元素
h_index = np.array([2, 7]).reshape(-1, 1)

# 构建一个numpy的arange列表,其长度为tensor的行数
line = np.arange(2).reshape(-1, 1)

# 注意上面两个numpy数组的格式都是(-1, 1)
# 将h_index和line合并
index = np.hstack((line, h_index))

# 使用tf.gather_nd来取值
result = tf.gather_nd(g, index)

如上即可,返回仍为tensor

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐