tensorflow2.0中Layer的__init__(),build(), call()函数
·
最近在实验中,需要用到tensorflow建立一个简单的模型,但鉴于部分要求比较苛刻,不能直接使用其内置的layer,因此需要自定义一个layer类,这便涉及到了对__init__()
, build()
, call()
这三个函数的理解
先看官方手册中使用了Layer中的这三个关键函数的一个简单的实例:
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
def call(self, input):
return tf.matmul(input, self.kernel)
layer = MyDenseLayer(10)
从直观上理解,似乎__init__()
和build()
函数都在对Layer进行初始化,都初始化了一些成员函数,而call()
函数则是在该layer被调用时执行。
显然,这三个函数都是从tf.keras.layers.Layer
处继承而来的,那么不妨看一下官方对这几个函数作何解释。
下图为tf.keras.layers.Layer
的官方文档
简单翻译,就是说官方推荐凡是tf.keras.layers.Layer
的派生类都要实现__init__()
,build()
, call()
这三个方法
__init__()
:保存成员变量的设置
build()
:在call()
函数第一次执行时会被调用一次,这时候可以知道输入数据的shape
。返回去看一看,果然是__init__()
函数中只初始化了输出数据的shape
,而输入数据的shape
需要在build()
函数中动态获取,这也解释了为什么在有__init__()
函数时还需要使用build()
函数
call()
: call()
函数就很简单了,即当其被调用时会被执行。
下面附上这几个函数的文档,就不做详细介绍了,有兴趣可以自己看看:
更多推荐
已为社区贡献10条内容
所有评论(0)