TensorFlow函数:tf.estimator.EstimatorSpec
tf.estimator.EstimatorSpec函数
EstimatorSpec类
定义在:tensorflow/python/estimator/model_fn.py.
从model_fn返回的操作和对象并传递给Estimator.
EstimatorSpec完全定义了由Estimator运行的模型.
属性
- eval_metric_ops
字段号4的别名 - evaluation_hooks
字段号9的别名 - export_outputs
字段号5的别名 - loss
字段号2的别名 - mode
字段号0的别名 - prediction_hooks
字段号10的别名 - predictions
字段号1的别名 - scaffold
字段号8的别名 - train_op
字段号3的别名 - training_chief_hooks
字段号6的别名 - training_hooks
字段号7的别名
方法
__new__
@ staticmethod __new__ ( cls , mode , predictions = None , loss = None , train_op = None , eval_metric_ops = None , export_outputs = None , training_chief_hooks = None , training_hooks = None , scaffold = None , evaluation_hooks = None , prediction_hooks= 无 )
创建一个已经验证的EstimatorSpec实例.
根据mode的值的不同,需要不同的参数,即:
- 对于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
- 对于mode == ModeKeys.EVAL:必填字段是loss.
- 为mode == ModeKeys.PREDICT:必填字段是predictions.
model_fn可以填充独立于模式的所有参数.在这种情况下,Estimator将忽略某些参数.在eval和infer模式中,train_op将被忽略.例子如下:
def my_model_fn(mode, features, labels):
predictions = ...
loss = ...
train_op = ...
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
或者,model_fn可以填充适合给定模式的参数.例:
def my_model_fn(mode, features, labels):
if (mode == tf.estimator.ModeKeys.TRAIN or
mode == tf.estimator.ModeKeys.EVAL):
loss = ...
else:
loss = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = ...
else:
train_op = None
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = ...
else:
predictions = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
函数参数:
- mode:一个ModeKeys,指定是training(训练)、evaluation(计算)还是prediction(预测).
- predictions:预测Tensor或字典Tensor.
- loss:训练损失Tensor,必须是标量或形状[1].
- train_op:适用于训练的步骤.
- eval_metric_ops:按名称键入的度量结果字典.字典的值是调用度量函数的结果,即(metric_tensor, update_op)元组.应该在没有任何状态影响的情况下进行metric_tensor计算(通常是基于变量的纯计算结果).例如,它不应该触发update_op或需要任何输入提取.
- export_outputs:描述要在服务期间导出到SavedModel并使用的输出签名.在字典{name: output}中:name:此输出的任意名称.output:一个ExportOutput对象,如ClassificationOutput,RegressionOutput或PredictOutput.Single-headed模型只需要在本字典中指定一个条目.Multi-headed模型应为每个头指定一个条目,其中之一必须使用signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY进行命名.
- training_chief_hooks:在训练期间可以在主要工作人员中运行的tf.train.SessionRunHook对象的迭代.
- training_hooks:在训练过程中可以对所有工作人员运行的tf.train.SessionRunHook对象.
- scaffold:可用于设置初始化,保护程序等用于训练的tf.train.Scaffold对象.
- evaluation_hooks:评估期间要运行的tf.train.SessionRunHook对象的可迭代性.
- prediction_hooks:在预测期间可以运行的tf.train.SessionRunHook对象的可迭代性.
返回值:
一个经过验证的EstimatorSpec对象.
可能引发的异常:
- ValueError:如果验证失败,则会引发此异常.
- TypeError:如果任何参数不是预期的类型.