对LSTM网络的理解

对LSTM网络不理解的请看这篇博客,对新手比较友好,也很容易理解,只有理解了LSTM,才知道下面要讲的参数分别对应什么

LSTM参数列表

Pytorch中创建一个LSTM网络,参数列表如下:

参数解释
input_size输入数据的特征维数
hidden_sizeLSTM中隐层的维度
num_layers循环神经网络的层数
bias用不用偏置,default=True
batch_first这个要注意,通常我们输入的数据shape=(batch_size,seq_length,embedding_dim),而batch_first默认是False,所以我们的输入数据最好送进LSTM之前将batch_size与seq_length这两个维度调换
dropout默认是0,代表不用dropout
bidirectional默认是false,代表不用双向LSTM

其实最重要的参数就前三个,其他参数都可以默认。把网络看成一个黑箱,我们在用是肯定是输入一个向量,然后网络处理后输出一个向量,所以我们必须要告诉网络输入的向量是多少维,输出的为多少维,因此前两个参数就决定了输入和输出向量的维度。当然,hidden_size只是指定从LSTM输出的向量的维度,并不是最后的维度,因为LSTM层之后可能还会接其他层,如全连接层(FC),因此hidden_size对应的维度也就是FC层的输入维度。第三个参数num_layers为隐藏层的层数,这个比较好理解,官方的例程里面建议一般设置为1或者2。

光看上面的文字描述似乎还不够直白,因此我画了一张图。
注:局部图片来自于这篇文章
在这里插入图片描述

Input shape

建好网络之后我们还需要把数据调整整对应的shape,pytorch中LSTM的调用如下:

output,(h_n,c_n) = lstm (x, [ht_1, ct_1])

其中x就是我们喂给网络的数据,它的shape要求如下:

x:[seq_length, batch_size, input_size]

新手可能对着三个参数容易理解错误,这里贴一篇知乎上的文章用「动图」和「举例子」讲讲 RNN,这篇文章讲的比较通俗易懂,适合新手看。
同样,我也画了几张对应的图来解释
在这里插入图片描述

在这里插入图片描述

转载拿图请注明出处。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐