codecamp

TensorFlow.js 导入Keras模型

Keras模型(通常通过Python API创建)的存储格式可以有多种, 其中“整个模型”(即架构+权重+优化器状态都存在一个文件内)格式可以转换为 TensorFlow.js Layers格式,可以直接加载到TensorFlow.js中进行推理或进一步训练。

TensorFlow.js Layers格式是一个包含model.json文件和一组二进制格式的权重文件分片的目录。 model.json文件包含模型拓扑结构(又名“体系结构”或“图”:层的描述以及它们如何连接)以及权重文件的清单。


要求

运行模型格式转换程序需要Python环境;如果你想要一个独立的环境,你可以使用pipenv或virtualenv。要安装转换器,请使用这个命令 

pip install tensorflowjs

将Keras模型导入TensorFlow.js可以分为两个步骤。首先,将现有的Keras模型转换为TF.js Layers格式,然后将其加载到TensorFlow.js中。


1、将Keras模型转换为TF.js Layers格式

Keras模型通常使用model.save(filepath)保存,它生成一个包含模型拓扑结构和权重的HDF5(.h5)文件。要将这样的文件转换为TF.js Layer格式,请运行以下命令,其中path / to / my_model.h5是Keras .h5源文件,path / to / tfjs_target_dir是TF.js文件的输出目录:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

这一步骤的另一个方案:使用Python API直接导出为TF.js Layers格式

如果您在Python中使用Keras模型,则可以将其直接导出为TensorFlow.js Layers格式,如下所示:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)


2、将模型加载到TensorFlow.js中

使用Web服务器来为您在步骤1中生成的转换模型文件提供服务。请注意,为了允许JavaScript获取文件,您可能需要配置服务器以允许跨源资源共享(CORS)。

然后通过提供model.json文件的URL将模型加载到TensorFlow.js中:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadModel('https://foo.bar/tfjs_artifacts/model.json');

现在该模型已准备好进行推理,评估或重新训练。例如,加载的模型可以立即用于预测:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

TensorFlow.js Examples中的很多例子都采用此方法,使用已在Google云端存储上转换并托管的预训练模型。

请注意,您使用model.json文件名来引用整个模型。 loadModel(...)获取model.json,然后发出额外的HTTP(S)请求以获取model.json权重清单中引用的权重文件分片。这种方法允许所有这些文件被浏览器缓存(也可能通过互联网上的其他缓存服务器),因为model.json和weight shard分别小于典型的缓存文件大小限制。因此,模型可能会在随后的场合更快加载。


支持的功能

TensorFlow.js Layers目前仅支持使用标准Keras构造的Keras模型。使用了不受支持的操作或层(例如自定义层,Lambda层,自定义损失或自定义指标)的模型将无法自动导入,因为它们所依赖的Python代码无法正确地转换为JavaScript。


Tensorflow.js 图片训练
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }