TensorFlow笔记——(1)理解tf.control_dependencies与control_flow_ops.with_dependencies
引言
我们在实现神经网络的时候经常会看到tf.control_dependencies的使用,但是这个函数究竟是什么作用,我们应该在什么情况下使用呢?今天我们就来一探究竟。
理解
其实从字面上看,control_dependencies 是控制依赖的意思,我们可以大致推测出来,这个函数应该使用来控制就算图节点之间的依赖的。其实正是如此,tf.control_dependencies()是用来控制计算流图的,给图中的某些节点指定计算的顺序。
原型分析
tf.control_dependencies(self, control_inputs)
arguments:control_inputs: A list of `Operation` or `Tensor` objects
which must be executed or computed before running the operations
defined in the context. (注意这里control_inputs是list)
return: A context manager that specifies control dependencies
for all operations constructed within the context.
通过以上的解释,我们可以知道,该函数接受的参数control_inputs,是Operation或者Tensor构成的list。返回的是一个上下文管理器,该上下文管理器用来控制在该上下文中的操作的依赖。也就是说,上下文管理器下定义的操作是依赖control_inputs中的操作的,control_dependencies用来控制control_inputs中操作执行后,才执行上下文管理器中定义的操作。
例子1
如果我们想要确保获取更新后的参数,name我们可以这样组织我们的代码。
opt = tf.train.Optimizer().minize(loss)
with tf.control_dependencies([opt]): #先执行opt
updated_weight = tf.identity(weight) #再执行该操作
with tf.Session() as sess:
tf.global_variables_initializer().run()
sess.run(updated_weight, feed_dict={...}) # 这样每次得到的都是更新后的weight
可以看到以上的例子用到了tf.identity(),至于为什么要使用tf.identity(),我在下一篇博客:TensorFlow笔记——(1)理解tf.control_dependencies与control_flow_ops.with_dependencies中有详细的解释,不懂的可以移步了解。
control_flow_ops.with_dependencies
除了常用tf.control_dependencies()我们还会看到,control_flow_ops.with_dependencies(),其实连个函数都可以实现依赖的控制,只是实现的方式不太一样。
with_dependencies(dependencies, output_tensor, name=None)
Produces the content of `output_tensor` only after `dependencies`.
所有的依赖操作完成后,计算output_tensor并返回
In some cases, a user may want the output of an operation to be
consumed externally only after some other dependencies have run
first. This function ensures returns `output_tensor`, but only after all
operations in `dependencies` have run. Note that this means that there is
no guarantee that `output_tensor` will be evaluated after any `dependencies`
have run.
See also @{tf.tuple$tuple} and @{tf.group$group}.
Args:
dependencies: Iterable of operations to run before this op finishes.
output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
name: (Optional) A name for this operation.
Returns:
Same as `output_tensor`.
Raises:
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
例子2
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) #从一个集合中取出变量,返回的是一个列表
......
total_loss, clones_gradients = model_deploy.optimize_clones(
clones,
optimizer,
var_list=variables_to_train)
......
# tf.group()将多个tensor或者op合在一起,然后进行run,返回的是一个op
update_op = tf.group(*update_ops)
train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
name='train_op')
可以看到以上的例子用到了tf.group(),至于为什么要使用tf.identity(),我在下一篇博客:TensorFlow笔记——(2) tf.group(), tf.tuple 和 tf.identity()中有详细的解释,不懂的可以移步了解。
参考文档
1、tensorflow学习笔记(四十一):control dependencies
2、tf.control_dependencies与tf.identity组合详解
更多推荐
所有评论(0)