PyTorch torch.utils.checkpoint
注意
通过在反向过程中为每个检查点段重新运行一个正向通过段来实现检查点。 这可能会导致像 RNG 状态这样的持久状态比没有检查点的状态更先进。 默认情况下,检查点包括处理 RNG 状态的逻辑,以便与非检查点通过相比,使用 RNG(例如,通过丢弃)的检查点通过具有确定的输出。 根据检查点操作的运行时间,存储和恢复 RNG 状态的逻辑可能会导致性能下降。 如果不需要与非检查点通过相比确定的输出,则在每个检查点期间向checkpoint
或checkpoint_sequential
提供preserve_rng_state=False
,以忽略存储和恢复 RNG 状态。
隐藏逻辑将当前设备以及所有 cuda Tensor 参数的设备的 RNG 状态保存并恢复到run_fn
。 但是,该逻辑无法预料用户是否将张量移动到run_fn
本身内的新设备。 因此,如果在run_fn
中将张量移动到新设备(“新”表示不属于[当前设备+张量参数的设备的集合]),则与非检查点传递相比,确定性输出将永远无法保证。
torch.utils.checkpoint.checkpoint(function, *args, **kwargs)¶
检查点模型或模型的一部分
检查点通过将计算交换为内存来工作。 检查点部分没有存储整个计算图的所有中间激活以进行向后计算,而是由而不是保存中间激活,而是在向后传递时重新计算它们。 它可以应用于模型的任何部分。
具体而言,在前向传递中,function
将以torch.no_grad()
方式运行,即不存储中间激活。 相反,前向传递保存输入元组和function
参数。 在向后遍历中,检索保存的输入和function
,并再次在function
上计算正向遍历,现在跟踪中间激活,然后使用这些激活值计算梯度。
警告
检查点不适用于 torch.autograd.grad()
,而仅适用于 torch.autograd.backward()
。
Warning
如果后退期间的function
调用与前退期间的调用有任何不同,例如,由于某些全局变量,则检查点版本将不相等,很遗憾,无法检测到该版本。
参数
- 函数 –描述在模型的正向传递中或模型的一部分中运行的内容。 它还应该知道如何处理作为元组传递的输入。 例如,在 LSTM 中,如果用户通过
(activation, hidden)
,则function
应正确使用第一个输入作为activation
,第二个输入作为hidden
- reserve_rng_state (bool , 可选 , 默认= True 在每个检查点期间恢复 RNG 状态。
- args –包含
function
输入的元组
退货
在*args
上运行function
的输出
torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs, **kwargs)¶
用于检查点顺序模型的辅助功能。
顺序模型按顺序(依次)执行模块/功能列表。 因此,我们可以将这样的模型划分为不同的段,并在每个段上检查点。 除最后一个段外,所有段都将以torch.no_grad()
方式运行,即不存储中间激活。 将保存每个检查点线段的输入,以便在后向传递中重新运行该线段。
有关检查点的工作方式,请参见 checkpoint()
。
Warning
Checkpointing doesn't work with torch.autograd.grad()
, but only with torch.autograd.backward()
.
Parameters
- 功能 –一个
torch.nn.Sequential
或要顺序运行的模块或功能列表(包含模型)。 - 段 –在模型中创建的块数
- 输入 –张量元组,它们是
functions
的输入 - preserve_rng_state (bool__, optional__, default=True) – Omit stashing and restoring the RNG state during each checkpoint.
Returns
在*inputs
上顺序运行functions
的输出
例
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)