codecamp

TensorFlow函数教程:tf.nn.ctc_loss

tf.nn.ctc_loss函数

tf.nn.ctc_loss(
    labels,
    inputs,
    sequence_length,
    preprocess_collapse_repeated=False,
    ctc_merge_repeated=True,
    ignore_longer_outputs_than_inputs=False,
    time_major=True
)

定义在:tensorflow/python/ops/ctc_ops.py.

参见指南:神经网络>连接时间分类(CTC)

计算CTC(连接时间分类)loss.

输入要求:

sequence_length(b) <= time for all b

max(labels.indices(labels.indices[:, 1] == b, 2))
  <= sequence_length(b) for all b.

笔记:

此类为您执行softmax操作,因此输入应该是例如LSTM对输出的线性预测.

inputs张量的最内层的维度大小,num_classes,代表num_labels + 1类别,其中num_labels是实际的标签的数量,而最大的值(num_classes - 1)是为空白标签保留的.

例如,对于包含3个标签[a, b, c]的词汇表,num_classes = 4,并且标签索引是{a: 0, b: 1, c: 2, blank: 3}.

关于参数preprocess_collapse_repeatedctc_merge_repeated

如果preprocess_collapse_repeated为True,则在loss计算之前运行预处理步骤,其中传递给loss的重复标签会合并为单个标签.如果训练标签来自,例如强制对齐,并因此具有不必要的重复,则这是有用的.

如果ctc_merge_repeated设置为False,则在CTC计算的深处,重复的非空白标签将不会合并,并被解释为单个标签.这是CTC的简化(非标准)版本.

以下是(大致)预期的第一顺序行为表:

  • preprocess_collapse_repeated=Falsectc_merge_repeated=True

典型的CTC行为:输出实际的重复类,其间有空白,还可以输出中间没有空白的重复类,这需要由解码器折叠.

  • preprocess_collapse_repeated=Truectc_merge_repeated=False

不要得知输出重复的类,因为它们在训练之前在输入标签中折叠.

  • preprocess_collapse_repeated=Falsectc_merge_repeated=False

输出中间有空白的重复类,但通常不需要解码器折叠/合并重复的类.

  • preprocess_collapse_repeated=Truectc_merge_repeated=True

未经测试,很可能不会得知输出重复的类.

ignore_longer_outputs_than_inputs选项允许在处理输出长于输入的序列时指定CTCLoss的行为.如果为true,则CTCLoss将仅为这些项返回零梯度,否则返回InvalidArgument错误,停止训练.

参数:

  • labels:一个int32SparseTensorlabels.indices[i, :] == [b, t]表示labels.values[i]存储(batch b, time t)的id;labels.values[i]必须采用[0, num_labels)中的值.
  • inputs:3-D float Tensor如果time_major == False,这将是一个Tensor,形状:[batch_size, max_time, num_classes]如果time_major == True(默认值),这将是一个Tensor,形状:[max_time, batch_size, num_classes];是logits.
  • sequence_length:1-Dint32向量,大小为[batch_size]序列长度.
  • preprocess_collapse_repeatedBoolean,默认值:False;如果为True,则在CTC计算之前折叠重复的标签.
  • ctc_merge_repeatedBoolean,默认值:True.
  • ignore_longer_outputs_than_inputs:Boolean,默认值:False;如果为True,则输出比输入长的序列将被忽略.
  • time_majorinputs张量的形状格式如果是True,那些Tensors必须具有形状[max_time, batch_size, num_classes]如果为False,则Tensors必须具有形状[batch_size, max_time, num_classes]使用time_major = True(默认)更有效,因为它避免了在ctc_loss计算开始时的转置.但是,大多数TensorFlow数据都是批处理为主的,因此通过此函数还可以接受以批处理为主的形式的输入.

返回:

1-DfloatTensor,大小为[batch]包含负对数概率.

可能引发的异常:

  • TypeError:如果标签不是SparseTensor.
TensorFlow函数教程:tf.nn.ctc_greedy_decoder
TensorFlow函数教程:tf.nn.depthwise_conv2d
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

TensorFlow 函数介绍

TensorFlow 函数模块:tf

TensorFlow的image模块

TensorFlow使用之tf.io

TensorFlow使用之tf.keras

TensorFlow函数教程:tf.keras.applications

TensorFlow函数教程:tf.keras.backend

TensorFlow使用之tf.metrics

TensorFlow使用之tf.nn

TensorFlow使用之tf.python_io

TensorFlow 功能函数

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }