codecamp

TensorFlow聚合梯度的条件累加器

tf.ConditionalAccumulator

tf.ConditionalAccumulator 类

定义在:tensorflow/python/ops/data_flow_ops.py

用于聚合梯度的条件累加器.

最新的梯度(即计算梯度的时间步长等于累加器的时间步长)被加到累加器中.

平均梯度的提取被阻塞,直到所需数量的梯度被累积为止.

属性

  • accumulator_ref:底层累加器参考.
  • dtype:该累加器积累的梯度的数据类型.
  • name:底层累加器的名称.

方法

__init__

__init__ (  
    dtype ,  
    shape = None ,  
    shared_name = None ,  
    name = 'conditional_accumulator' 
)

创建一个新的 ConditionalAccumulator.

ARGS:
  • dtype:累积梯度的数据类型.
  • shape:累积梯度的形状.
  • shared_name:可选.如果非空,这个累加器将在多个会话的给定名称下共享.
  • name:累加器的可选名称.

apply_grad

apply_grad (  
    grad ,  
    local_step = 0 ,  
    name = None
  )

尝试向累加器应用梯度.

如果梯度是陈旧的,即 local_step 小于累加器的全局时间步长,则该尝试将被静默地丢弃.

ARGS:

  • grad:要应用的梯度张量.
  • local_step:计算梯度的时间步长.
  • name:操作的可选名称.

返回:

(有条件地) 将梯度应用于累加器的操作.

注意:

  • ValueError:如果 grad 是错误的形状

num_accumulated

num_accumulated ( name = None )

目前在累加器中聚合的梯度数.

ARGS:

  • name:操作的可选名称.

返回:

累加器中当前累积的梯度数.

set_global_step

set_global_step (  
    new_global_step ,  
    name = None
  )

设置累加器的全局时间步长.

如果尝试设置的时间步长低于累加器自己的时间步长, 则操作会记录一个警告.

ARGS:

  • new_global_step:新的时间步长的值,可以是变量或常量.
  • name:操作的可选名称.

返回:

设置累加器时间步长的操作.

take_grad

take_grad (  
    num_required ,  
    name = None
  )

尝试从累加器中提取平均梯度.

操作阻止直到足够数量的梯度已成功应用于累加器.

一旦成功,还会触发以下操作:

  • 累加梯度的计数器复位为0.
  • 聚合梯度被重置为0张量.
  • 累加器的内部时间步长增加1.

ARGS:

  • num_required:需要聚合的梯度次数
  • name:操作的可选名称

返回:

一个持续平均梯度值的张量.

注意:

  • InvalidArgumentError:如果 num_required <1


tf.cond函数的使用
如何使用ConditionalAccumulatorBase
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

TensorFlow 函数介绍

TensorFlow 函数模块:tf

TensorFlow的image模块

TensorFlow使用之tf.io

TensorFlow使用之tf.keras

TensorFlow函数教程:tf.keras.applications

TensorFlow函数教程:tf.keras.backend

TensorFlow使用之tf.metrics

TensorFlow使用之tf.nn

TensorFlow使用之tf.python_io

TensorFlow 功能函数

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }