tensorflow从0开始(6)——保存加载模型
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
目的
学习tensorflow的目的是能够训练的模型,并且利用已经训练好的模型对新数据进行预测。下文就是一个简单的保存模型加载模型的过程。
保存模型
import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('summaries_dir', '/tmp/save_graph_logs', 'Summaries directory')
data = np.arange(10,dtype=np.int32)
with tf.Session() as sess:
print("# build graph and run")
input1= tf.placeholder(tf.int32, [10], name="input")
output1= tf.add(input1, tf.constant(100,dtype=tf.int32), name="output") # data depends on the input data
saved_result= tf.Variable(data, name="saved_result")
do_save=tf.assign(saved_result,output1)
tf.initialize_all_variables()
os.system("rm -rf /tmp/save_graph_logs")
merged = tf.merge_all_summaries()
train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir,
sess.graph)
os.system("rm -rf /tmp/load")
tf.train.write_graph(sess.graph_def, "/tmp/load", "test.pb", False) #proto
# now set the data:
result,_=sess.run([output1,do_save], {input1: data}) # calculate output1 and assign to 'saved_result'
saver = tf.train.Saver(tf.all_variables())
saver.save(sess,"checkpoint.data")
模型图示
加载模型
with tf.Session() as persisted_sess:
print("load graph")
with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print("map variables")
persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)
try:
saver = tf.train.Saver(tf.all_variables()) # 'Saver' misnomer! Better: Persister!
except:pass
print("load data")
saver.restore(persisted_sess, "checkpoint.data") # now OK
print(persisted_result.eval())
print("DONE")
显示结果
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
3 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
3 个月前
更多推荐
已为社区贡献7条内容
所有评论(0)