TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。

例子1:将[2.4, 3.5]写进TensorArray三次

import tensorflow as tf


def condition(time, output_ta_l):
    return tf.less(time, 3)


def body(time, output_ta_l):
    output_ta_l = output_ta_l.write(time, [2.4, 3.5])
    return time + 1, output_ta_l


time = tf.constant(0)
output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)

result = tf.while_loop(condition, body, loop_vars=[time, output_ta])

last_time, last_out = result

final_out = last_out.stack()

with tf.Session():
    print(last_time.eval())
    print(final_out.eval())
Out:
3
[[ 2.4000001  3.5      ]
 [ 2.4000001  3.5      ]
 [ 2.4000001  3.5      ]]

重要函数:

ta.stack(name=None) 将TensorArray中元素叠起来当做一个Tensor输出

ta.unstack(value, name=None) 可以看做是stack的反操作,输入Tensor,输出一个新的TensorArray对象

ta.write(index, value, name=None) 指定index位置写入Tensor

ta.read(index, name=None) 读取指定index位置的Tensor

以上所有函数的参数name=None均可用来指定当前操作的名称。

GitHub 加速计划 / te / tensorflow
184.54 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:26 天前 )
a49e66f2 PiperOrigin-RevId: 663726708 1 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 1 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐