TensorFlow函数教程:tf.nn.static_rnn
tf.nn.static_rnn函数
别名:
- tf.contrib.rnn.static_rnn
- tf.nn.static_rnn
tf.nn.static_rnn(
cell,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None
)
定义在:tensorflow/python/ops/rnn.py。
创建由RNNCell cell
指定的循环神经网络。
生成的最简单的RNN网络形式是:
state = cell.zero_state(...)
outputs = []
for input_ in inputs:
output, state = cell(input_, state)
outputs.append(output)
return (outputs, state)
但是,还有一些其他选项:
可以提供初始状态。如果提供sequence_length向量,则执行动态计算。这种计算方法不计算超过最小批处理的最大序列长度的RNN步骤(从而节省计算时间),并且将示例的序列长度的状态适当地传播到最终状态输出。
在批处理行b的时间t上执行的动态计算:
(output, state)(b, t) =
(t >= sequence_length(b))
? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
: cell(input(b, t), state(b, t - 1))
参数:
- cell:RNNCell的一个实例。
- inputs:输入的长度为T的列表,每个Tensor具有shape [batch_size, input_size];或这些元素的嵌套元组。
- initial_state:(可选)RNN的初始状态。如果cell.state_size是整数,则必须是具有适当的类型和shape为[batch_size, cell.state_size]的Tensor。如果cell.state_size是一个元组,这应该是具有shape [batch_size, s]的张量元组,其中s位于cell.state_size。
- dtype:(可选)初始状态和预期输出的数据类型。如果未提供initial_state或RNN状态具有异构类型,则为必需。
- sequence_length:指定输入中每个序列的长度。int32或int64向量(张量),大小为[batch_size],值位于[0, T)。
- scope:用于创建子图的VariableScope;默认为“rnn”。
返回:
(outputs, state)对,其中:
- outputs的长度为T的列表(每个输入一个),或这些元素的嵌套元组。
- state是最终状态
可能引发的异常:
- TypeError:如果cell不是RNNCell的实例。
- ValueError:如果inputs为None或是一个空列表,或者无法通过形状推断从输入推断输入深度(列大小)。
实例:
import tensorflow as tf
x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim]
x=tf.unstack(x,axis=1) #按时间步展开
n_neurons = 5 #输出神经元数量
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)
print(len(output_seqs)) #四个时间步
print(output_seqs[0]) #每个时间步输出一个张量
print(output_seqs[1]) #每个时间步输出一个张量
print(states) #隐藏状态
输出结果如下:
4
Tensor("rnn/basic_rnn_cell/Tanh:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_1:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_3:0", shape=(2, 5), dtype=float32)