tensorflow中tensor,从每行取指定索引元素
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow

·
实验有需求,需要对tensor中每一行取一个不同的索引的元素,其中tensor为2维(本文方法适合任意维),因此本文以2维tensor为例。
# 二维tensor
g = tf.constant([[1,2,3,4,5,6,7,8],[9,8,7,6,5,4,3,2]])
# 每一行取的index,在本例中,正确取值为[3, 2],即第一行index=2的元素和第二行index=7的元素
h_index = np.array([2, 7]).reshape(-1, 1)
# 构建一个numpy的arange列表,其长度为tensor的行数
line = np.arange(2).reshape(-1, 1)
# 注意上面两个numpy数组的格式都是(-1, 1)
# 将h_index和line合并
index = np.hstack((line, h_index))
# 使用tf.gather_nd来取值
result = tf.gather_nd(g, index)
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
如上即可,返回仍为tensor
推荐内容




一个面向所有人的开源机器学习框架
最近提交(Master分支:1 个月前 )
4f64a3d5
Instead, check for this case in `ResolveUsers` and `ResolveOperand`, by querying whether the `fused_expression_root` is part of the `HloFusionAdaptor`.
This prevents us from stepping into nested fusions.
PiperOrigin-RevId: 724311958
1 个月前
aa7e952e
Fix a bug in handling negative strides, and add a test case that exposes it.
We can have negative strides that are not just -1, e.g. with a combining
reshape.
PiperOrigin-RevId: 724293790
1 个月前
更多推荐
相关推荐
查看更多
tensorflow

一个面向所有人的开源机器学习框架
tensorflow

TensorFlow for R
TensorFlow

Project containig related material for my TensorFlow articles
热门开源项目
活动日历
查看更多
直播时间 2025-03-13 18:32:35

全栈自研企业级AI平台:Java核心技术×私有化部署实战
直播时间 2025-03-11 18:35:18

从0到1:Go IoT 开发平台的架构演进与生态蓝图
直播时间 2025-03-05 14:35:37

国产工作流引擎 终结「996」开发困局!
直播时间 2025-02-25 14:38:13

免费开源宝藏 ShopXO,电商系统搭建秘籍大公开!
直播时间 2025-02-18 14:31:04

从数据孤岛到数据智能 - 企业级数据管理利器深度解析
所有评论(0)