TensorFlow 变量管理

摘自:《TensorFlow实战Google深度学框架》一书,5.3节。

Tensorflow提供了通过变量名称来创建或者获取一个变量的机制。通过这个机制,在不同的函数中可以直接通过变量的名字来使用变量,而不需要将变量通过参数的形式到处传递。TensorFlow中通过变量名获取变量的机制主要是通过tf.get_variable和tf.variable_scope函数实现的。下面将分别介绍如何使用这两个函数。

通过tf.Variable()函数可以创建一个变量。除了tf.Variable()函数,TensorFlow还提供了tf.get_variable函数来创建或者获取变量。当tf.get_variable用于创建变量时,它和tf.Variable()的功能是基本等价的。下面的代码给出通过这两个函数创建同一个变量的示例:

1
2
3
# 下面这两个定义是等价的
v = tf.get_variable("v",shape=[1],initializer=tf.constant_initializer(1.0))
v = tf.Variable(tf.constant(1.0, shape=[1]),name='v')

从上面的代码可以看出,通过tf.Variable和tf.get_variable函数创建变量的过程基本上是一样的。tf.get_variable函数调用时提供的维度(shape)信息以及初始化方法(initializer)的参数和tf.Variable函数调用时提供的初始化过程中的参数也类似。TensorFlow中提供的initializer函数和随机数以及常量生成函数大部分是一一对应的。比如,在上面的样例程序中使用的常数初始化函数tf.constant_initializer和常数生成的函数tf.constant功能上是一致的。Tensorflow提供了7种不同的初始化函数,如下表所示:

初始化函数 功能 主要参数
tf.constant_initializer 将变量初始化为给定常量 常量的取值
tf.random_normal_initializer 将变量初始化为满足正态分布的随机值 正态分布的均值和标准差
tf.truncated_normal_initializer 将变量初始化为满足正态分布的随机值,但如果随机出来的值
偏离平均值超过2个标准差,那么这个数将会被重新随机
正态分布的均值和标准差
tf.random_uniform_initializer 将变量初始化为满足均匀分布的随机值 最大值、最小值
tf.uniform_unit_scaling_initializer 将变量初始化为满足均匀分布但不影响输出数量级的随机值 factor(产生随机数时
乘以的系数)
tf.zeros_initializer 将被变量设置为0 变量维度
tf.ones_initializer 将变量设置为1 变量的维度

tf.get_variable函数与tf.Variable函数最大的区别在于指定变量名称的参数。对于tf.Variable函数, 变量名称是一个可选的参数,通过name=“v”的形式给出。但是对于tf.get_variable函数,变量名称是一个必填的参数。tf.get_variable会根据这个名字去创建或者获取变量。在上面的示例程序中,tf.get_variable首先会试图创建一个名字为v的参数,如果创建失败(比如已经有同名的参数),那么这个程序会报错。这是为了避免无意识的变量复用造成的错误。比如在定义神经网络参数时,第一层网络的权重已经叫weights了,如果创建第二层的神经网络时,如果参数名仍然叫weights,那么就会触发变量重用的错误。否则两层神经网络公用一个权重会出现一些比较难以发现的错误。如果需要通过tf.get_variable获取一个已经创建的变量,需要通过tf.variable_scope函数来生成一个上下文管理器,并明确指定在这个上下文管理器中,tf.get_variable将直接获取已经生成的变量。下面给出一段代码说明如何通过tf.variable_scope函数来控制tf.get_variable函数获取已经创建过的变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 在名字为foo的命名空间内创建名字为v的变量
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1], initializer=tf.constant_initializer(1.0))
# 因为在命名空间foo中已经存在名为v的变量,所以下面的代码将会报错
with tf.variable_scope("foo"):
v = tf.get_variable("v",[1])
# 在生成上下文管理器时,将参数reuse设置为True。这样tf.get_vaiable函数将直接获取
# 已经声明的变量。
with tf.variable_scope("foo",reuse=True):
v1 = tf.get_variable("v",[1])
print v==v1 #输出为True,v和v1代表的是相同的变量
# 将参数reuse设置为True时,tf.variable_scope将只能获取已经创建过的变量。因为在命名
# 空间bar中还没有创建变量v,所以下面的代码将会报错
with tf.variable_scope("bar",reuse=True):
v = tf.get_variable("v",[1])

上面的样例简单地说明了通过tf.variable_scope函数可以控制tf.get_variable函数的语义。当tf.variable_scope函数使用参数reuse=True生成上下文管理器时,这个上下文管理器内所有的tf.get_variable函数会直接获取已经创建的变量。如果变量没有被创建,则tf.get_variable将会报错;相反如果tf.variable_scope函数使用参数reuse=None或者reuse=False创建上下文管理器,tf.get_variable操作将创建新的变量。如果同名变量已经存在,则tf.get_variable函数将会报错。TensorFlow中tf.variable_scope函数是可以嵌套的。下面的程序说明了当tf.variable_scope函数嵌套时,reuse参数的取值时如何确定的。

1
2
3
4
5
6
7
8
9
10
11
12
with tf.variable_scope("root"):
# 可以通过tf.get_variable_scope().reuse函数来获取当前上下文管理器中reuse参数的取值
print tf.get_variable_scope().reuse #输出False,即最外层reuse是False
with tf.variable_scope("foo",reuse=True): # 新建一个嵌套的上下文管理器,
# 并指定reuse为True
print tf.get_variable_scope().reuse # 输出为True
with tf.variable_scope("bar"): # 新建一个嵌套的上下文管理器
# 但不指定reuse的取值,和外层的保持一致
print tf.get_variable_scope().reuse # 输出为True
print tf.get_variable_scope().reuse # 输出False,退出reuse设置为True
# 的上下文之后,reuse的值又回到了False

tf.variable_scope函数生成的上下文管理器也会创建一个TensorFlow中的命名空间,在命名空间内创建的变量名称都会带上这个命名空间名作为前缀。所以,tf.variable_scope函数除了控制tf.get_variable执行的功能之外,这个函数也提供了一个管理变量命名空间的方式。下面的代码显示如何通过tf.variable_scope来管理变量的名称。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
v1 = tf.get_variable("v",[1])
print v1.name # 输出v:0,"v"为变量名称,“:0”表示这个变量时生成变量这个运算的第一个结果
with tf.variable_scope("foo"):
v2 = tf.get_variable("v",[1])
print v2.name # 输出为foo/v:0。在tf.variable_scope中创建的变量,名称前面会
# 加入命名空间的名称,通过/来分隔命名空间的名称和变量的名称。
with tf.variable_scope("foo"):
with tf.variable_scope("bar"):
v3 = tf.get_variable("v",[1])
print v3.name # 输出为foo/bar/v:0。命名空间可以嵌套,同时变量的名称也会
# 加入所有命名空间的名称作为前缀。
v4 = tf.get_variable("v1",[1])
print v4.name # 输出foo/v1:0。当命名空间退出之后,变量名称也就不会再被加入其前缀了。
# 创建一个名称为空的命名空间,并设置为reuse=True
with tf.variable_scope("", reuse=True):
v5 = tf.get_variable("foo/bar/v",[1]) # 可以直接通过带命名空间名称的变量名来
# 获取其他命名空间下的变量
print v5 == v3 # 输出为True
v6 = tf.get_variable("foo/v1",[1])
print v6 == v4 # 输出为True

坚持原创技术分享,您的支持将鼓励我继续创作!