TensorFlow使用初始化器生成具有常量值的张量
tf.constant_initializer
tf.constant_initializer 类
tf.contrib.keras.initializers.Constant 类
定义在:tensorflow/python/ops/init_ops.py.
参见指南:变量>共享变量
初始化器用于生成具有常量值的张量.
生成的张量由 dtype 类型的值填充,参数值按照新张量的期望形状来指定(参见下面的例子).
参数值可以是常量值,或者是 dtype 类型的值的列表.如果值是一个列表,则列表的长度必须小于或等于所需的张量形状所隐含的元素数.如果值中的元素总数小于张量形状所需的元素数,则值中的最后一个元素将用于填充剩余的项.如果值中的元素总数大于张量形状所需的元素数,则初始值设定项将引发 ValueError.
ARGS:
- value:一个 Python 标量、值列表或者 N 维 numpy 数组.初始化变量的所有元素都将设置为值参数中的相应值.
- dtype:数据类型.
- verify_shape:布尔值,用于验证数值形状.为 True 时,如果值的形状与初始张量的形状不兼容,则初始值设定项将引发错误.
示例:可以使用 numpy.ndarray 而不是 value 列表重写以下示例,甚至重新映射,如 value 列表初始化下面的两个注释行所示.
>>> import numpy as np
>>> import tensorflow as tf
>>> value = [0, 1, 2, 3, 4, 5, 6, 7]
>>> # value = np.array(value)
>>> # value = value.reshape([2, 4])
>>> init = tf.constant_initializer(value)
>>> print('fitting shape:')
>>> with tf.Session():
>>> x = tf.get_variable('x', shape=[2, 4], initializer=init)
>>> x.initializer.run()
>>> print(x.eval())
fitting shape:
[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]]
>>> print('larger shape:')
>>> with tf.Session():
>>> x = tf.get_variable('x', shape=[3, 4], initializer=init)
>>> x.initializer.run()
>>> print(x.eval())
larger shape:
[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 7. 7. 7. 7.]]
>>> print('smaller shape:')
>>> with tf.Session():
>>> x = tf.get_variable('x', shape=[2, 3], initializer=init)
* <b>`ValueError`</b>: Too many elements provided. Needed at most 6, but received 8
>>> print('shape verification:')
>>> init_verify = tf.constant_initializer(value, verify_shape=True)
>>> with tf.Session():
>>> x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
* <b>`TypeError`</b>: Expected Tensor's shape: (3, 4), got (8,).
方法
__init__
__init__(
value=0,
dtype=tf.float32,
verify_shape=False
)
__call__
__call__(
shape,
dtype=None,
partition_info=None,
verify_shape=None
)
from_config
from_config(
cls,
config
)
从配置字典中实例化一个初始化程序.
例:
initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)
参数:
- config:Python 字典.它通常是 get_config 的输出.
返回:
一个初始化程序实例.
get_config
get_config ()