pytorch中LSTM参数详解(一张图帮你更好的理解每一个参数)
·
对LSTM网络的理解
对LSTM网络不理解的请看这篇博客,对新手比较友好,也很容易理解,只有理解了LSTM,才知道下面要讲的参数分别对应什么
LSTM参数列表
Pytorch中创建一个LSTM网络,参数列表如下:
参数 | 解释 |
---|---|
input_size | 输入数据的特征维数 |
hidden_size | LSTM中隐层的维度 |
num_layers | 循环神经网络的层数 |
bias | 用不用偏置,default=True |
batch_first | 这个要注意,通常我们输入的数据shape=(batch_size,seq_length,embedding_dim),而batch_first默认是False,所以我们的输入数据最好送进LSTM之前将batch_size与seq_length这两个维度调换 |
dropout | 默认是0,代表不用dropout |
bidirectional | 默认是false,代表不用双向LSTM |
其实最重要的参数就前三个,其他参数都可以默认。把网络看成一个黑箱,我们在用是肯定是输入一个向量,然后网络处理后输出一个向量,所以我们必须要告诉网络输入的向量是多少维,输出的为多少维,因此前两个参数就决定了输入和输出向量的维度。当然,hidden_size只是指定从LSTM输出的向量的维度,并不是最后的维度,因为LSTM层之后可能还会接其他层,如全连接层(FC),因此hidden_size对应的维度也就是FC层的输入维度。第三个参数num_layers为隐藏层的层数,这个比较好理解,官方的例程里面建议一般设置为1或者2。
光看上面的文字描述似乎还不够直白,因此我画了一张图。
注:局部图片来自于这篇文章
Input shape
建好网络之后我们还需要把数据调整整对应的shape,pytorch中LSTM的调用如下:
output,(h_n,c_n) = lstm (x, [ht_1, ct_1])
其中x就是我们喂给网络的数据,它的shape要求如下:
x:[seq_length, batch_size, input_size]
新手可能对着三个参数容易理解错误,这里贴一篇知乎上的文章用「动图」和「举例子」讲讲 RNN,这篇文章讲的比较通俗易懂,适合新手看。
同样,我也画了几张对应的图来解释
转载或拿图请注明出处。
更多推荐
已为社区贡献5条内容
所有评论(0)