引言

最近在读别人写的代码的时候看到下面的代码。

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
......
update_op = tf.group(*update_ops) 

不太明白这个group()函数的作用是什么,于是上网进行了一番学习,在此记录一下,也分享出来供新手参考。

理解

其实上网学习了之后,我才后知后觉的发现这个函数的作用也可以由它的名字猜测出来。group是分组的意思,其实该函数就是把update_ops(list)中的操作作为一个组,把这些操作和成为一个操作。除了group函数之外,tuple函数也有类似的功能,但是有一点细微的差别。
1. 首先是接受参数的形式不同。
group的参数是一个一个operation,而不是一个列表。这也是引言中的update_ops列表前需要加*的原因。
2. 返回值不同。
tf.group()返回的是op,tf.tuple()返回的是list of tensor。
如果还是不太理解的话可以看后面的例子,和下面的英文的api,英文的内容比较简单,我就不翻译了。
* 可以发现,如果我们有很多 tensor 或 op想要一起run,tf.group() 与 tf.tuple()两个函数就是一个很好的帮手了。*

# TODO(touts): Accept "inputs" as a list.
def group(*inputs, **kwargs):
  """Create an op that groups multiple operations.

  When this op finishes, all ops in `inputs` have finished. This op has no
  output.

  See also @{tf.tuple$tuple} and
  @{tf.control_dependencies$control_dependencies}.

  Args:
    *inputs: Zero or more tensors to group.
    name: A name for this operation (optional).

  Returns:
    An Operation that executes all its inputs.

  Raises:
    ValueError: If an unknown keyword argument is provided.
  """
def tuple(tensors, name=None, control_inputs=None)
Group tensors together.

This creates a tuple of tensors with the same values as the tensors argument, 
except that the value of each tensor is only returned after the values of all 
tensors have been computed.

control_inputs contains additional ops that have to finish before this op 
finishes, but whose outputs are not returned.

This can be used as a "join" mechanism for parallel computations: all the
argument tensors can be computed in parallel, but the values of any tensor
returned by tuple are only available after all the parallel computations 
are done.

See also group and with_dependencies.

Args:
tensors: A list of Tensors or IndexedSlices, some entries can be None.
name: (optional) A name to use as a name_scope for the operation.
control_inputs: List of additional ops to finish before returning.
Returns:
Same as tensors.

例子1

w = tf.Variable(1)
mul = tf.multiply(w, 2)
add = tf.add(w, 2)
group = tf.group(mul, add)
tuple = tf.tuple([mul, add])
# sess.run(group)和sess.run(tuple)都会求Tensor(add)
#Tensor(mul)的值。区别是,tf.group()返回的是`op`
#tf.tuple()返回的是list of tensor。
#这样就会导致,sess.run(tuple)的时候,会返回 Tensor(mul),Tensor(add)的值.
#而 sess.run(group)不会

tf.identity()

在学习tf.group()的时候,看到很多文章都是将tf.identity()和tf.group()放在一起辨析,就一起学习了。字面上来理解identity是恒等的意思,其实这就是一个赋值操作。在一般的情况下,我们使用赋值操作符=来进行赋值,例如y=x,表示将x的值赋值给y。但是在TensorFlow某些特殊的情况下是不支持这么做的,原因是TensorFlow中的计算都是基于计算图中的,计算图的每个节点都是一个operation对象,所有以上的赋值操作也需要用一个赋值操作来表示,才能在计算图中进行计算(赋值)。所以需要写成y=tf.identity(x)。

例子2

下面程序要做的是,5次循环,每次循环给x加1,赋值给y,然后打印出来

x = tf.Variable(0.0)
#返回一个op,表示给变量x加1的操作
x_plus_1 = tf.assign_add(x, 1)

#control_dependencies的意义是,在执行with包含的内容(在这里就是 y = x)前
#先执行control_dependencies中的内容(在这里就是 x_plus_1)
with tf.control_dependencies([x_plus_1]):
    y = x
init = tf.initialize_all_variables()

with tf.Session() as session:
    init.run()
    for i in xrange(5):
        print(y.eval())#相当于sess.run(y),由于control_dependencies的所以执行print前都会先执行x_plus_1

这个打印的是

0,0,0,0,0 

也就是说没有达到我们预期的效果。
如果改成这样:

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1)

with tf.control_dependencies([x_plus_1]):
    y = tf.identity(x)#修改部分
init = tf.initialize_all_variables()

with tf.Session() as session:
    init.run()
    for i in xrange(5):
        print(y.eval())
This works: it prints 1, 2, 3, 4, 5. 

这时候打印的是

1,2,3,4,5

可以看到,tf.identity的左右的是将普通的赋值语句变成一个操作。
但是第一种写法为什么不work呢?
可以这样解释,虽然tf.control_dependencies参数中的op列表会在with包含的操作op执行之前先执行,但是y=x这个语句并不是一个op,而是一个tensor,所以执行y=x时,并不会执行tf.control_dependencies参数中的操作op。
所以可以将 y=x 修改为 y=tf.identity(x),此时这个语句就是一个操作op,要先执行tf.control_dependencies参数中的op列表,再执行y=tf.identity(x)操作,最终输出结果为1.0 2.0 3.0 4.0 5.0,最终变量x的结果也为5.0。

例子3

其实明白了上面例子为什么不能work之后,和group()函数的作用后。我们还有另一种改写方法。

import tensorflow as tf  
x = tf.Variable(0.0)  
x_plus = tf.assign_add(x, 1)  
with tf.control_dependencies([x_plus]):#只有当内部为操作时以来才会生效  
    #y = tf.identity(x)#将该语句变为操作  
    y = x  
    update = tf.group(y)#将该语句变为操作  
init = tf.global_variables_initializer()  
with tf.Session() as session:  
    init.run()  
    for i in range(5):  
        session.run(update)  
        print(y.eval())  
    print(x.eval())#5  

参考文章

  1. tf.identity()与tf.group()
  2. tf.identity的意义以及用例
  3. tensorflow学习笔记(四十一):control dependencies
  4. tensorflow学习笔记(三十五):control flow
  5. tf.control_dependencies与tf.identity组合详解
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐