codecamp

TensorFlow函数教程:tf.nn.weighted_cross_entropy_with_logits

tf.nn.weighted_cross_entropy_with_logits函数

tf.nn.weighted_cross_entropy_with_logits(
    targets,
    logits,
    pos_weight,
    name=None
)

定义在:tensorflow/python/ops/nn_impl.py。

计算加权交叉熵。

类似于sigmoid_cross_entropy_with_logits(),除了pos_weight,允许人们通过向上或向下加权相对于负误差的正误差的成本来权衡召回率和精确度。

通常的交叉熵成本定义为:

targets * -log(sigmoid(logits)) +
    (1 - targets) * -log(1 - sigmoid(logits))

值pos_weights > 1减少了假阴性计数,从而增加了召回率。相反设置pos_weights < 1会减少假阳性计数并提高精度。从一下内容可以看出pos_weight是作为损失表达式中的正目标项的乘法系数引入的:

targets * -log(sigmoid(logits)) * pos_weight +
    (1 - targets) * -log(1 - sigmoid(logits))

为了简便起见,让x = logits,z = targets,q = pos_weight。损失是:

  qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))

设置l = (1 + (q - 1) * z),确保稳定性并避免溢出,使用一下内容来实现:

(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))

logits和targets必须具有相同的类型和形状。

参数:

  • targets:一个Tensor,与logits具有相同的类型和形状。
  • logits:一个Tensor,类型为float32或float64。
  • pos_weight:正样本中使用的系数。
  • name:操作的名称(可选)。

返回:

与具有分量加权逻辑损失的logits具有相同形状的Tensor。

可能引发的异常:

  • ValueError:如果logits和targets没有相同的形状。
TensorFlow函数教程:tf.nn.sufficient_statistics
TensorFlow函数教程:tf.nn.weighted_moments
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

TensorFlow 函数介绍

TensorFlow 函数模块:tf

TensorFlow的image模块

TensorFlow使用之tf.io

TensorFlow使用之tf.keras

TensorFlow函数教程:tf.keras.applications

TensorFlow函数教程:tf.keras.backend

TensorFlow使用之tf.metrics

TensorFlow使用之tf.nn

TensorFlow使用之tf.python_io

TensorFlow 功能函数

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }