tensorflow变量共享——VariableScope的reuse模式、tf.get_variable()、tf.Variable() 探索
文章目录
一、VariableScope的reuse模式的设置
1.1节
1、tf.get_variable_scope()
可以获得当前的变量域,还可以通过其.name
和.reuse
来查看其名称和当前是否为reuse模式。
2、变量域有name,变量也有name。默认变量作用域的name为空白字符串。
3、在变量域内命名的变量的name全称为:“变量域的name+变量定义时传入的name”(就像一个文件有名字作为标识符,但是在前面加上绝对路径就是它在整个文件系统中的全局标识符)。
这三点贯穿本文,如果不太清楚,可以直接看后面的多个例子,会不断地体现在代码中。
1.2节
with tf.variable_scope()
可以打开一个变量域,有两个关键参数。name_or_scope
参数可以是字符串或tf.VariableScope对象,reuse
参数为布尔值,传入True
表示设置该变量域为reuse模式。
还有一种方法可以将变量域设置为reuse模式,即使用VariableScope对象的reuse_variables()方法
,例如tf.get_variable_scope().reuse_variables()
可以将当前变量域设置为reuse模式。
with tf.variable_scope('vs1'):
tf.get_variable_scope().reuse_variables()
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
with tf.variable_scope('vs2',reuse=True):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
'''
"vs1" True
"vs2" True
'''
1.3节
对某变量域设置reuse模式,则reuse模式会被变量域的子域继承
# 注意,默认变量域的名称为空白字符串
tf.get_variable_scope().reuse_variables() # 将默认变量域设置为reuse模式
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse) # 为了显示空白字符串,在名称两边加上双引号
with tf.variable_scope('vs'):
# vs是默认变量域的子域,故虽然没有明确设置vs的模式,但其也更改成了reuse模式
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
'''
输出为:
"" True
"vs" True
'''
1.4节
每次在with块中设置变量域的模式,退出with块就会失效(恢复回原来的模式)。
with tf.variable_scope('vs1'):
tf.get_variable_scope().reuse_variables()
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
with tf.variable_scope('vs1'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
'''
输出为:
"vs1" True
"vs1" False
'''
1.5节
可以使用一个变量域(tf.VariableScope对象)的引用来打开它,这样可以不用准确的记住其name的字符串。下面的例子来自tensorflow官网。
with tf.variable_scope("model") as scope:
output1 = my_image_filter(input1)
with tf.variable_scope(scope, reuse=True):
output2 = my_image_filter(input2)
tf.VariableScope对象作为with tf.variable_scope( name_or_scope ):
的参数时,该with语句块的模式是该scope
对应的模式。(下面的代码同时也展现了前面所说的“继承”和“失效”的现象。)
with tf.variable_scope('vs1'):
tf.get_variable_scope().reuse_variables()
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
with tf.variable_scope('vs2') as scope: # vs2(全称是vs1/vs2)将会继承vs1的reuse模式
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
# 重新用with打开vs1和vs2,他们的reuse模式不受之前with块中的设置的影响
with tf.variable_scope('vs1'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
with tf.variable_scope('vs2'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
# tf.variable_scope也可以传入tf.VariableScope类型的变量,此处的scope是第4行with语句中定义的
with tf.variable_scope(scope):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
# 第二个with中, vs1/vs2的reuse模式为False,即前面所说的退出with块之后reuse模式的设置“失效”
# 但是第三个with中,vs1/vs2的reuse却为True,这是因为当`name_or_scope`参数是tf.VariableScope对象时,
# 其打开的变量域的reuse模式由这个参数scope决定。
# 此处的`scope`在第4行定义,“继承”vs1的reuse,且之后没有改变,所以第三个with打开的就是reuse=True
'''
输出为:
"vs1" True
"vs1/vs2" True
"vs1" False
"vs1/vs2" False
"vs1/vs2" True
'''
二、reuse模式对tf.Variable() 的影响
tf.Variable()
只有新建变量的功能,一个变量域是否为reuse模式不影响tf.Variable()
的作用。如果该变量域中已经有同名的变量,则新建的变量会被重命名,加上数字后缀以区分。
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v1=tf.Variable(tf.constant(1),name='v')
v2=tf.Variable(tf.constant(1),name='v')
tf.get_variable_scope().reuse_variables()
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v3=tf.Variable(tf.constant(1),name='v') # 在reuse模式下使用tf.Variable(),仍然会新建,故v3名称为v_2
print(v1.name)
print(v2.name)
print(v3.name)
'''
输出为:
"" False
"" True
v:0
v_1:0
v_2:0
'''
三、reuse模式对tf.get_variable()的影响
reuse模式会对tf.get_variable()
的实际效果有决定作用。
3.1节
在non-reuse模式下,tf.get_variable()
作用为新建变量(设为v
)。若变量域内已经有同名变量(设为w
),则分两种情况:
- 若
w
是之前通过tf.Variable()
创建的,则v
将被重命名,即加上数字后缀。 - 若
w
是之前通过tf.get_variable()
创建的,则不允许新建同名变量v
。
with tf.variable_scope('vs'):
# 打印当前scope的名称和是否为reuse模式
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v1=tf.Variable(tf.constant(1),name='v')
print(v1.name) # 前缀为scope的名称加一个反斜线,即“vs/”,故全称为“vs/v:0”,“冒号0”的解释见后文。
v2=tf.get_variable('v',shape=())
print(v2.name) # 已经有名为v的变量,故v2的name会在v后面加上数字后缀(从1开始)
v3=tf.get_variable('v',shape=()) # 已经有名为v且由tf.get_variable创建的变量,故v3的创建抛出异常
print(v3.name)
输出为:(题外话,“:0” 指的是该变量是创建它的operation的第一个输出,见 这个链接)
"vs" False
vs/v:0
vs/v_1:0
--------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-2-e0b97b39994d> in <module>()
5 v2=tf.get_variable('v',shape=())
6 print(v2.name)
----> 7 v3=tf.get_variable('v',shape=())
8 print(v3.name)
9
<省略部分输出>
ValueError: Variable vs/v already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
3.2节
在reuse模式下,tf.get_variable()
作用为重用(reuse)变量。注意只能重用之前在本变量域创建的、且使用tf.get_variable()
创建的变量,即不能在本变量域中重用其他变量域中创建的变量,也不能重用那些使用tf.Variable()
创建的变量。
1.重用变量
with tf.variable_scope('vs'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v=tf.get_variable('v',shape=())
print(v.name)
with tf.variable_scope('vs',reuse=True):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
reused_v=tf.get_variable('v',shape=()) # reused_v就是之前的v,他们是共享内存的变量
print(reused_v.name)
'''
输出为:
"vs" False
vs/v:0
"vs" True
vs/v:0
'''
2.不能重用其他变量域中命名的变量(相当于你在A文件夹新建了v.txt,但是不能到B文件夹里面找v.txt)。
# 在vs变量域新建v,尝试到vs1中重用变量
with tf.variable_scope('vs'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v=tf.get_variable('v',shape=())
# v=tf.Variable(tf.constant(1),name='v')
print(v.name)
with tf.variable_scope('vs1',reuse=True):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
# 下一行会报错,因为vs1这个变量域并没有用get_variable()创建过名为v的变量
reused_v=tf.get_variable('v',shape=())
print(reused_v.name)
'''
报错:
ValueError: Variable vs1/v does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
'''
3.只能重用那些使用tf.get_variable()
创建的变量,而不能重用那些使用tf.Variable()
创建的变量。
with tf.variable_scope('vs'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v=tf.Variable(tf.constant(1),name='v')
print(v.name)
with tf.variable_scope('vs',reuse=True):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
reused_v=tf.get_variable('v',shape=())
print(reused_v.name)
输出为:
"vs" False
vs/v:0
"vs" True
--------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-2-63ddfa598083> in <module>()
6 with tf.variable_scope('vs',reuse=True):
7 print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
----> 8 reused_v=tf.get_variable('v',shape=())
9 print(reused_v.name)
10
<省略部分输出>
ValueError: Variable vs/v does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
附加1:tf.name_scope()与tf.variable_scope()的区别
tf.name_scope()
与tf.variable_scope()
的功能很像,这里也顺便探讨一下他们的区别,以助于加深对两个方法的理解。
本小节参考自 这篇知乎文章
with tf.name_scope('ns'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v1=tf.get_variable('v',shape=())
v2=tf.Variable(tf.constant(1),name='v')
print(v1.name)
print(v2.name)
with tf.variable_scope('vs'):
print('"'+tf.get_variable_scope().name+'"', tf.get_variable_scope().reuse)
v3=tf.get_variable('v',shape=())
v4=tf.Variable(tf.constant(1),name='v')
print(v3.name)
print(v4.name)
with tf.variable_scope('vs'):
with tf.name_scope('ns'):
v5=tf.Variable(tf.constant(1),name='v')
print(v5.name)
v6=tf.get_variable('v',shape=()) # 这里将会抛出异常
print(v6.name)
输出如下,解释见对应的注释:
"" False # 1.with打开NameScope并不影响所在的VariableScope
v:0 # 2.NameScope对于以tf.get_variable()新建的变量的命名不会有影响
ns/v:0 # 3.对于以tf.Variable()方式新建的变量的命名,会加上NameScope的名字作为前缀
"vs" False # 4.印证了第1点
vs/v:0 # 5.印证了第2点
ns/vs/v:0 # 6.对于被多层NameScope和VariableScope包围的、以tf.Variable()新建的变量,其命名以嵌套顺序来确定前缀
vs/ns/v:0 # 7.印证了第6点
# 下面的异常是由v6=tf.get_variable('v',shape=())导致的
# 因为tf.get_variable()获得的变量的命名不受NameScope影响,所以这里其实对应了3.1节第2点的情况
# 即在相同的VariableScope中使用tf.get_variable()定义了重名的变量
Traceback (most recent call last):
File "scope.py", line 79, in <module>
v6=tf.get_variable('v',shape=())
File "C:\Users\pyxies\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1203, in get_variable
constraint=constraint)
File "C:\Users\pyxies\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1092, in get_variable
constraint=constraint)
File "C:\Users\pyxies\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 425, in get_variable
constraint=constraint)
File "C:\Users\pyxies\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 394, in _true_getter
use_resource=use_resource, constraint=constraint)
File "C:\Users\pyxies\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 742, in _get_single_variable
name, "".join(traceback.format_list(tb))))
ValueError: Variable vs/v already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
附加2:单机多GPU下的变量共享/复用
见 https://blog.csdn.net/xpy870663266/article/details/99330338
更多推荐
所有评论(0)