codecamp

PyTorch torch.hub

原文:PyTorch torch.hub

Pytorch Hub 是经过预先训练的模型资料库,旨在促进研究的可重复性。

发布模型

Pytorch Hub 支持通过添加简单的hubconf.py文件将预训练的模型(模型定义和预训练的权重)发布到 github 存储库;

hubconf.py可以有多个入口点。 每个入口点都定义为 python 函数(例如:您要发布的经过预先训练的模型)。

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

如何实现入口点?

如果我们扩展pytorch/vision/hubconf.py中的实现,则以下代码段指定了resnet18模型的入口点。 在大多数情况下,在hubconf.py中导入正确的功能就足够了。 在这里,我们仅以扩展版本为例来说明其工作原理。

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18


## resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model

  • dependencies变量是加载模型所需的软件包名称的列表。 请注意,这可能与训练模型所需的依赖项稍有不同。
  • argskwargs传递给实际的可调用函数。
  • 该函数的文档字符串用作帮助消息。 它解释了模型做什么以及允许的位置/关键字参数是什么。 强烈建议在此处添加一些示例。
  • Entrypoint 函数可以返回模型(nn.module),也可以返回辅助工具以使用户工作流程更流畅,例如 标记器。
  • 带下划线前缀的可调用项被视为辅助功能,不会在torch.hub.list()中显示。
  • 预训练的权重既可以存储在 github 存储库中,也可以由torch.hub.load_state_dict_from_url()加载。 如果少于 2GB,建议将其附加到项目版本,并使用该版本中的网址。 在上面的示例中,torchvision.models.resnet.resnet18处理pretrained,或者,您可以在入口点定义中添加以下逻辑。

if pretrained:
    # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)


    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 发布的模型应至少在分支/标签中。 不能是随机提交。

从集线器加载模型

Pytorch Hub 提供了便捷的 API,可通过torch.hub.list()浏览集线器中的所有可用模型,通过torch.hub.help()显示文档字符串和示例,并使用torch.hub.load()加载经过预先训练的模型

torch.hub.list(github, force_reload=False)¶

列出 <cite>github</cite> hubconf 中可用的所有入口点。

参数

  • github (字符串)–格式为“ repo_owner / repo_name [:tag_name]”的字符串,带有可选的标记/分支。 如果未指定,则默认分支为<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
  • force_reload (bool 可选)–是否放弃现有缓存并强制重新下载。 默认值为<cite>否</cite>。

退货

可用入口点名称的列表

返回类型

入口点

>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)

torch.hub.help(github, model, force_reload=False)¶

显示入口点<cite>模型</cite>的文档字符串。

Parameters

  • github (字符串)–格式为< repo_owner / repo_name [:tag_name] [:HT_7]的字符串,带有可选的标记/分支。 如果未指定,则默认分支为<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
  • 模型(字符串)–在存储库的 hubconf.py 中定义的入口点名称字符串
  • force_reload (bool__, optional) – whether to discard the existing cache and force a fresh download. Default is <cite>False</cite>.

Example

>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))

torch.hub.load(github, model, *args, **kwargs)¶

使用预训练的权重从 github 存储库加载模型。

Parameters

  • github (string) – a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is <cite>master</cite> if not specified. Example: 'pytorch/vision[:hub]'
  • model (string) – a string of entrypoint name defined in repo's hubconf.py
  • args (可选*)–可调用<cite>模型</cite>的相应 args。
  • force_reload (bool 可选)–是否无条件强制重新下载 github 存储库。 默认值为<cite>否</cite>。
  • 详细 (bool 可选)–如果为 False,则忽略有关命中本地缓存的消息。 请注意,有关首次下载的消息不能被静音。 默认值为<cite>为真</cite>。
  • \ kwargs (可选)–可调用<cite>模型</cite>的相应 kwargs。

Returns

具有相应预训练权重的单个模型。

Example

>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)

torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)¶

将给定 URL 上的对象下载到本地路径。

Parameters

  • url (字符串)–要下载的对象的 URL
  • dst (字符串)–保存对象的完整路径,例如 <cite>/ tmp / temporary_file</cite>
  • hash_prefix (字符串 可选))–如果不是 None,则下载的 SHA256 文件应以 <cite>hash_prefix</cite> 开头。 默认值:无
  • 进度 (bool 可选)–是否显示 stderr 的进度条默认值:True

Example

>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')

torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)¶

将 Torch 序列化对象加载到给定的 URL。

如果下载的文件是 zip 文件,它将被自动解压缩。

如果 <cite>model_dir</cite> 中已经存在该对象,则将其反序列化并返回。 <cite>model_dir</cite> 的默认值为$TORCH_HOME/checkpoints,其中环境变量$TORCH_HOME的默认值为$XDG_CACHE_HOME/torch$XDG_CACHE_HOME遵循 Linux 文件系统布局的 X 设计组规范,如果未设置,则默认值为~/.cache

Parameters

  • url (string) – URL of the object to download
  • model_dir (字符串 可选)–保存对象的目录
  • map_location (可选)–指定如何重新映射存储位置的函数或命令(请参见 torch.load)
  • 进度 (bool 可选)–是否显示 stderr 进度条。 默认值:True
  • check_hash (bool 可选)–如果为 True,则 URL 的文件名部分应遵循命名约定filename-<sha256>.ext,其中[ <sha256>是文件内容的 SHA256 哈希值的前 8 位或更多位。 哈希用于确保唯一的名称并验证文件的内容。 默认值:False

Example

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

运行加载的模型:

注意,torch.load()中的*args, **kwargs用于实例化模型。 加载模型后,如何找到可以使用该模型的功能? 建议的工作流程是

  • dir(model)查看模型的所有可用方法。
  • help(model.foo)检查model.foo需要执行哪些参数

为了帮助用户探索而又不来回参考文档,我们强烈建议回购所有者使功能帮助消息清晰明了。 包含一个最小的工作示例也很有帮助。

我下载的模型保存在哪里?

这些位置按以下顺序使用

  • 呼叫hub.set_dir(<PATH_TO_HUB_DIR>)
  • $TORCH_HOME/hub,如果设置了环境变量TORCH_HOME
  • $XDG_CACHE_HOME/torch/hub,如果设置了环境变量XDG_CACHE_HOME
  • ~/.cache/torch/hub

torch.hub.set_dir(d)¶

(可选)将 hub_dir 设置为本地目录,以保存下载的模型&权重。

如果未调用set_dir,则默认路径为$TORCH_HOME/hub,其中环境变量$TORCH_HOME默认为$XDG_CACHE_HOME/torch$XDG_CACHE_HOME遵循 Linux 文件系统布局的 X 设计组规范,如果未设置环境变量,则默认值为~/.cache

Parameters

d (字符串)–本地文件夹的路径,用于保存下载的模型&权重。

缓存逻辑

默认情况下,加载文件后我们不会清理文件。 如果hub_dir中已经存在,则集线器默认使用缓存。

用户可以通过调用hub.load(..., force_reload=True)来强制重新加载。 这将删除现有的 github 文件夹和下载的权重,重新初始化新的下载。 当更新发布到同一分支时,此功能很有用,用户可以跟上最新版本。

已知限制:

Torch 集线器通过导入软件包来进行工作,就像安装软件包一样。 在 Python 中导入会带来一些副作用。 例如,您可以在 Python 缓存sys.modulessys.path_importer_cache中看到新项目,这是正常的 Python 行为。

在这里值得一提的已知限制是用户无法相同的 python 进程中加载同一存储库的两个不同分支。 就像在 Python 中安装两个具有相同名称的软件包一样,这是不好的。 快取可能会加入聚会,如果您实际尝试的话会给您带来惊喜。 当然,将它们分别加载是完全可以的。

PyTorch 概率分布-torch分布
PyTorch torch脚本
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

Pytorch 音频

PyTorch 命名为 Tensor(实验性)

PyTorch 强化学习

PyTorch 用其他语言

PyTorch 语言绑定

PyTorch torchvision参考

PyTorch 音频参考

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }