【Tensorflow+Keras】tf.keras.layers.Bidirectional()的解析与使用
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
1 作用
实现RNN类型神经网络的双向构造
RNN类型神经网络比如LSTM、GRU等等
2 参数
tf.keras.layers.Bidirectional(
layer,
merge_mode=‘concat’,
weights=None,
backward_layer=None
)
- layer 神经网络,如RNN、LSTM、GRU
- merge_mode 前向和后向RNN的输出将被组合的模式。{‘sum’,‘mul’,‘concat’,‘ave’,None}中的一个。如果为None,则将不合并输出,它们将作为列表返回。默认值为“ concat”。
- weights
- backward_layer 处理向后输入处理的神经网络,如果未提供,则作为参数传递的图层实例 将用于自动生成后向图层
注意
该层的调用参数与包装的RNN层的调用参数相同。请注意,在initial_state此层的调用期间传递参数时,列表中元素列表的前半部分initial_state 将传递给正向RNN调用,而元素列表中的后半部分将传递给后向RNN调用。
3 举例使用
三十个tf.keras.layers.Bidirectional的例子
导入模块和定义参数
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
max_features = 20000 # Only consider the top 20k words
maxlen = 200 # Only consider the first 200 words of each movie review
构建模型
# Input for variable-length sequences of integers
inputs = keras.Input(shape=(None,), dtype="int32")
# Embed each integer in a 128-dimensional vector
x = layers.Embedding(max_features, 128)(inputs)
# Add 2 bidirectional LSTMs
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
x = layers.Bidirectional(layers.LSTM(64))(x)
# Add a classifier
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.summary()
加载数据集
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(
num_words=max_features
)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)
训练和评估模型
model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val))
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献13条内容
所有评论(0)