codecamp

PyTorch 自动求导机制

原文: PyTorch 自动求导机制

本说明将概述 autograd 的工作方式并记录操作。 不一定要完全了解所有这些内容,但我们建议您熟悉它,因为它可以帮助您编写更高效,更简洁的程序,并可以帮助您进行调试。

从向后排除子图

每个张量都有一个标志:requires_grad,允许从梯度计算中细粒度地排除子图,并可以提高效率。

requires_grad

如果某个操作的单个输入需要进行渐变,则其输出也将需要进行渐变。 相反,仅当所有输入都不需要渐变时,输出才不需要。 在所有张量都不要求渐变的子图中,永远不会执行向后计算。

>>> x = torch.randn(5, 5)  # requires_grad=False by default
>>> y = torch.randn(5, 5)  # requires_grad=False by default
>>> z = torch.randn((5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True

当您要冻结部分模型,或者事先知道您将不使用渐变色时,此功能特别有用。 一些参数。 例如,如果您想微调预训练的 CNN,只需在冻结的基数中切换requires_grad标志,就不会保存任何中间缓冲区,直到计算到达最后一层,仿射变换将使用权重为 需要梯度,网络的输出也将需要它们。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
## Replace the last fully-connected layer
## Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)


## Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

autograd 如何编码历史

Autograd 是反向自动分化系统。 从概念上讲,autograd 会记录一个图形,记录执行操作时创建数据的所有操作,从而为您提供一个有向无环图,其叶子为输入张量,根为输出张量。 通过从根到叶跟踪该图,您可以使用链式规则自动计算梯度。

在内部,autograd 将该图表示为Function对象(真正的表达式)的图,可以将其apply()编辑以计算评估图的结果。 在计算前向通过时,autograd 同时执行请求的计算,并建立一个表示表示计算梯度的函数的图形(每个 torch.Tensor.grad_fn属性是该图形的入口)。 完成前向遍历后,我们在后向遍历中评估此图以计算梯度。

需要注意的重要一点是,每次迭代都会从头开始重新创建图形,这正是允许使用任意 Python 控制流语句的原因,它可以在每次迭代时更改图形的整体形状和大小。 在开始训练之前,您不必编码所有可能的路径-跑步就是您的与众不同。

使用 autograd 进行就地操作

在 autograd 中支持就地操作很困难,并且在大多数情况下,我们不鼓励使用它们。 Autograd 积极的缓冲区释放和重用使其非常高效,就地操作实际上很少显着降低内存使用量的情况很少。 除非您在高内存压力下进行操作,否则可能永远不需要使用它们。

限制就地操作的适用性的主要原因有两个:

  1. 就地操作可能会覆盖计算梯度所需的值。
  2. 实际上,每个就地操作都需要实现来重写计算图。 异地版本仅分配新对象并保留对旧图形的引用,而就地操作则需要更改表示此操作的Function的所有输入的创建者。 这可能很棘手,特别是如果有许多张量引用相同的存储(例如通过索引或转置创建的),并且如果修改后的输入的存储被任何其他Tensor引用,则就地函数实际上会引发错误。

就地正确性检查

每个张量都有一个版本计数器,每次在任何操作中被标记为脏时,该计数器都会增加。 当函数保存任何张量以供向后时,也会保存其包含 Tensor 的版本计数器。 访问self.saved_tensors后,将对其进行检查,如果该值大于保存的值,则会引发错误。 这样可以确保,如果您使用的是就地函数并且没有看到任何错误,则可以确保计算出的梯度是正确的。

PyTorch torch.nn 到底是什么?
PyTorch 广播语义
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

Pytorch 音频

PyTorch 命名为 Tensor(实验性)

PyTorch 强化学习

PyTorch 用其他语言

PyTorch 语言绑定

PyTorch torchvision参考

PyTorch 音频参考

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }