codecamp

TensorFlow函数教程:tf.nn.dynamic_rnn

tf.nn.dynamic_rnn函数

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

定义在:tensorflow/python/ops/rnn.py.

请参阅指南:神经网络>递归神经网络

创建由 RNNCellcell指定的递归神经网络

执行inputs的完全动态展开.

示例:

# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                   inputs=data,
                                   dtype=tf.float32)

参数:

  • cell:RNNCell的一个实例.
  • inputs:RNN输入.如果time_major == False(默认),则是一个shape为[batch_size, max_time, ...]Tensor,或者这些元素的嵌套元组.如果time_major == True,则是一个shape为[max_time, batch_size, ...]Tensor,或这些元素的嵌套元组.这也可能是满足此属性的Tensors(可能是嵌套的)元组.前两个维度必须匹配所有输入,否则秩和其他形状组件可能不同.在这种情况下,在每个时间步输入到cell将复制这些元组的结构,时间维度(从中获取时间)除外.在每个时间步输入到个cell将是一个Tensor或(可能是嵌套的)Tensors元组,每个元素都有维度[batch_size, ...].
  • sequence_length:(可选)大小为[batch_size]的int32/int64的向量.超过批处理元素的序列长度时用于复制状态和零输出.所以它更多的是正确性而不是性能.
  • initial_state:(可选)RNN的初始状态.如果cell.state_size是整数,则必须是具有适当类型和shape为[batch_size, cell.state_size]Tensor.如果cell.state_size是一个元组,则应该是张量元组,cell.state_size中为s设置shape[batch_size, s].
  • dtype:(可选)初始状态和预期输出的数据类型.如果未提供initial_state或RNN状态具有异构dtype,则是必需的.
  • parallel_iterations:(默认值:32).并行运行的迭代次数.适用于那些没有任何时间依赖性并且可以并行运行的操作.该参数用于交换空间的时间.远大于1的值会使用更多内存但占用更少时间,而较小值使用较少内存但计算时间较长.
  • swap_memory:透明地交换推理中产生的张量,但是需要从GPU到CPU的支持.这允许训练通常不适合单个GPU的RNN,具有非常小的(或没有)性能损失.
  • time_majorinputsoutputsTensor的形状格式.如果是true,则这些 Tensors的shape必须为[max_time, batch_size, depth].如果是false,则这些Tensors的shape必须为[batch_size, max_time, depth].使用time_major = True更有效,因为它避免了RNN计算开始和结束时的转置.但是,大多数TensorFlow数据都是batch-major,因此默认情况下,此函数接受输入并以batch-major形式发出输出.
  • scope:用于创建子图的VariableScope;默认为“rnn”.

返回:

一对(outputs, state),其中:

  • outputs:RNN输出Tensor.

    如果time_major == False(默认),这将是shape为[batch_size, max_time, cell.output_size]Tensor.

    如果time_major == True,这将是shape为[max_time, batch_size, cell.output_size]Tensor.

    注意,如果cell.output_size是整数或TensorShape对象的(可能是嵌套的)元组,那么outputs将是一个与cell.output_size具有相同结构的元祖,它包含与cell.output_size中的形状数据有对应shape的Tensors.

  • state:最终的状态.如果cell.state_size是int,则会形成[batch_size, cell.state_size].如果它是TensorShape,则将形成[batch_size] + cell.state_size.如果它是一个(可能是嵌套的)int或TensorShape元组,那么这将是一个具有相应shape的元组.如果单元格是LSTMCells,则state将是包含每个单元格的LSTMStateTuple的元组.

可能引发的异常:

  • TypeError:如果cell不是RNNCell的实例.
  • ValueError:如果输入为None或是空列表.
TensorFlow函数教程:tf.nn.dropout
TensorFlow函数教程:tf.nn.elu
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

TensorFlow 函数介绍

TensorFlow 函数模块:tf

TensorFlow的image模块

TensorFlow使用之tf.io

TensorFlow使用之tf.keras

TensorFlow函数教程:tf.keras.applications

TensorFlow函数教程:tf.keras.backend

TensorFlow使用之tf.metrics

TensorFlow使用之tf.nn

TensorFlow使用之tf.python_io

TensorFlow 功能函数

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }