谷歌Pixel 2人像模式代码曝光,你看懂了吗?
2018-03-16 18:34:11
浏览数 (6045)
谷歌把他们所应用的 AI 图像分层算法 DeepLab-v3+ 变成开源代码,让第三方相机 app 都可以利用借此神经网络。
开源代码:
import tensorflow as tffrom deeplab.core import feature_extractor
slim = tf.contrib.slim
_LOGITS_SCOPE_NAME = 'logits'
_MERGED_LOGITS_SCOPE = 'merged_logits'
_IMAGE_POOLING_SCOPE = 'image_pooling'
_ASPP_SCOPE = 'aspp'
_CONCAT_PROJECTION_SCOPE = 'concat_projection'
_DECODER_SCOPE = 'decoder'
def get_extra_layer_scopes():
"""Gets the scopes for extra layers.
Returns:
A list of scopes for extra layers.
"""
return [
_LOGITS_SCOPE_NAME,
_IMAGE_POOLING_SCOPE,
_ASPP_SCOPE,
_CONCAT_PROJECTION_SCOPE,
_DECODER_SCOPE,
]
def predict_labels_multi_scale(images,
model_options,
eval_scales=(1.0,),
add_flipped_images=False):
"""Predicts segmentation labels.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
eval_scales: The scales to resize images for evaluation.
add_flipped_images: Add flipped images for evaluation or not.
Returns:
A dictionary with keys specifying the output_type (e.g., semantic
prediction) and values storing Tensors representing predictions (argmax
over channels). Each prediction has size [batch, height, width].
"""
outputs_to_predictions = {
output: []
for output in model_options.outputs_to_num_classes
}
for i, image_scale in enumerate(eval_scales):
with tf.variable_scope(tf.get_variable_scope(), reuse=True if i else None):
outputs_to_scales_to_logits = multi_scale_logits(
images,
model_options=model_options,
image_pyramid=[image_scale],
is_training=False,
fine_tune_batch_norm=False)
if add_flipped_images:
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
outputs_to_scales_to_logits_reversed = multi_scale_logits(
tf.reverse_v2(images, [2]),
model_options=model_options,
image_pyramid=[image_scale],
is_training=False,
fine_tune_batch_norm=False)
for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear(
scales_to_logits[_MERGED_LOGITS_SCOPE],
tf.shape(images)[1:3],
align_corners=True)
outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits), 4))
if add_flipped_images:
scales_to_logits_reversed = (
outputs_to_scales_to_logits_reversed[output])
logits_reversed = tf.image.resize_bilinear(
tf.reverse_v2(scales_to_logits_reversed[_MERGED_LOGITS_SCOPE], [2]),
tf.shape(images)[1:3],
align_corners=True)
outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits_reversed), 4))
for output in sorted(outputs_to_predictions):
predictions = outputs_to_predictions[output]
# Compute average prediction across different scales and flipped images.
predictions = tf.reduce_mean(tf.concat(predictions, 4), axis=4)
outputs_to_predictions[output] = tf.argmax(predictions, 3)
return outputs_to_predictions
def predict_labels(images, model_options, image_pyramid=None):
"""Predicts segmentation labels.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
image_pyramid: Input image scales for multi-scale feature extraction.
Returns:
A dictionary with keys specifying the output_type (e.g., semantic
prediction) and values storing Tensors representing predictions (argmax
over channels). Each prediction has size [batch, height, width].
"""
outputs_to_scales_to_logits = multi_scale_logits(
images,
model_options=model_options,
image_pyramid=image_pyramid,
is_training=False,
fine_tune_batch_norm=False)
predictions = {}
for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear(
scales_to_logits[_MERGED_LOGITS_SCOPE],
tf.shape(images)[1:3],
align_corners=True)
predictions[output] = tf.argmax(logits, 3)
return predictions
def scale_dimension(dim, scale):
"""Scales the input dimension.
Args:
dim: Input dimension (a scalar or a scalar Tensor).
scale: The amount of scaling applied to the input.
Returns:
Scaled dimension.
"""
if isinstance(dim, tf.Tensor):
return tf.cast((tf.to_float(dim) - 1.0) scale + 1.0, dtype=tf.int32)
else:
return int((float(dim) - 1.0) scale + 1.0)
def multi_scale_logits(images,
model_options,
image_pyramid,
weight_decay=0.0001,
is_training=False,
fine_tune_batch_norm=False):
"""Gets the logits for multi-scale inputs.
The returned logits are all downsampled (due to max-pooling layers)
for both training and evaluation.
更多查看:https://github.com/tensorflow/models/tree/master/research/deeplab