TensorFlow中变量管理reuse参数的使用
TensorFlow用于变量管理的函数主要有两个: tf. get_variable()和tf.variable_scope(),
前者用于创建或获取变量的值,后者用于生成上下文管理器,创建命名空间,命名空间可以嵌套。
函数tf.get_variable()既可以创建变量,也可以获取变量。控制创建还是获取的开关来自函数tf.variable.scope()中的参数reuse为“True”还是"False",分两种情况进行说明:
1. 设置reuse=False时,函数get_variable()表示创建变量
如下面的例子:
with tf.variable_scope("foo",reuse=False):
v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
在tf.variable_scope()函数中,设置reuse=False时,在其命名空间"foo"中执行函数get_variable()时,表示创建变量"v",若在该命名空间中已经有了变量"v",则在创建时会报错,如下面的例子
import tensorflow as tf
with tf.variable_scope("foo"):
v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
v1=tf.get_variable("v",[1])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-1-eaed46cad84f> in <module>()
3 with tf.variable_scope("foo"):
4 v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
----> 5 v1=tf.get_variable("v",[1])
6
ValueError: Variable foo/v already exists, disallowed.
Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
2. 设置reuse=True时,函数get_variable()表示获取变量
如下面的例子:
import tensorflow as tf
with tf.variable_scope("foo"):
v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
with tf.variable_scope("foo",reuse=True):
v1=tf.get_variable("v",[1])
print(v1==v)
结果为:
True
在tf.variable_scope()函数中,设置reuse=True时,在其命名空间"foo"中执行函数get_variable()时,表示获取变量"v"。若在该命名空间中还没有该变量,则在获取时会报错,如下面的例子
import tensorflow as tf
with tf.variable_scope("foo",reuse=True):
v1=tf.get_variable("v",[1])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-1-019a05c4b9a4> in <module>()
2
3 with tf.variable_scope("foo",reuse=True):
----> 4 v1=tf.get_variable("v",[1])
5
ValueError: Variable foo/v does not exist, or was not created with tf.get_variable().
Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
3. 结论
TensorFlow通过tf. get_variable()和tf.variable_scope()两个函数,可以创建多个并列的或嵌套的命名空间,用于存储神经网络中的各层的权重、偏置、学习率、滑动平均衰减率、正则化系数等参数值,神经网络不同层的参数可放置在不同的命名空间中。同时,变量重用检错和读取不存在变量检错两种机制保证了数据存放的安全性。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)