tensorflow 中导出/恢复模型Graph数据Saver
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
不得不说,在tensorflow中,这个问题一直困扰我好几天了,没有弄清graph个saver的关系。
下面我就记录一下两者的用法以及应用场景:
Graph
图是tensorflow的核心,所有的操作都是基于图进行的,图中有很多的op,一个op又有一个或则多个的Tensor构成。
Saver
在训练的中可以保存数据比如得到一个Weights值后,需要保存下来,以便下次再使用。
应用场景
graph 和saver可以相互配合使用。可以说graph提供模型,saver提供数据。下面通过训练手写字识别来进行保存graph和saver:
#coding=utf-8
#保存soft.ph和soft.ckpt
#created by tengxing on 2017.2.22
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)
#create model
with tf.name_scope('input'):
x = tf.placeholder(tf.float32,[None,784],name='x_input')
y_ = tf.placeholder(tf.float32,[None,10],name='y_input')
with tf.name_scope('layer'):
with tf.name_scope('W'):
#tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
W = tf.Variable(tf.zeros([784,10]),name='Weights')
with tf.name_scope('b'):
b = tf.Variable(tf.zeros([10]),name='biases')
with tf.name_scope('W_p_b'):
Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')
y = tf.nn.softmax(Wx_plus_b, name='final_result')
print y
#define loss and optimizer
with tf.name_scope('loss'):
loss = -tf.reduce_sum(y_ * tf.log(y))
with tf.name_scope('train_step'):
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
print train_step
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
# important step
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
sess.run(init)
writer = tf.summary.FileWriter("logs/", sess.graph)
#train
for step in range(100):
batch_xs,batch_ys =mnist.train.next_batch(100)
train_step.run({x:batch_xs,y_:batch_ys})
print step
variables = tf.all_variables()
saver = tf.train.Saver(variables)
print len(variables)
print sess.run(b)
#print W.get_shape(),b.get_shape()
saver.save(sess, "data/soft.ckpt")
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
print '最终的测试正确率:{0}'.format(a)
tf.train.write_graph(sess.graph_def,'graph','soft.ph',False)
通过以上就可以保存起来了,我的代码可能有点乱,自行整理吧,下面开始恢复
#coding=utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)
# 加载Graph
def loadGraph(dir):
f = tf.gfile.FastGFile(dir,'rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_graph =tf.import_graph_def(graph_def,name='')
return persisted_graph
graph = loadGraph('graph/soft.ph')
with tf.Session(graph=graph) as sess:
#sess.run(tf.initialize_all_variables())
#sess.run(init) #加载时候不需要进行初始化
softmax_tensor = sess.graph.get_tensor_by_name('layer/final_result:0')
x = sess.graph.get_tensor_by_name('input/x_input:0')
y_ = sess.graph.get_tensor_by_name('input/y_input:0')
name = sess.graph.get_tensor_by_name('tengxing:0')
Weights = sess.graph.get_tensor_by_name('layer/W/Weights:0')
biases = sess.graph.get_tensor_by_name('layer/b/biases:0')
#W = tf.Variable(tf.zeros([784, 10]), name='Weights')
#b = tf.Variable(tf.zeros([10]), name='biases')
tf.add_to_collection(tf.GraphKeys.VARIABLES, name)
tf.add_to_collection(tf.GraphKeys.VARIABLES, Weights)
tf.add_to_collection(tf.GraphKeys.VARIABLES, biases)
try:
saver = tf.train.Saver(tf.global_variables()) # 'Saver' misnomer! Better: Persister!
except:
pass
print("load data")
#print sess.run(name) 此时才有一个Tensor获取变量还要进行赋值
saver.restore(sess, "./data/soft.ckpt") # now OK creted by tengxing
#test
correct_prediction = tf.equal(tf.argmax(softmax_tensor, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
通过以上两个代码可以实现训练模型的保存和继续使用。大家使用时候有问题发我邮箱:tengxing7452@163.com
后记:这篇文章写的时间不多,但是确实解决了我的很多问题,我相信这这种问题使我们在开发过程必然面临的。所以我才会花时间取解决。总的来说,结果还是令人满意的,毕竟弄出来了,但是我的代码比较乱,稍后我会整理上传。
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献7条内容
所有评论(0)