codecamp

PyTorch torch.utils.checkpoint

PyTorch 检查点机制详解:优化显存使用与模型训练效率

一、检查点机制是什么?

检查点(Checkpoint)机制是一种用于优化深度学习模型训练过程中显存使用的技巧。在训练复杂的深度学习模型时,尤其是大型神经网络,显存资源往往非常有限。检查点机制通过在正向传播过程中丢弃某些中间激活结果,然后在反向传播过程中重新计算这些中间结果,从而减少显存占用。

二、PyTorch 检查点函数详解

(一)torch.utils.checkpoint.checkpoint(function, *args, preserve_rng_state=True)

  1. 基本原理
    • 在正向传播阶段,function 会以 torch.no_grad() 模式运行,即不保存中间激活结果。仅保存输入张量和 function 参数。
    • 在反向传播阶段,通过重新运行 function 来重新计算中间激活结果,然后基于这些结果计算梯度。

  1. 参数说明
    • function:定义模型正向传播过程的函数。该函数应能够处理输入元组并正确执行前向计算。
    • args:传递给 function 的输入张量元组。
    • preserve_rng_state:布尔值,默认为 True。如果为 True,则在检查点过程中保存并恢复随机数生成器(RNG)状态,以确保使用随机操作(如 dropout)时结果的确定性。

  1. 注意事项
    • 检查点机制不支持 torch.autograd.grad(),仅支持 torch.autograd.backward()
    • 如果反向传播期间的 function 调用与正向传播期间的调用存在差异(例如由于全局变量的影响),则可能导致结果不一致。

(二)torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs, preserve_rng_state=True)

  1. 基本原理
    • 适用于顺序执行的模型或模块列表。将模型划分为多个段,每个段对应一个检查点。
    • 除最后一个段外,其他段均以 torch.no_grad() 模式运行,不保存中间激活结果。每个检查点段的输入会被保存,以便在反向传播时重新计算该段的正向结果。

  1. 参数说明
    • functions:一个 torch.nn.Sequential 对象或包含多个模块 / 函数的列表。
    • segments:模型被划分为的段数。
    • inputs:传递给 functions 的输入张量元组。
    • preserve_rng_state:布尔值,默认为 True。是否在每个检查点期间保存和恢复 RNG 状态。

三、实际应用案例

(一)单个模块的检查点应用

假设我们有一个简单的神经网络模块,我们希望对该模块应用检查点以减少显存占用。

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp


class CheckpointModel(nn.Module):
    def __init__(self):
        super(CheckpointModel, self).__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(10, 10)
        self.layer3 = nn.Linear(10, 2)


    def forward(self, x):
        # 对 layer2 应用检查点
        x = self.layer1(x)
        x = cp.checkpoint(self.layer2, x)
        x = self.layer3(x)
        return x


model = CheckpointModel()
input_var = torch.randn(1, 10)
output = model(input_var)

(二)顺序模型的检查点应用

对于顺序执行的模型,我们可以使用 checkpoint_sequential 来划分检查点段。

model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 2)
)


input_var = torch.randn(1, 10)
segments = 2  # 将模型划分为 2 个段
output = cp.checkpoint_sequential(model, segments, input_var)

四、性能与显存权衡

使用检查点机制虽然可以有效减少显存占用,但会增加计算时间,因为需要在反向传播过程中重新计算中间激活结果。在实际应用中,需要根据模型规模、显存限制和训练时间要求等因素,合理选择是否应用检查点机制以及如何划分检查点段。

五、总结

通过本教程,我们详细介绍了 PyTorch 中的检查点机制及其应用方法。检查点机制在训练大型深度学习模型时,能够有效减少显存占用,提高模型训练的可行性。正确理解和使用检查点机制,可以帮助我们在有限的硬件资源下训练更复杂的模型。

PyTorch torch.utils.bottleneck
PyTorch torch.utils.cpp_extension
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

Pytorch 音频

PyTorch 命名为 Tensor(实验性)

PyTorch 强化学习

PyTorch 用其他语言

PyTorch 语言绑定

PyTorch torchvision参考

PyTorch 音频参考

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }