PyTorch torch.hub
Pytorch Hub 是经过预先训练的模型资料库,旨在促进研究的可重复性。
发布模型
Pytorch Hub 支持通过添加简单的hubconf.py
文件将预训练的模型(模型定义和预训练的权重)发布到 github 存储库;
hubconf.py
可以有多个入口点。 每个入口点都定义为 python 函数(例如:您要发布的经过预先训练的模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何实现入口点?
如果我们扩展pytorch/vision/hubconf.py
中的实现,则以下代码段指定了resnet18
模型的入口点。 在大多数情况下,在hubconf.py
中导入正确的功能就足够了。 在这里,我们仅以扩展版本为例来说明其工作原理。
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
## resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
变量是加载模型所需的软件包名称的列表。 请注意,这可能与训练模型所需的依赖项稍有不同。args
和kwargs
传递给实际的可调用函数。- 该函数的文档字符串用作帮助消息。 它解释了模型做什么以及允许的位置/关键字参数是什么。 强烈建议在此处添加一些示例。
- Entrypoint 函数可以返回模型(nn.module),也可以返回辅助工具以使用户工作流程更流畅,例如 标记器。
- 带下划线前缀的可调用项被视为辅助功能,不会在
torch.hub.list()
中显示。 - 预训练的权重既可以存储在 github 存储库中,也可以由
torch.hub.load_state_dict_from_url()
加载。 如果少于 2GB,建议将其附加到项目版本,并使用该版本中的网址。 在上面的示例中,torchvision.models.resnet.resnet18
处理pretrained
,或者,您可以在入口点定义中添加以下逻辑。
if pretrained:
# For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要通知
- 发布的模型应至少在分支/标签中。 不能是随机提交。
从集线器加载模型
Pytorch Hub 提供了便捷的 API,可通过torch.hub.list()
浏览集线器中的所有可用模型,通过torch.hub.help()
显示文档字符串和示例,并使用torch.hub.load()
加载经过预先训练的模型
torch.hub.list(github, force_reload=False)¶
列出 <cite>github</cite> hubconf 中可用的所有入口点。
参数
- github (字符串)–格式为“ repo_owner / repo_name [:tag_name]”的字符串,带有可选的标记/分支。 如果未指定,则默认分支为<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
- force_reload (bool , 可选)–是否放弃现有缓存并强制重新下载。 默认值为<cite>否</cite>。
退货
可用入口点名称的列表
返回类型
入口点
例
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False)¶
显示入口点<cite>模型</cite>的文档字符串。
Parameters
- github (字符串)–格式为< repo_owner / repo_name [:tag_name] [:HT_7]的字符串,带有可选的标记/分支。 如果未指定,则默认分支为<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
- 模型(字符串)–在存储库的 hubconf.py 中定义的入口点名称字符串
- force_reload (bool__, optional) – whether to discard the existing cache and force a fresh download. Default is <cite>False</cite>.
Example
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
torch.hub.load(github, model, *args, **kwargs)¶
使用预训练的权重从 github 存储库加载模型。
Parameters
- github (string) – a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is <cite>master</cite> if not specified. Example: 'pytorch/vision[:hub]'
- model (string) – a string of entrypoint name defined in repo's hubconf.py
- args (可选*)–可调用<cite>模型</cite>的相应 args。
- force_reload (bool , 可选)–是否无条件强制重新下载 github 存储库。 默认值为<cite>否</cite>。
- 详细 (bool , 可选)–如果为 False,则忽略有关命中本地缓存的消息。 请注意,有关首次下载的消息不能被静音。 默认值为<cite>为真</cite>。
- \ kwargs (可选)–可调用<cite>模型</cite>的相应 kwargs。
Returns
具有相应预训练权重的单个模型。
Example
>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)¶
将给定 URL 上的对象下载到本地路径。
Parameters
- url (字符串)–要下载的对象的 URL
- dst (字符串)–保存对象的完整路径,例如 <cite>/ tmp / temporary_file</cite>
- hash_prefix (字符串 , 可选))–如果不是 None,则下载的 SHA256 文件应以 <cite>hash_prefix</cite> 开头。 默认值:无
- 进度 (bool , 可选)–是否显示 stderr 的进度条默认值:True
Example
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)¶
将 Torch 序列化对象加载到给定的 URL。
如果下载的文件是 zip 文件,它将被自动解压缩。
如果 <cite>model_dir</cite> 中已经存在该对象,则将其反序列化并返回。 <cite>model_dir</cite> 的默认值为$TORCH_HOME/checkpoints
,其中环境变量$TORCH_HOME
的默认值为$XDG_CACHE_HOME/torch
。 $XDG_CACHE_HOME
遵循 Linux 文件系统布局的 X 设计组规范,如果未设置,则默认值为~/.cache
。
Parameters
- url (string) – URL of the object to download
- model_dir (字符串 , 可选)–保存对象的目录
- map_location (可选)–指定如何重新映射存储位置的函数或命令(请参见 torch.load)
- 进度 (bool , 可选)–是否显示 stderr 进度条。 默认值:True
- check_hash (bool , 可选)–如果为 True,则 URL 的文件名部分应遵循命名约定
filename-<sha256>.ext
,其中[<sha256>
是文件内容的 SHA256 哈希值的前 8 位或更多位。 哈希用于确保唯一的名称并验证文件的内容。 默认值:False
Example
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
运行加载的模型:
注意,torch.load()
中的*args, **kwargs
用于实例化模型。 加载模型后,如何找到可以使用该模型的功能? 建议的工作流程是
dir(model)
查看模型的所有可用方法。help(model.foo)
检查model.foo
需要执行哪些参数
为了帮助用户探索而又不来回参考文档,我们强烈建议回购所有者使功能帮助消息清晰明了。 包含一个最小的工作示例也很有帮助。
我下载的模型保存在哪里?
这些位置按以下顺序使用
- 呼叫
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
,如果设置了环境变量TORCH_HOME
。$XDG_CACHE_HOME/torch/hub
,如果设置了环境变量XDG_CACHE_HOME
。~/.cache/torch/hub
torch.hub.set_dir(d)¶
(可选)将 hub_dir 设置为本地目录,以保存下载的模型&权重。
如果未调用set_dir
,则默认路径为$TORCH_HOME/hub
,其中环境变量$TORCH_HOME
默认为$XDG_CACHE_HOME/torch
。 $XDG_CACHE_HOME
遵循 Linux 文件系统布局的 X 设计组规范,如果未设置环境变量,则默认值为~/.cache
。
Parameters
d (字符串)–本地文件夹的路径,用于保存下载的模型&权重。
缓存逻辑
默认情况下,加载文件后我们不会清理文件。 如果hub_dir
中已经存在,则集线器默认使用缓存。
用户可以通过调用hub.load(..., force_reload=True)
来强制重新加载。 这将删除现有的 github 文件夹和下载的权重,重新初始化新的下载。 当更新发布到同一分支时,此功能很有用,用户可以跟上最新版本。
已知限制:
Torch 集线器通过导入软件包来进行工作,就像安装软件包一样。 在 Python 中导入会带来一些副作用。 例如,您可以在 Python 缓存sys.modules
和sys.path_importer_cache
中看到新项目,这是正常的 Python 行为。
在这里值得一提的已知限制是用户无法在相同的 python 进程中加载同一存储库的两个不同分支。 就像在 Python 中安装两个具有相同名称的软件包一样,这是不好的。 快取可能会加入聚会,如果您实际尝试的话会给您带来惊喜。 当然,将它们分别加载是完全可以的。