codecamp

PyTorch 命名为 Tensors 操作员范围

原文:PyTorch 命名为 Tensors 操作员范围

请首先阅读命名张量,以了解命名张量。

本文档是名称推断的参考,HTH1 是一个定义张量命名方式的过程:

  1. 使用名称提供其他自动运行时正确性检查
  2. 将名称从输入张量传播到输出张量

以下是命名张量及其关联的名称推断规则支持的所有操作的列表。

如果此处未列出操作,但对您的用例有帮助,请搜索问题是否已提交,否则请提交一个问题

警告

命名的张量 API 是实验性的,随时可能更改。

Supported Operations

API 名称推断规则
Tensor.abs() , torch.abs() 保留输入名称
Tensor.abs_() Keeps input names
Tensor.acos() , torch.acos() Keeps input names
Tensor.acos_() Keeps input names
Tensor.add() , torch.add() 统一输入的名称
Tensor.add_() Unifies names from inputs
Tensor.addmm() , torch.addmm() 缩小暗淡
Tensor.addmm_() Contracts away dims
Tensor.addmv() , torch.addmv() Contracts away dims
Tensor.addmv_() Contracts away dims
Tensor.align_as() 查看文件
Tensor.align_to() See documentation
Tensor.all()torch.all() 没有
Tensor.any()torch.any() None
Tensor.asin() , torch.asin() Keeps input names
Tensor.asin_() Keeps input names
Tensor.atan() , torch.atan() Keeps input names
Tensor.atan2() , torch.atan2() Unifies names from inputs
Tensor.atan2_() Unifies names from inputs
Tensor.atan_() Keeps input names
Tensor.bernoulli() , torch.bernoulli() Keeps input names
Tensor.bernoulli_() None
Tensor.bfloat16() Keeps input names
Tensor.bitwise_not() , torch.bitwise_not() Keeps input names
Tensor.bitwise_not_() None
Tensor.bmm() , torch.bmm() Contracts away dims
Tensor.bool() Keeps input names
Tensor.byte() Keeps input names
torch.cat() Unifies names from inputs
Tensor.cauchy_() None
Tensor.ceil() , torch.ceil() Keeps input names
Tensor.ceil_() None
Tensor.char() Keeps input names
Tensor.chunk() , torch.chunk() Keeps input names
Tensor.clamp() , torch.clamp() Keeps input names
Tensor.clamp_() None
Tensor.copy_() 输出功能和就地变体
Tensor.cos() , torch.cos() Keeps input names
Tensor.cos_() None
Tensor.cosh() , torch.cosh() Keeps input names
Tensor.cosh_() None
Tensor.cpu() Keeps input names
Tensor.cuda() Keeps input names
Tensor.cumprod() , torch.cumprod() Keeps input names
Tensor.cumsum() , torch.cumsum() Keeps input names
Tensor.data_ptr() None
Tensor.detach() ,torch.detach() Keeps input names
Tensor.detach_() None
Tensor.device , torch.device() None
Tensor.digamma() , torch.digamma() Keeps input names
Tensor.digamma_() None
Tensor.dim() None
Tensor.div() , torch.div() Unifies names from inputs
Tensor.div_() Unifies names from inputs
Tensor.dot() , torch.dot() None
Tensor.double() Keeps input names
Tensor.element_size() None
torch.empty() 工厂功能
torch.empty_like() Factory functions
Tensor.eq() , torch.eq() Unifies names from inputs
Tensor.erf() , torch.erf() Keeps input names
Tensor.erf_() None
Tensor.erfc() , torch.erfc() Keeps input names
Tensor.erfc_() None
Tensor.erfinv() , torch.erfinv() Keeps input names
Tensor.erfinv_() None
Tensor.exp() , torch.exp() Keeps input names
Tensor.exp_() None
Tensor.expand() Keeps input names
Tensor.expm1() , torch.expm1() Keeps input names
Tensor.expm1_() None
Tensor.exponential_() None
Tensor.fill_() None
Tensor.flatten() , torch.flatten() See documentation
Tensor.float() Keeps input names
Tensor.floor() , torch.floor() Keeps input names
Tensor.floor_() None
Tensor.frac() , torch.frac() Keeps input names
Tensor.frac_() None
Tensor.ge() , torch.ge() Unifies names from inputs
Tensor.get_device() ,torch.get_device() None
Tensor.grad None
Tensor.gt() , torch.gt() Unifies names from inputs
Tensor.half() Keeps input names
Tensor.has_names() See documentation
Tensor.index_fill() ,torch.index_fill() Keeps input names
Tensor.index_fill_() None
Tensor.int() Keeps input names
Tensor.is_contiguous() None
Tensor.is_cuda None
Tensor.is_floating_point() , torch.is_floating_point() None
Tensor.is_leaf None
Tensor.is_pinned() None
Tensor.is_shared() None
Tensor.is_signed() ,torch.is_signed() None
Tensor.is_sparse None
torch.is_tensor() None
Tensor.item() None
Tensor.kthvalue() , torch.kthvalue() 移除尺寸
Tensor.le() , torch.le() Unifies names from inputs
Tensor.log() , torch.log() Keeps input names
Tensor.log10() , torch.log10() Keeps input names
Tensor.log10_() None
Tensor.log1p() , torch.log1p() Keeps input names
Tensor.log1p_() None
Tensor.log2() , torch.log2() Keeps input names
Tensor.log2_() None
Tensor.log_() None
Tensor.log_normal_() None
Tensor.logical_not() , torch.logical_not() Keeps input names
Tensor.logical_not_() None
Tensor.logsumexp() , torch.logsumexp() Removes dimensions
Tensor.long() Keeps input names
Tensor.lt() , torch.lt() Unifies names from inputs
torch.manual_seed() None
Tensor.masked_fill() ,torch.masked_fill() Keeps input names
Tensor.masked_fill_() None
Tensor.masked_select() , torch.masked_select() 将遮罩对齐到输入,然后 unified_names_from_input_tensors
Tensor.matmul() , torch.matmul() Contracts away dims
Tensor.mean() , torch.mean() Removes dimensions
Tensor.median() , torch.median() Removes dimensions
Tensor.mm() , torch.mm() Contracts away dims
Tensor.mode() , torch.mode() Removes dimensions
Tensor.mul() , torch.mul() Unifies names from inputs
Tensor.mul_() Unifies names from inputs
Tensor.mv() , torch.mv() Contracts away dims
Tensor.names See documentation
Tensor.narrow() , torch.narrow() Keeps input names
Tensor.ndim None
Tensor.ndimension() None
Tensor.ne() , torch.ne() Unifies names from inputs
Tensor.neg() , torch.neg() Keeps input names
Tensor.neg_() None
torch.normal() Keeps input names
Tensor.normal_() None
Tensor.numel() , torch.numel() None
torch.ones() Factory functions
Tensor.pow() , torch.pow() Unifies names from inputs
Tensor.pow_() None
Tensor.prod() , torch.prod() Removes dimensions
torch.rand() Factory functions
torch.rand() Factory functions
torch.randn() Factory functions
torch.randn() Factory functions
Tensor.random_() None
Tensor.reciprocal() , torch.reciprocal() Keeps input names
Tensor.reciprocal_() None
Tensor.refine_names() See documentation
Tensor.register_hook() None
Tensor.rename() See documentation
Tensor.rename_() See documentation
Tensor.requires_grad None
Tensor.requires_grad_() None
Tensor.resize_() 只允许不改变形状的调整大小
Tensor.resize_as_() Only allow resizes that do not change shape
Tensor.round() , torch.round() Keeps input names
Tensor.round_() None
Tensor.rsqrt() , torch.rsqrt() Keeps input names
Tensor.rsqrt_() None
Tensor.select() ,torch.select() Removes dimensions
Tensor.short() Keeps input names
Tensor.sigmoid() , torch.sigmoid() Keeps input names
Tensor.sigmoid_() None
Tensor.sign() , torch.sign() Keeps input names
Tensor.sign_() None
Tensor.sin() , torch.sin() Keeps input names
Tensor.sin_() None
Tensor.sinh() , torch.sinh() Keeps input names
Tensor.sinh_() None
Tensor.size() None
Tensor.split() , torch.split() Keeps input names
Tensor.sqrt() , torch.sqrt() Keeps input names
Tensor.sqrt_() None
Tensor.squeeze() , torch.squeeze() Removes dimensions
Tensor.std() , torch.std() Removes dimensions
torch.std_mean() Removes dimensions
Tensor.stride() None
Tensor.sub() ,torch.sub() Unifies names from inputs
Tensor.sub_() Unifies names from inputs
Tensor.sum() , torch.sum() Removes dimensions
Tensor.tan() , torch.tan() Keeps input names
Tensor.tan_() None
Tensor.tanh() , torch.tanh() Keeps input names
Tensor.tanh_() None
torch.tensor() Factory functions
Tensor.to() Keeps input names
Tensor.topk() , torch.topk() Removes dimensions
Tensor.transpose() , torch.transpose() 排列尺寸
Tensor.trunc() , torch.trunc() Keeps input names
Tensor.trunc_() None
Tensor.type() None
Tensor.type_as() Keeps input names
Tensor.unbind() , torch.unbind() Removes dimensions
Tensor.unflatten() See documentation
Tensor.uniform_() None
Tensor.var() , torch.var() Removes dimensions
torch.var_mean() Removes dimensions
Tensor.zero_() None
torch.zeros() Factory functions

保留输入名称

所有逐点一元函数以及其他一些一元函数都遵循此规则。

  • 检查姓名:无
  • 传播名称:输入张量的名称会传播到输出。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')

移除尺寸

所有缩小操作,例如 sum() ,都会通过缩小所需尺寸来删除尺寸。 select()squeeze() 等其他操作会删除尺寸。

只要有人可以将整数维度索引传递给运算符,就可以传递维度名称。 包含维索引列表的函数也可以包含维名称列表。

  • 检查名称:如果dimdims作为名称列表传入,请检查self中是否存在这些名称。
  • 传播名称:如果在输出张量中不存在dimdims指定的输入张量的尺寸,则这些尺寸的相应名称不会出现在output.names中。

>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')


>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')


## Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')

统一输入中的名称

所有二进制算术运算都遵循此规则。 广播操作仍然从右侧进行位置广播,以保持与未命名张量的兼容性。 要通过名称执行显式广播,请使用 Tensor.align_as()

  • 检查名称:所有名称都必须从右侧位置匹配。 即,在tensor + other中,对于(-min(tensor.dim(), other.dim()) + 1, -1]中的所有imatch(tensor.names[i], other.names[i])必须为 true。
  • 检查名称:此外,所有命名的尺寸必须从右对齐。 在匹配期间,如果我们将命名尺寸A与未命名尺寸None匹配,则A不得出现在具有未命名尺寸的张量中。
  • 传播名称:从两个张量的右边开始统一名称对,以产生输出名称。

例如,

## tensor: Tensor[   N, None]
## other:  Tensor[None,    C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')

检查姓名:

  • match(tensor.names[-1], other.names[-1])True
  • match(tensor.names[-2], tensor.names[-2])True
  • 由于我们将 tensor 中的None'C'匹配,因此请确保 tensor 中不存在'C'
  • 检查以确保other中不存在'N'(不存在)。

最后,使用[unify('N', None), unify(None, 'C')] = ['N', 'C']计算输出名称

更多示例:

## Dimensions don't match from the right:
## tensor: Tensor[N, C]
## other:  Tensor[   N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.


## Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
## tensor: Tensor[N, None]
## other:  Tensor[      N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.

注意

在最后两个示例中,可以通过名称对齐张量,然后执行加法。 使用 Tensor.align_as() 按名称对齐张量,或使用 Tensor.align_to() 将张量对齐到自定义尺寸顺序。

排列尺寸

某些操作(例如 Tensor.t())会置换尺寸顺序。 维度名称附加到各个维度,因此也可以排列。

如果操作员输入位置索引dim,它也可以采用尺寸名称作为dim

  • 检查名称:如果将dim作为名称传递,请检查其是否在张量中存在。
  • 传播名称:以与要排列的维相同的方式排列维名称。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')

收缩消失

矩阵乘法函数遵循此方法的某些变体。 让我们先通过 torch.mm() ,然后概括一下批矩阵乘法的规则。

对于torch.mm(tensor, other)

  • Check names: None
  • 传播名称:结果名称为(tensor.names[-2], other.names[-1])

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')

本质上,矩阵乘法在二维上执行点积运算,使它们折叠。 当两个张量矩阵相乘时,收缩尺寸消失,并且不出现在输出张量中。

torch.mv()torch.dot() 的工作方式类似:名称推断不会检查输入名称,并且会删除点积所涉及的尺寸:

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)

现在,让我们看一下torch.matmul(tensor, other)。 假设tensor.dim() >= 2other.dim() >= 2

  • 检查名称:检查输入的批次尺寸是否对齐并可以广播。 请参见统一输入的名称,以了解对齐输入的含义。
  • 传播名称:结果名称是通过统一批次尺寸并删除合同规定的尺寸获得的:unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])

例子:

## Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
## 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')

最后,还有许多功能的融合add版本。 即 addmm()addmv() 。 这些被视为构成 mm() 的名称推断和 add() 的命名推断。

工厂功能

现在,工厂函数采用新的names参数,该参数将名称与每个维度相关联。

>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
        [0., 0., 0.]], names=('N', 'C'))

输出功能和就地变型

指定为out=张量的张量具有以下行为:

  • 如果没有命名维,则将从操作中计算出的名称传播到其中。
  • 如果它具有任何命名维,则从该操作计算出的名称必须与现有名称完全相同。 否则,操作错误。

所有就地方法都会将输入修改为具有与根据名称推断计算出的名称相同的名称。 例如,

>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)


>>> x += y
>>> x.names
('N', 'C')
PyTorch 命名张量
PyTorch torchvision
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

Pytorch 音频

PyTorch 命名为 Tensor(实验性)

PyTorch 强化学习

PyTorch 用其他语言

PyTorch 语言绑定

PyTorch torchvision参考

PyTorch 音频参考

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }