【webAI】Tensorflow.js加载预训练的model
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
环境准备
- win10
- python3.6
- pip install tensorflow
- pip install tensorflowjs
训练并保存tensorflow模型为saved_model
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# 下载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# 初始化session
sess = tf.InteractiveSession()
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
# 神经网络参数
n_input = 784
n_node = 256
n_out = 10
x = tf.placeholder(tf.float32, [None, n_input], name="x")
y_ = tf.placeholder(tf.float32, [None, n_out])
# 第一层
W = weight_variable([n_input, n_node])
b = bias_variable([n_node])
layer_h = tf.nn.relu(tf.matmul(x, W) + b)
# 第二层
W_out = bias_variable([n_node, n_out])
b_out = bias_variable([n_out])
y = tf.nn.relu(tf.matmul(layer_h, W_out) + b_out)
softmax = tf.nn.softmax(y, name="softmax")
# LOSS损失函数
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y))
correct_prediction = tf.equal(tf.argmax(softmax, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 训练模型
train_step = tf.train.AdamOptimizer().minimize(cross_entropy)
tf.global_variables_initializer().run()
for i in range(2000):
batch = mnist.train.next_batch(50)
if i % 200 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1]})
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
print('test accuracy %g' % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels}))
# 保存模型为saved_model
tf.saved_model.simple_save(sess, "./saved_model",
inputs={"x": x, }, outputs={"softmax": softmax, })
转换tensorflow的模型
tensorflowjs_converter --input_format=tf_saved_model \
--output_node_names="softmax" \
--saved_model_tags=serve ./saved_model \
./web_model
- 转换后的模型文件
- tensorflowjs_model.pb 为 tensorflow.js能识别的模型
- weights_manifest.json 为 tensorflow.js能识别的模型参数文件
Tensorflow.js加载转换后的模型
import * as tf from '@tensorflow/tfjs'
import {loadFrozenModel} from '@tensorflow/tfjs-converter'
const MODEL_URL = 'tensorflowjs_model.pb'
const WEIGHTS_URL = 'weights_manifest.json'
async function predict() {
try {
const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL)
var xs = tf.tensor2d([pixels])
var output = model.execute({x: xs})
console.log(output.dataSync())
return output
} catch (e) {
console.log(e)
}
}
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献5条内容
所有评论(0)