关于loss.backward()以及其参数retain_graph的一些坑

首先,loss.backward()这个函数很简单,就是计算与图中叶子结点有关的当前张量的梯度
使用呢,当然可以直接如下使用

    optimizer.zero_grad() 清空过往梯度;
    loss.backward() 反向传播,计算当前梯度;
    optimizer.step() 根据梯度更新网络参数

or这种情况
    for i in range(num):
        loss+=Loss(input,target)
    optimizer.zero_grad() 清空过往梯度;
    loss.backward() 反向传播,计算当前梯度;
    optimizer.step() 根据梯度更新网络参数

 但是,有些时候会出现这样的错误:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

这个错误的意思就是,Pytorch的机制是每次调用.backward()都会free掉所有buffers,模型中可能有多次backward(),而前一次backward()存储在buffer中的梯度,会因为后一次调用backward()被free掉,因此,这里需要用到retain_graph=True这个参数
使用这个参数,可以让前一次的backward()的梯度保存在buffer内,直到更新完成,但是要注意,如果你是这样写的:

    optimizer.zero_grad() 清空过往梯度;
    loss1.backward(retain_graph=True) 反向传播,计算当前梯度;
    loss2.backward(retain_graph=True) 反向传播,计算当前梯度;
    optimizer.step() 根据梯度更新网络参数

 那么你可能会出现内存溢出的情况,并且,每一次迭代会比上一次更慢,越往后越慢(因为你的梯度都保存了,没有free)
解决的方法,当然是这样:

    optimizer.zero_grad() 清空过往梯度;
    loss1.backward(retain_graph=True) 反向传播,计算当前梯度;
    loss2.backward() 反向传播,计算当前梯度;
    optimizer.step() 根据梯度更新网络参数

即:最后一个backward()不要加retain_graph参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了。

这里有人就会问了,我又没有这么多 loss,怎么还会出现这种错误呢?这里可能是因为,你用的模型本身有问题,LSTM和GRU都会出现这样的问题,问题存在与hidden unit,这个东东也参与了反向传播,所以导致了有多个backward(),
这里其实我也挺费解,为什么存在多个backward()呢?难道是,我的LSTM网络是N to N,即输入N和,输出N个,然后和N个label进行计算loss,再进行回传,这里,可以思考一下BPTT,即,如果是N to 1,那么梯度更新需要时间序列所有的输入以及隐藏变量计算梯度,然后从最后一个向前传,所以只有一个backward(), 而N to N 以及 N to M 都会出现多个loss需要进行backward()的情况,如果还是两个方向(一个从输出到输入,一个沿着时间)一直进行传播,那么其实会有重叠的部分,所以解决的方法也就很明了了,利用detach()函数,切断其中重叠的反向传播,(这里仅是我的个人理解,若有误还请评论指出,大家共同探讨)切断的方式有三种,如下:

hidden.detach_()
hidden = hidden.detach()
hidden = Variable(hidden.data, requires_grad=True) 

任选其一即可, 这里附一些我参考的解释,大家可以看看

Help clarifying repackage_hidden in word_language_model
Pytorch 如何实现训练LSTM的BPTT算法?
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed
Pytorch中retain_graph参数的作用

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐