pytorch报错详解:RuntimeError: Trying to backward through the graph a second time
·
代码如下,当我尝试对y3进行第二次梯度计算时,报了这个错
import torch
x = torch.tensor([1], dtype=torch.float32, requires_grad=True)
y1 = x ** 2
y2 = x ** 3
y3 = y1 + y2
y3.backward()
print(x.grad)
x.grad.data.zero_()
y3.backward()
print(x.grad)
百度了一下问题原因和解决方法,
解决方法:在第一次backward中加一句retain_grad=True,意思为一直保留计算图,问题解决。
y3.backward(retain_graph=True)
报错原因就是pytorch的计算图在第一次执行完backward计算梯度的时候就已经被释放了。第二次想要再用计算图计算时,计算图已经没了,自然报错。
那么,计算图是什么?
针对上述例子,计算图可以画成这样:
当第二次backward时,无法从y3回溯到y1/y2,这是报错的根源。
此外,我还尝试了另一个例子:
x = torch.randn(3,3,requires_grad=True)
print(x)
y = x + x
out2 = y.sum()
out2.backward()
print(x.grad)
x.grad.data.zero_()
out2.backward()
print(x.grad)
神奇的是,就算连续调用两次backward,依然不报错。当时我的猜想是这根据计算图的计算方法复杂程度而定,第一个例子变量都来自x,一个是二次方,一个是三次方,计算方法不同,计算图必须保留;第二个例子只用了加法一种计算方法,但随后我将这种想法推翻了,因为这个例子:
y = x*x
z = y*y
a = z*z
a = a.sum()
a.backward()
a.backward()
计算方法单一,但结果报错。我又将乘法都换成了加法,就没错了O_o
总结一下,只要是 只有加法存在 这种情况,不管套了多少层,backward()即可,剩下的情况都要backward(retain_graph=True)
更多推荐
已为社区贡献1条内容
所有评论(0)