PyTorch torch脚本
原文: PyTorch torch脚本
TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 进程中保存并加载到没有 Python 依赖项的进程中。
我们提供了将模型从纯 Python 程序逐步过渡到可以独立于 Python 运行的 TorchScript 程序的工具,例如在独立的 C ++程序中。 这样就可以使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后通过 TorchScript 将模型导出到生产环境中,在该生产环境中 Python 程序可能由于性能和多线程原因而处于不利地位。
有关 TorchScript 的简要介绍,请参见 TorchScript 简介教程。
有关将 PyTorch 模型转换为 TorchScript 并在 C ++中运行的端到端示例,请参见在 C ++中加载 PyTorch 模型教程。
创建 TorchScript 代码
class torch.jit.ScriptModule¶
property code¶
返回forward
方法的内部图的漂亮打印表示形式(作为有效的 Python 语法)。
property graph¶
返回forward
方法的内部图形的字符串表示形式。
save(f, _extra_files=ExtraFilesMap{})¶
torch.jit.save
。
class torch.jit.ScriptFunction¶
功能上与 ScriptModule
等效,但是代表单个功能,没有任何属性或参数。
torch.jit.script(obj)¶
为函数或nn.Module
编写脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,然后返回 ScriptModule
或 ScriptFunction
。 TorchScript 本身是 Python 语言的子集,因此 Python 并非所有功能都可以使用,但是我们提供了足够的功能来在张量上进行计算并执行与控制有关的操作。
torch.jit.script
可用作模块和功能的函数,以及 TorchScript 类和功能的修饰器@torch.jit.script
。
Scripting a function
@torch.jit.script
装饰器将通过编译函数的主体来构造 ScriptFunction
。
示例(编写函数):
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
print(type(foo)) # torch.jit.ScriptFuncion
## See the compiled graph as Python code
print(foo.code)
## Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
Scripting an nn.Module
默认情况下,为nn.Module
编写脚本将编译forward
方法,并递归编译forward
调用的任何方法,子模块和函数。 如果nn.Module
仅使用 TorchScript 支持的功能,则无需更改原始模块代码。 script
将构建 ScriptModule
,该副本具有原始模块的属性,参数和方法的副本。
示例(使用参数编写简单模块的脚本):
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
# This parameter will be copied to the new ScriptModule
self.weight = torch.nn.Parameter(torch.rand(N, M))
# When this submodule is used, it will be compiled
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mv(input)
# This calls the `forward` method of the `nn.Linear` module, which will
# cause the `self.linear` submodule to be compiled to a `ScriptModule` here
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3))
示例(使用跟踪的子模块编写模块脚本):
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
要编译除forward
以外的方法(并递归编译其调用的任何内容),请将 @torch.jit.export
装饰器添加到该方法。 要选择退出编译,请使用 @torch.jit.ignore
。
示例(模块中的导出方法和忽略方法):
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
@torch.jit.export
def some_entry_point(self, input):
return input + 10
@torch.jit.ignore
def python_only_fn(self, input):
# This function won't be compiled, so any
# Python APIs can be used
import pdb
pdb.set_trace()
def forward(self, input):
if self.training:
self.python_only_fn(input)
return input * 99
scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)¶
跟踪一个函数并返回将使用即时编译进行优化的可执行文件或 ScriptFunction
。 对于仅在Tensor
和Tensor
的列表,字典和元组上运行的代码,跟踪是理想的选择。
使用torch.jit.trace
和 torch.jit.trace_module
,您可以将现有模块或 Python 函数转换为 TorchScript ScriptFunction
或 ScriptModule
。 您必须提供示例输入,然后我们运行该函数,记录在所有张量上执行的操作。
- 独立功能的最终记录将产生
ScriptFunction
。 nn.Module
或nn.Module
的forward
功能的所得记录产生ScriptModule
。
该模块还包含原始模块也具有的任何参数。
警告
跟踪仅正确记录不依赖数据的功能和模块(例如,对张量中的数据没有条件)并且不包含任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。 跟踪仅记录在给定张量上运行给定函数时执行的操作。 因此,返回的 ScriptModule
将始终在任何输入上运行相同的跟踪图。 当期望模块根据输入和/或模块状态运行不同的操作集时,这具有重要意义。 例如,
- 跟踪将不会记录任何控制流,例如 if 语句或循环。 当整个模块的控制流恒定时,这很好,并且通常内联控制流决策。 但是有时控制流实际上是模型本身的一部分。 例如,循环网络是输入序列(可能是动态)长度上的循环。
- 在返回的
ScriptModule
中,在training
和eval
模式下具有不同行为的操作将始终像在跟踪过程中一样处于运行状态,无论是哪种模式 ]ScriptModule
已插入。
在这种情况下,跟踪是不合适的, scripting
是更好的选择。 如果跟踪此类模型,则可能在随后的模型调用中静默地得到不正确的结果。 在执行可能会导致产生不正确跟踪的操作时,跟踪器将尝试发出警告。
参数
- 函数(可调用的或 torch.nn.Module)– Python 函数或
torch.nn.Module
与example_inputs
一起运行。func
的参数和返回值必须是张量或包含张量的(可能是嵌套的)元组。 将模块传递到torch.jit.trace
时,仅运行并跟踪forward
方法(有关详细信息,参见torch.jit.trace
)。 - example_inputs (tuple )–示例输入的元组,将在跟踪时传递给函数。 假设跟踪的操作支持这些类型和形状,则可以使用不同类型和形状的输入来运行结果跟踪。
example_inputs
也可以是单个张量,在这种情况下,它会自动包装在元组中。
Keyword Arguments
- check_trace (
bool
,可选)–检查通过跟踪代码运行的相同输入是否产生相同的输出。 默认值:True
。 例如,如果您的网络包含不确定性操作,或者即使检查程序失败,但您确定网络正确,则可能要禁用此功能。 - check_inputs (元组列表 , 可选)–输入参数的元组列表,应使用这些元组来检查跟踪内容 是期待。 每个元组等效于
example_inputs
中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的example_inputs
进行检查 - check_tolerance (python:float , 可选)–在检查程序中使用的浮点比较公差。 如果结果由于已知原因(例如操作员融合)而在数值上出现差异,则可以使用此方法来放松检查器的严格性。
退货
如果callable
是nn.Module
的nn.Module
或forward
,则trace
将使用包含跟踪代码的单个forward
方法返回 ScriptModule
对象。 返回的 ScriptModule
将具有与原始nn.Module
相同的子模块和参数集。 如果callable
是独立功能,则trace
返回 ScriptFunction
示例(跟踪函数):
import torch
def foo(x, y):
return 2 * x + y
## Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
## `traced_foo` can now be run with the TorchScript interpreter or saved
## and loaded in a Python-free environment
示例(跟踪现有模块):
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
## Trace a specific method and construct `ScriptModule` with
## a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)
## Trace a module (implicitly traces `forward`) and construct a
## `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)
torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)¶
跟踪模块并返回可执行文件 ScriptModule
,该文件将使用即时编译进行优化。 将模块传递到 torch.jit.trace
时,仅运行并跟踪forward
方法。 使用trace_module
,您可以指定方法名称的字典作为示例输入,以跟踪下面的参数(请参见example_inputs
)。
有关跟踪的更多信息,参见 torch.jit.trace
。
Parameters
- mod (Torch.nn.Module)–一种
torch.nn.Module
,其中包含名称在example_inputs
中指定的方法。 给定的方法将被编译为单个 <cite>ScriptModule</cite> 的一部分。 - example_inputs (dict )–包含样本输入的字典,该样本输入由
mod
中的方法名称索引。 输入将在跟踪时传递给名称与输入键对应的方法。{ 'forward' : example_forward_input, 'method2': example_method2_input}
Keyword Arguments
- check_trace (
bool
, optional) – Check if the same inputs run through traced code produce the same outputs. Default:True
. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure. - check_inputs (字典列表 , 可选)–输入参数的字典列表,用于检查跟踪内容 是期待。 每个元组等效于
example_inputs
中指定的一组输入参数。 为了获得最佳结果,请传递一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。 如果未指定,则使用原始的example_inputs
进行检查 - check_tolerance (python:float__, optional) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.
Returns
具有单个forward
方法的 ScriptModule
对象,其中包含跟踪的代码。 当func
是torch.nn.Module
时,返回的 ScriptModule
将具有与func
相同的子模块和参数集。
示例(使用多种方法跟踪模块):
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
def weighted_kernel_sum(self, weight):
return weight * self.conv.weight
n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
## Trace a specific method and construct `ScriptModule` with
## a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)
## Trace a module (implicitly traces `forward`) and construct a
## `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)
## Trace specific methods on a module (specified in `inputs`), constructs
## a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)
torch.jit.save(m, f, _extra_files=ExtraFilesMap{})¶
保存此模块的脱机版本以在单独的过程中使用。 保存的模块将序列化此模块的所有方法,子模块,参数和属性。 可以使用torch::jit::load(filename)
将其加载到 C ++ API 中,或者使用 torch.jit.load
加载到 Python API 中。
为了能够保存模块,它不得对本地 Python 函数进行任何调用。 这意味着所有子模块也必须是torch.jit.ScriptModule
的子类。
危险
所有模块,无论使用哪种设备,都始终在加载期间加载到 CPU 中。 这与 load
的语义不同,并且将来可能会发生变化。
Parameters
- m –要保存的 ScriptModule。
- f –类似于文件的对象(必须实现写入和刷新)或包含文件名的字符串。
- _extra_files -从文件名映射到将作为“ f”的一部分存储的内容。
Warning
如果您使用的是 Python 2,torch.jit.save
不支持StringIO.StringIO
作为有效的类似文件的对象。 这是因为 write 方法应返回写入的字节数; StringIO.write()
不这样做。
请改用io.BytesIO
之类的东西。
例:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
m = torch.jit.script(MyModule())
## Save to file
torch.jit.save(m, 'scriptmodule.pt')
## This line is equivalent to the previous
m.save("scriptmodule.pt")
## Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)
## Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
torch.jit.load(f, map_location=None, _extra_files=ExtraFilesMap{})¶
加载先前用 torch.jit.save
保存的 ScriptModule
或 ScriptFunction
之前保存的所有模块,无论使用何种设备,都首先加载到 CPU 中,然后再移动到保存它们的设备上。 如果失败(例如,因为运行时系统没有某些设备),则会引发异常。
Parameters
- f –类似于文件的对象(必须实现读取,读取行,告诉和查找),或包含文件名的字符串
- map_location (字符串 或 torch设备)–
torch.save
中map_location
的简化版本 用于动态地将存储重新映射到另一组设备。 - _extra_files (文件名到内容的字典)–映射中给定的多余文件名将被加载,其内容将存储在提供的映射中。
Returns
ScriptModule
对象。
Example:
import torch
import io
torch.jit.load('scriptmodule.pt')
## Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
## Load all tensors to the original device
torch.jit.load(buffer)
## Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))
## Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')
## Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])
混合跟踪和脚本编写
在许多情况下,将模型转换为 TorchScript 都可以使用跟踪或脚本编写。 可以组成跟踪和脚本以适合模型一部分的特定要求。
脚本函数可以调用跟踪函数。 当您需要在简单的前馈模型周围使用控制流时,这特别有用。 例如,序列到序列模型的波束搜索通常将用脚本编写,但是可以调用使用跟踪生成的编码器模块。
示例(在脚本中调用跟踪的函数):
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
跟踪的函数可以调用脚本函数。 即使大部分模型只是前馈网络,当模型的一小部分需要一些控制流时,这也很有用。 跟踪函数调用的脚本函数内部的控制流已正确保留。
示例(在跟踪函数中调用脚本函数):
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
此组合也适用于nn.Module
,在这里它可用于通过跟踪来生成子模块,该跟踪可以从脚本模块的方法中调用。
示例(使用跟踪模块):
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
迁移到 PyTorch 1.2 递归脚本 API
本节详细介绍了 PyTorch 1.2 中对 TorchScript 的更改。 如果您不熟悉 TorchScript,则可以跳过本节。 PyTorch 1.2 对 TorchScript API 进行了两个主要更改。
\1. torch.jit.script
现在将尝试递归编译遇到的函数,方法和类。 调用torch.jit.script
后,编译将是“选择退出”,而不是“选择加入”。
2.现在torch.jit.script(nn_module_instance)
是创建 ScriptModule
的首选方法,而不是从torch.jit.ScriptModule
继承。 这些更改组合在一起,提供了一个更简单易用的 API,可将您的nn.Module
转换为 ScriptModule
,可以在非 Python 环境中进行优化和执行。
新用法如下所示:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
my_model = Model()
my_scripted_model = torch.jit.script(my_model)
- 该模块的
forward
是默认编译的。 从forward
调用的方法将按照在forward
中使用的顺序进行延迟编译。 - 要编译未从
forward
调用的forward
以外的方法,请添加@torch.jit.export
。 - 要停止编译器编译方法,请添加
@torch.jit.ignore
或@torch.jit.unused
。@ignore
离开 - 方法作为对 python 的调用,并且
@unused
将其替换为异常。@ignored
无法导出;@unused
可以。 - 可以推断大多数属性类型,因此不需要
torch.jit.Attribute
。 对于空容器类型,请使用 PEP 526 样式类注释对其类型进行注释。 - 可以使用
Final
类注释来标记常量,而不是将成员的名称添加到__constants__
中。 - 可以使用 Python 3 类型提示代替
torch.jit.annotate
As a result of these changes, the following items are considered deprecated and should not appear in new code:
@torch.jit.script_method
装饰器- 继承自
torch.jit.ScriptModule
的类 torch.jit.Attribute
包装器类__constants__
数组torch.jit.annotate
功能
模块
Warning
@torch.jit.ignore
注释的行为在 PyTorch 1.2 中发生了变化。 在 PyTorch 1.2 之前,@ ignore 装饰器用于使函数或方法可从导出的代码中调用。 要恢复此功能,请使用@torch.jit.unused()
。 @torch.jit.ignore
现在等同于@torch.jit.ignore(drop=False)
。 有关详细信息,参见 @torch.jit.ignore
和 @torch.jit.unused
。
当传递给 torch.jit.script
函数时,torch.nn.Module
的数据将复制到 ScriptModule
,然后 TorchScript 编译器将编译该模块。 该模块的forward
默认为编译状态。 从forward
调用的方法以及它们在forward
中使用的顺序都是按延迟顺序编译的。
torch.jit.export(fn)¶
此修饰符指示nn.Module
上的方法用作 ScriptModule
的入口点,应进行编译。
forward
隐式地假定为入口点,因此不需要此装饰器。 从forward
调用的函数和方法在编译器看到的情况下进行编译,因此它们也不需要此装饰器。
示例(在方法上使用@torch.jit.export
):
import torch
import torch.nn as nn
class MyModule(nn.Module):
def implicitly_compiled_method(self, x):
return x + 99
# `forward` is implicitly decorated with `@torch.jit.export`,
# so adding it here would have no effect
def forward(self, x):
return x + 10
@torch.jit.export
def another_forward(self, x):
# When the compiler sees this call, it will compile
# `implicitly_compiled_method`
return self.implicitly_compiled_method(x)
def unused_method(self, x):
return x - 20
## `m` will contain compiled methods:
## `forward`
## `another_forward`
## `implicitly_compiled_method`
## `unused_method` will not be compiled since it was not called from
## any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
功能
功能没有太大变化,可以根据需要用 @torch.jit.ignore
或 torch.jit.unused
装饰。
## Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
return 2
## Marks a function as ignored, if nothing
## ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
return 2
## As with ignore, if nothing calls it then it has no effect.
## If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
import pdb; pdb.set_trace()
return 4
## Doesn't do anything, this function is already
## the main entry point
@torch.jit.export
def some_fn4():
return 2
TorchScript 类
默认情况下,将导出用户定义的 TorchScript 类中的所有内容,可以根据需要用 @torch.jit.ignore
修饰功能。
属性
TorchScript 编译器需要知道模块属性的类型。 大多数类型可以从成员的值推断出来。 空列表和字典不能推断其类型,而必须使用 PEP 526 样式类注释来注释其类型。 如果无法推断类型并且未对显式类型进行注释,则不会将其作为属性添加到结果 ScriptModule
旧 API:
from typing import Dict
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.my_dict = torch.jit.Attribute({}, Dict[str, int])
self.my_int = torch.jit.Attribute(20, int)
m = MyModule()
新 API:
from typing import Dict
class MyModule(torch.nn.Module):
my_dict: Dict[str, int]
def __init__(self):
super(MyModule, self).__init__()
# This type cannot be inferred and must be specified
self.my_dict = {}
# The attribute type here is inferred to be `int`
self.my_int = 20
def forward(self):
pass
m = torch.jit.script(MyModule())
Python 2
如果您受制于 Python 2 并且无法使用类注释语法,则可以使用__annotations__
类成员直接应用类型注释。
from typing import Dict
class MyModule(torch.jit.ScriptModule):
__annotations__ = {'my_dict': Dict[str, int]}
def __init__(self):
super(MyModule, self).__init__()
self.my_dict = {}
self.my_int = 20
常数
Final
类型的构造函数可用于将成员标记为常量。 如果成员未标记为常量,则将其复制为结果 ScriptModule
作为属性。 如果已知该值是固定的,则使用Final
可以进行优化,并提供附加的类型安全性。
Old API:
class MyModule(torch.jit.ScriptModule):
__constants__ = ['my_constant']
def __init__(self):
super(MyModule, self).__init__()
self.my_constant = 2
def forward(self):
pass
m = MyModule()
New API:
try:
from typing_extensions import Final
except:
# If you don't have `typing_extensions` installed, you can use a
# polyfill from `torch.jit`.
from torch.jit import Final
class MyModule(torch.nn.Module):
my_constant: Final[int]
def __init__(self):
super(MyModule, self).__init__()
self.my_constant = 2
def forward(self):
pass
m = torch.jit.script(MyModule())
变量
假定容器的类型为Tensor
,并且是非可选的(有关更多信息,请参见默认类型)。 以前,torch.jit.annotate
用来告诉 TorchScript 编译器类型是什么。 现在支持 Python 3 样式类型提示。
import torch
from typing import Dict, Optional
@torch.jit.script
def make_dict(flag: bool):
x: Dict[str, int] = {}
x['hi'] = 2
b: Optional[int] = None
if flag:
b = 2
return x, b
TorchScript 语言参考
TorchScript 是 Python 的静态类型子集,可以直接编写(使用 @torch.jit.script
装饰器),也可以通过跟踪从 Python 代码自动生成。 使用跟踪时,通过仅在张量上记录实际的运算符并简单地执行和丢弃其他周围的 Python 代码,代码会自动转换为 Python 的此子集。
使用@torch.jit.script
装饰器直接编写 TorchScript 时,程序员只能使用 TorchScript 支持的 Python 子集。 本节记录了 TorchScript 支持的功能,就像它是独立语言的语言参考一样。 本参考中未提及的 Python 的任何功能都不属于 TorchScript。 有关可用的 Pytorch 张量方法,模块和功能的完整参考,请参见内置函数。
作为 Python 的子集,任何有效的 TorchScript 函数也是有效的 Python 函数。 这样就可以禁用 TorchScript 并使用pdb
之类的标准 Python 工具调试该功能。 反之则不成立:有许多有效的 Python 程序不是有效的 TorchScript 程序。 相反,TorchScript 特别专注于表示 PyTorch 中的神经网络模型所需的 Python 功能。
类型
TorchScript 与完整的 Python 语言之间的最大区别是 TorchScript 仅支持表达神经网络模型所需的一小部分类型。 特别是,TorchScript 支持:
|
类型
|
描述
| | --- | --- | | Tensor
| 任何 dtype,尺寸或后端的 PyTorch 张量 | | Tuple[T0, T1, ...]
| 包含子类型T0
,T1
等(例如Tuple[Tensor, Tensor]
)的元组 | | bool
| 布尔值 | | int
| 标量整数 | | float
| 标量浮点数 | | str
| 一串 | | List[T]
| 所有成员均为T
类型的列表 | | Optional[T]
| 无或输入T
的值 | | Dict[K, V]
| 键类型为K
而值类型为V
的字典。 只能将str
,int
和float
作为密钥类型。 | | T
| 一个 TorchScript 类 | | NamedTuple[T0, T1, ...]
| collections.namedtuple
元组类型 |
与 Python 不同,TorchScript 函数中的每个变量都必须具有一个静态类型。 这使优化 TorchScript 函数变得更加容易。
示例(类型不匹配)
import torch
@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r
Traceback (most recent call last):
...
RuntimeError: ...
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
if x:
~~~~~... <--- HERE
r = torch.rand(1)
else:
and was used here:
else:
r = 4
return r
~ <--- HERE
...
默认类型
默认情况下,TorchScript 函数的所有参数均假定为 Tensor。 要指定 TorchScript 函数的参数是其他类型,可以使用上面列出的类型使用 MyPy 样式的类型注释。
import torch
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
注意
也可以使用typing
模块中的 Python 3 类型提示来注释类型。
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
在我们的示例中,我们使用基于注释的类型提示来确保 Python 2 的兼容性。
假定空列表为List[Tensor]
,空字典为Dict[str, Tensor]
。 要实例化其他类型的空列表或字典,请使用 Python 3 类型提示。 如果您使用的是 Python 2,则可以使用torch.jit.annotate
。
示例(Python 3 的类型注释):
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super(EmptyDataStructures, self).__init__()
def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
# This annotates the list to be a `List[Tuple[int, float]]`
my_list: List[Tuple[int, float]] = []
for i in range(10):
my_list.append((i, x.item()))
my_dict: Dict[str, int] = {}
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
示例(适用于 Python 2 的torch.jit.annotate
):
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super(EmptyDataStructures, self).__init__()
def forward(self, x):
# type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]
# This annotates the list to be a `List[Tuple[int, float]]`
my_list = torch.jit.annotate(List[Tuple[int, float]], [])
for i in range(10):
my_list.append((i, float(x.item())))
my_dict = torch.jit.annotate(Dict[str, int], {})
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
可选类型细化
在 if 语句的条件内或在assert
中检查与None
的比较时,TorchScript 将优化Optional[T]
类型的变量的类型。 编译器可以推理与and
,or
和not
结合的多个None
检查。 对于未明确编写的 if 语句的 else 块,也会进行优化。
None
检查必须在 if 语句的条件内; 将None
检查分配给变量,并在 if 语句的条件下使用它,将不会优化检查中的变量类型。 仅局部变量将被细化,self.x
之类的属性将不会且必须分配给要细化的局部变量。
示例(优化参数和局部变量的类型):
import torch
import torch.nn as nn
from typing import Optional
class M(nn.Module):
z: Optional[int]
def __init__(self, z):
super(M, self).__init__()
# If `z` is None, its type cannot be inferred, so it must
# be specified (above)
self.z = z
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
TorchScript 类
如果 Python 类使用 @torch.jit.script
注释,则可以在 TorchScript 中使用,类似于声明 TorchScript 函数的方式:
@torch.jit.script
class Foo:
def __init__(self, x, y):
self.x = x
def aug_add_x(self, inc):
self.x += inc
此子集受限制:
- 所有函数必须是有效的 TorchScript 函数(包括
__init__()
)。
- 这些类必须是新型类,因为我们使用
__new__()
和 pybind11 来构造它们。
- TorchScript 类是静态类型的。 只能通过在
__init__()
方法中分配给 self 来声明成员。
\> 例如,在__init__()
方法之外分配给self
: > > > @torch.jit.script > class Foo: > def assign_x(self): > self.x = torch.rand(2, 3) > >
> > 将导致: > > > RuntimeError: > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: > def assign_x(self): > self.x = torch.rand(2, 3) > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE > >
- 类的主体中不允许使用除方法定义之外的任何表达式。
- 除了从
object
继承以指定新样式类外,不支持继承或任何其他多态策略。
定义了一个类之后,就可以像其他任何 TorchScript 类型一样在 TorchScript 和 Python 中互换使用该类:
## Declare a TorchScript class
@torch.jit.script
class Pair:
def __init__(self, first, second):
self.first = first
self.second = second
@torch.jit.script
def sum_pair(p):
# type: (Pair) -> Tensor
return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))
命名为元组
collections.namedtuple
产生的类型可以在 TorchScript 中使用。
import torch
import collections
Point = collections.namedtuple('Point', ['x', 'y'])
@torch.jit.script
def total(point):
# type: (Point) -> Tensor
return point.x + point.y
p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
表达式
支持以下 Python 表达式。
文字
True
False
None
'string literals'
"string literals"
3 # interpreted as int
3.4 # interpreted as a float
列表结构
假定一个空列表具有List[Tensor]
类型。 其他列表文字的类型是从成员的类型派生的。 有关更多详细信息,请参见默认类型。
[3, 4]
[]
[torch.rand(3), torch.rand(4)]
元组结构
(3, 4)
(3,)
字典结构
假定一个空字典为Dict[str, Tensor]
类型。 其他 dict 文字的类型是从成员的类型派生的。 有关更多详细信息,请参见默认类型。
{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}
变量
有关如何解析变量的信息,请参见变量分辨率。
my_variable_name
算术运算符
a + b
a - b
a * b
a / b
a ^ b
a @ b
比较运算符
a == b
a != b
a < b
a > b
a <= b
a >= b
逻辑运算符
a and b
a or b
not b
下标和切片
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]
函数调用
调用内置函数
torch.rand(3, dtype=torch.int)
调用其他脚本函数:
import torch
@torch.jit.script
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
return foo(x)
方法调用
调用诸如张量之类的内置类型的方法:x.mm(y)
在模块上,必须先编译方法才能调用它们。 TorchScript 编译器以递归方式编译在编译其他方法时看到的方法。 默认情况下,编译从forward
方法开始。 将编译forward
调用的任何方法,以及这些方法调用的任何方法,依此类推。 要以forward
以外的方法开始编译,请使用 @torch.jit.export
装饰器(forward
隐式标记为@torch.jit.export
)。
直接调用子模块(例如self.resnet(input)
)等效于调用其forward
方法(例如self.resnet.forward(input)
)。
import torch
import torch.nn as nn
import torchvision
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
means = torch.tensor([103.939, 116.779, 123.68])
self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
resnet = torchvision.models.resnet18()
self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))
def helper(self, input):
return self.resnet(input - self.means)
def forward(self, input):
return self.helper(input)
# Since nothing in the model calls `top_level_method`, the compiler
# must be explicitly told to compile this method
@torch.jit.export
def top_level_method(self, input):
return self.other_helper(input)
def other_helper(self, input):
return input + 10
## `my_script_module` will have the compiled methods `forward`, `helper`,
## `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())
三元表达式
x if x > y else y
演员表
float(ten)
int(3.5)
bool(ten)
str(2)``
访问模块参数
self.my_parameter
self.my_submodule.my_parameter
语句
TorchScript 支持以下类型的语句:
简单分配
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b
模式匹配分配
a, b = tuple_or_list
a, b, *c = a_tuple
多项分配
a = b, c = tup
打印报表
print("the result of an add:", a + b)
If 语句
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
除布尔值外,浮点数,整数和张量还可以在条件中使用,并将隐式转换为布尔值。
While 循环
a = 0
while a < 4:
print(a)
a += 1
适用于范围为的循环
x = 0
for i in range(10):
x *= i
用于遍历元组的循环
这些展开循环,为元组的每个成员生成一个主体。 主体必须对每个成员进行正确的类型检查。
tup = (3, torch.rand(4))
for x in tup:
print(x)
用于在常量 nn.ModuleList 上循环
要在已编译方法中使用nn.ModuleList
,必须通过将属性名称添加到__constants__
列表中的类型来将其标记为常量。 nn.ModuleList
上的 for 循环将在编译时展开循环的主体,并使用常量模块列表的每个成员。
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
__constants__ = ['mods']
def __init__(self):
super(MyModule, self).__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
m = torch.jit.script(MyModule())
中断并继续
for i in range(5):
if i == 1:
continue
if i == 3:
break
print(i)
返回
return a, b
可变分辨率
TorchScript 支持 Python 的可变分辨率(即作用域)规则的子集。 局部变量的行为与 Python 中的相同,不同之处在于,在通过函数的所有路径上,变量必须具有相同的类型。 如果变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束后使用它是错误的。
同样,如果沿函数的某些路径仅将定义为,则不允许使用该变量。
Example:
@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y)
Traceback (most recent call last):
...
RuntimeError: ...
y is not defined in the false branch...
@torch.jit.script...
def foo(x):
if x < 0:
~~~~~~~~~... <--- HERE
y = 4
print(y)
...
定义函数时,会在编译时将非局部变量解析为 Python 值。 然后使用 Python 值使用中描述的规则将这些值转换为 TorchScript 值。
使用 Python 值
为了使编写 TorchScript 更加方便,我们允许脚本代码引用周围范围中的 Python 值。 例如,任何时候只要引用torch
,当声明函数时,TorchScript 编译器实际上就会将其解析为torch
Python 模块。 这些 Python 值不是 TorchScript 的一流部分。 而是在编译时将它们分解为 TorchScript 支持的原始类型。 这取决于编译发生时引用的 Python 值的动态类型。 本节介绍在 TorchScript 中访问 Python 值时使用的规则。
功能
TorchScript 可以调用 Python 函数。 当将模型逐步转换为 TorchScript 时,此功能非常有用。 可以将模型逐函数移至 TorchScript,而对 Python 函数的调用保留在原处。 这样,您可以在进行过程中逐步检查模型的正确性。
torch.jit.ignore(drop=False, **kwargs)¶
该装饰器向编译器指示应忽略函数或方法,而将其保留为 Python 函数。 这使您可以将代码保留在尚未与 TorchScript 兼容的模型中。 具有忽略功能的模型无法导出; 请改用 torch.jit.unused。
示例(在方法上使用@torch.jit.ignore
):
import torch
import torch.nn as nn
class MyModule(nn.Module):
@torch.jit.ignore
def debugger(self, x):
import pdb
pdb.set_trace()
def forward(self, x):
x += 10
# The compiler would normally try to compile `debugger`,
# but since it is `@ignore`d, it will be left as a call
# to Python
self.debugger(x)
return x
m = torch.jit.script(MyModule())
## Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")
示例(在方法上使用@torch.jit.ignore(drop=True)
):
import torch
import torch.nn as nn
class MyModule(nn.Module):
@torch.jit.ignore(drop=True)
def training_method(self, x):
import pdb
pdb.set_trace()
def forward(self, x):
if self.training:
self.training_method(x)
return x
m = torch.jit.script(MyModule())
## This is OK since `training_method` is not saved, the call is replaced
## with a `raise`.
m.save("m.pt")
torch.jit.unused(fn)¶
此装饰器向编译器指示应忽略函数或方法,并用引发异常的方法代替。 这样,您就可以在尚不兼容 TorchScript 的模型中保留代码,并仍然可以导出模型。
示例(在方法上使用
@torch.jit.unused
):
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self, use_memory_efficent):
super(MyModule, self).__init__()
self.use_memory_efficent = use_memory_efficent
@torch.jit.unused
def memory_efficient(self, x):
import pdb
pdb.set_trace()
return x + 10
def forward(self, x):
# Use not-yet-scriptable memory efficient mode
if self.use_memory_efficient:
return self.memory_efficient(x)
else:
return x + 10
m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")
m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))
torch.jit.is_scripting()¶
在编译时返回 True 的函数,否则返回 False 的函数。 这对于使用@unused 装饰器尤其有用,可以将尚不兼容 TorchScript 的代码保留在模型中。 .. testcode:
import torch
@torch.jit.unused
def unsupported_linear_op(x):
return x
def linear(x):
if not torch.jit.is_scripting():
return torch.linear(x)
else:
return unsupported_linear_op(x)
Python 模块上的属性查找
TorchScript 可以在模块上查找属性。 像torch.add
这样的内置功能可以通过这种方式访问。 这使 TorchScript 可以调用其他模块中定义的函数。
Python 定义的常量
TorchScript 还提供了一种使用 Python 中定义的常量的方法。 这些可用于将超参数硬编码到函数中,或定义通用常量。 有两种指定 Python 值应视为常量的方式。
- 查找为模块属性的值假定为常量:
import math
import torch
@torch.jit.script
def fn():
return math.pi
- 可以通过使用
Final[T]
注释 ScriptModule 的属性来将其标记为常量。
import torch
import torch.nn as nn
class Foo(nn.Module):
# `Final` from the `typing_extensions` module can also be used
a : torch.jit.Final[int]
def __init__(self):
super(Foo, self).__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
支持的常量 Python 类型是
int
float
bool
torch.device
torch.layout
torch.dtype
- 包含受支持类型的元组
torch.nn.ModuleList
可以在 TorchScript for 循环中使用
Note
如果您使用的是 Python 2,则可以通过将属性名称添加到类的__constants__
属性中来将其标记为常量:
import torch
import torch.nn as nn
class Foo(nn.Module):
__constants__ = ['a']
def __init__(self):
super(Foo, self).__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
模块属性
torch.nn.Parameter
包装器和register_buffer
可用于将张量分配给模块。 如果可以推断出其他类型的值,则分配给已编译模块的其他值将添加到已编译模块中。 TorchScript 中可用的所有类型都可以用作模块属性。 张量属性在语义上与缓冲区相同。 空列表和字典的类型以及None
值无法推断,必须通过 PEP 526 样式类注释指定。 如果无法推断出类型并且未对其进行显式注释,则不会将其作为属性添加到结果 ScriptModule
中。
Example:
from typing import List, Dict
class Foo(nn.Module):
# `words` is initialized as an empty list, so its type must be specified
words: List[str]
# The type could potentially be inferred if `a_dict` (below) was not
# empty, but this annotation ensures `some_dict` will be made into the
# proper type
some_dict: Dict[str, int]
def __init__(self, a_dict):
super(Foo, self).__init__()
self.words = []
self.some_dict = a_dict
# `int`s can be inferred
self.my_int = 10
def forward(self, input):
# type: (str) -> int
self.words.append(input)
return self.some_dict[input] + self.my_int
f = torch.jit.script(Foo({'hi': 2}))
Note
如果您使用的是 Python 2,则可以通过将属性的类型添加到__annotations__
类属性中作为属性名字典来标记属性的类型
from typing import List, Dict
class Foo(nn.Module):
__annotations__ = {'words': List[str], 'some_dict': Dict[str, int]}
def __init__(self, a_dict):
super(Foo, self).__init__()
self.words = []
self.some_dict = a_dict
# `int`s can be inferred
self.my_int = 10
def forward(self, input):
# type: (str) -> int
self.words.append(input)
return self.some_dict[input] + self.my_int
f = torch.jit.script(Foo({'hi': 2}))
调试
禁用用于调试的 JIT
PYTORCH_JIT¶
设置环境变量PYTORCH_JIT=0
将禁用所有脚本和跟踪注释。 如果您的 TorchScript 模型之一存在难以调试的错误,则可以使用此标志来强制一切都使用本机 Python 运行。 由于此标志禁用了 TorchScript(脚本编写和跟踪),因此可以使用pdb
之类的工具来调试模型代码。
给定一个示例脚本:
@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x
def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
除调用,@torch.jit.script
,函数外,使用pdb
调试此脚本是可行的。 我们可以全局禁用 JIT,以便我们可以将 @torch.jit.script
函数作为普通的 Python 函数调用,而不进行编译。 如果上述脚本称为disable_jit_example.py
,我们可以这样调用它:
$ PYTORCH_JIT=0 python disable_jit_example.py
并且我们将能够像普通的 Python 函数一样进入 @torch.jit.script
函数。 要为特定功能禁用 TorchScript 编译器,请参见 @torch.jit.ignore
。
检查码
TorchScript 为所有 ScriptModule
实例提供了代码漂亮的打印机。 这个漂亮的打印机可以将脚本方法的代码解释为有效的 Python 语法。 例如:
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.code)
具有单个forward
方法的 ScriptModule
将具有属性code
,您可以使用该属性检查 ScriptModule
的代码。 如果 ScriptModule
具有多个方法,则需要在方法本身而非模块上访问.code
。 我们可以通过访问.foo.code
在 ScriptModule 上检查名为foo
的方法的代码。 上面的示例产生以下输出:
def foo(len: int) -> Tensor:
rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
rv0 = rv
for i in range(len):
if torch.lt(i, 10):
rv1 = torch.sub(rv0, 1., 1)
else:
rv1 = torch.add(rv0, 1., 1)
rv0 = rv1
return rv0
这是 TorchScript 对forward
方法的代码的编译。 您可以使用它来确保 TorchScript(跟踪或脚本)正确捕获了模型代码。
解释图
TorchScript 还以 IR 图的形式在比代码漂亮打印机更低的层次上进行表示。
TorchScript 使用静态单分配(SSA)中间表示(IR)表示计算。 这种格式的指令由 ATen(PyTorch 的 C ++后端)运算符和其他原始运算符组成,包括用于循环和条件的控制流运算符。 举个例子:
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.graph)
graph
遵循检查代码部分中关于forward
方法查找所述的相同规则。
上面的示例脚本生成图形:
graph(%len.1 : int):
%24 : int = prim::Constant[value=1]()
%17 : bool = prim::Constant[value=1]() # test.py:10:5
%12 : bool? = prim::Constant()
%10 : Device? = prim::Constant()
%6 : int? = prim::Constant()
%1 : int = prim::Constant[value=3]() # test.py:9:22
%2 : int = prim::Constant[value=4]() # test.py:9:25
%20 : int = prim::Constant[value=10]() # test.py:11:16
%23 : float = prim::Constant[value=1]() # test.py:12:23
%4 : int[] = prim::ListConstruct(%1, %2)
%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
%rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
block0(%i.1 : int, %rv.14 : Tensor):
%21 : bool = aten::lt(%i.1, %20) # test.py:11:12
%rv.13 : Tensor = prim::If(%21) # test.py:11:9
block0():
%rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
-> (%rv.3)
block1():
%rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
-> (%rv.6)
-> (%17, %rv.13)
return (%rv)
以指令%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
为例。
%rv.1 : Tensor
表示我们将输出分配给一个名为rv.1
的(唯一)值,该值是Tensor
类型,并且我们不知道其具体形状。aten::zeros
是运算符(与torch.zeros
等效),输入列表(%4, %6, %6, %10, %12)
指定范围中的哪些值应作为输入传递。 可以在内置函数中找到aten::zeros
等内置函数的模式。# test.py:9:10
是生成此指令的原始源文件中的位置。 在这种情况下,它是第 9 行和字符 10 处名为 <cite>test.py</cite> 的文件。
请注意,运算符也可以具有关联的blocks
,即prim::Loop
和prim::If
运算符。 在图形打印输出中,这些运算符被格式化以反映其等效的源代码形式,以方便进行调试。
如下图所示,可以检查图表以确认 ScriptModule
所描述的计算是正确的,无论是自动方式还是手动方式。
追踪案例
在某些极端情况下,给定 Python 函数/模块的跟踪不会代表基础代码。 这些情况可以包括:
- 跟踪取决于输入的控制流(例如张量形状)
- 跟踪张量视图的就地操作(例如,分配左侧的索引)
请注意,这些情况实际上将来可能是可追溯的。
自动跟踪检查
自动捕获跟踪中许多错误的一种方法是使用torch.jit.trace()
API 上的check_inputs
。 check_inputs
提取输入元组的列表,这些列表将用于重新追踪计算并验证结果。 例如:
def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
为我们提供以下诊断信息:
ERROR: Graphs differed across invocations!
Graph diff:
graph(%x : Tensor) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Tensor = aten::select(%x, %4, %5)
%result.2 : Tensor = aten::mul(%result.1, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Tensor = aten::select(%x, %8, %9)
- %result : Tensor = aten::mul(%result.2, %10)
+ %result.3 : Tensor = aten::mul(%result.2, %10)
? ++
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Tensor = aten::select(%x, %12, %13)
+ %result : Tensor = aten::mul(%result.3, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Tensor = aten::select(%x, %16, %17)
- %15 : Tensor = aten::mul(%result, %14)
? ^ ^
+ %19 : Tensor = aten::mul(%result, %18)
? ^ ^
- return (%15);
? ^
+ return (%19);
? ^
}
此消息向我们表明,在我们第一次追踪它和使用check_inputs
追踪它之间,计算有所不同。 实际上,loop_in_traced_fn
主体内的循环取决于输入x
的形状,因此,当我们尝试另一种形状不同的x
时,迹线会有所不同。
在这种情况下,可以使用 torch.jit.script()
来捕获类似于数据的控制流:
def fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:
torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))
产生:
graph(%x : Tensor) {
%5 : bool = prim::Constant[value=1]()
%1 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %1)
%4 : int = aten::size(%x, %1)
%result : Tensor = prim::Loop(%4, %5, %result.1)
block0(%i : int, %7 : Tensor) {
%10 : Tensor = aten::select(%x, %1, %i)
%result.2 : Tensor = aten::mul(%7, %10)
-> (%5, %result.2)
}
return (%result);
}
跟踪器警告
跟踪器会针对跟踪计算中的几种有问题的模式生成警告。 举个例子,追踪一个在 Tensor 的切片(视图)上包含就地分配的函数:
def fill_row_zero(x):
x[0] = torch.rand(*x.shape[1:2])
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
产生几个警告和一个仅返回输入的图形:
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1\. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
return (%0);
}
我们可以通过修改代码来解决此问题,使其不使用就地更新,而是使用torch.cat
来错位构建结果张量:
def fill_row_zero(x):
x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
内置函数
TorchScript 支持 PyTorch 提供的内置张量和神经网络功能的子集。 Tensor 上的大多数方法以及torch
名称空间中的函数,torch.nn.functional
中的所有函数以及torch.nn
中的所有模块在 TorchScript 中均受支持,下表中没有列出。 对于不支持的模块,建议使用 torch.jit.trace()
。
不支持的torch.nn
模块
torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
torch.nn.modules.normalization.CrossMapLRN2d
torch.nn.modules.rnn.RNN
有关支持的功能的完整参考,请参见 TorchScript 内置函数。
常见问题解答
问:我想在 GPU 上训练模型并在 CPU 上进行推理。 最佳做法是什么?
首先将模型从 GPU 转换为 CPU,然后将其保存,如下所示:
cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pth")
traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pth")
# ... later, when using the model:
if use_gpu:
model = torch.jit.load("gpu.pth")
else:
model = torch.jit.load("cpu.pth")
model(input)
推荐这样做是因为跟踪器可能会在特定设备上见证张量的创建,因此强制转换已加载的模型可能会产生意想不到的效果。 在保存之前对模型进行转换可确保跟踪器具有正确的设备信息。
问:如何在 ScriptModule
上存储属性?
说我们有一个像这样的模型:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.x = 2
def forward(self):
return self.x
m = torch.jit.script(Model())
如果实例化Model
,则将导致编译错误,因为编译器不了解x
。 有四种方法可以通知编译器ScriptModule
的属性:
\1.nn.Parameter
-包装在nn.Parameter
中的值将像在nn.Module
上一样工作
\2.register_buffer
-包装在register_buffer
中的值将像在nn.Module
上一样工作。 这等效于Tensor
类型的属性(请参见 4)。
3.常量-将类成员注释为Final
(或在类定义级别将其添加到名为__constants__
的列表中)会将包含的名称标记为常量。 常数直接保存在模型代码中。 有关详细信息,请参见 Python 定义的常量。
4.属性-可以将支持的类型的值添加为可变属性。 可以推断大多数类型,但可能需要指定一些类型,有关详细信息,请参见模块属性。
问:我想跟踪模块的方法,但一直出现此错误:
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
此错误通常表示您要跟踪的方法使用模块的参数,并且您正在传递模块的方法而不是模块实例(例如
my_module_instance.forward
与my_module_instance
)。
\& 使用模块的方法调用trace
会将模块参数(可能需要渐变)捕获为常量。 &
\&
\& 另一方面,使用模块实例(例如my_module
)调用trace
会创建一个新模块,并将参数正确复制到新模块中,以便在需要时可以累积梯度。
& 要跟踪模块上的特定方法,请参见torch.jit.trace_module