TensorFlow笔记——tf.split()拆分tensor和tf.squeeze()
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
tf.split()
参数说明:
split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
value: 输入的tensor
num_or_size_splits: 如果是个整数n,就将输入的tensor分为n个子tensor。如果是个tensor T,就将输入的tensor分为len(T)个子tensor。
axis: 默认为0,计算value.shape[axis], 一定要能被num_or_size_splits整除。
举例
import tensorflow as tf
import numpy as np
a=np.reshape(range(24),(4,2,3))
print(a)
sess=tf.InteractiveSession()
# 将a分为两个tensor,a.shape(1)为2,可以整除,不会报错。
# 输出应该为2个shape为[4,1,3]的tensor
b= tf.split(a,2,1)
print(b)
print(b[0].eval())
print("---------------------------------")
print(b[1].eval())
c= tf.split(a,2,0)
# a.shape(0)为4,被2整除,输出2个[2,2,3]的Tensor
print(c)
print(c[0].eval())
print("---------------------------------")
print(c[1].eval())
d= tf.split(a,3,2)
# 分成三个tensor,a.shape(2)为3,整除,不报错,返回3个[4,2,1]的Tensor
print(d)
print(d[0].eval())
print("---------------------------------")
print(d[1].eval())
print("---------------------------------")
print(d[2].eval())
d= tf.split(a,2,2)
# a.shape(2)为3,不被2整除,报错。
tf.squeeze()
参数说明:
tf.squeeze
squeeze(
input,
axis=None,
name=None,
squeeze_dims=None
)
去掉维数为1的维度。
示例
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t)) # [2, 3]
也可以指定去掉哪个维度:
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t, [2, 4])) # [1, 2, 3, 1]
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献22条内容
所有评论(0)