在tensorflow中保存模型参数
·
想要保存训练之后得到的神经网络参数,一般有两种办法。
第一种,可以将tensor对象转换为numpy数组进行保存。
即,
numpy.savetxt('weight.txt', weight.eval())
第二种,是利用tensorflow自带的Saver对象。
import tensorflow as tf
##################################################3
w1 = tf.Variable(tf.constant(1.0), name='w1')
w2 = tf.Variable(tf.constant(2.0), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
w1 = tf.add(w1, w2)
saver.save(sess, './my-model.ckpt')
上面的代码中,创建了容器vars。它收集了tensor变量w1和w2。之后,tensorflow将这一容器保存。
在session中运行,就能将数据保存到tensorflow创建的几个文件中。
上面的代码运行结束后,当前目录下出现四个文件:
my-model.ckpt.meta
my-model.ckpt.data-*
my-model.ckpt.index
checkpoint
利用这四个文件就能恢复出 w1和w2这两个变量。
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-model.ckpt.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
print(all_vars)
for v in all_vars:
print(v)
print(v.name)
v_ = v.eval() # sess.run(v)
print(v_)
运行结果为:
[<tf.Tensor 'w1:0' shape=() dtype=float32_ref>, <tf.Tensor 'w2:0' shape=() dtype=float32_ref>]
Tensor("w1:0", shape=(), dtype=float32_ref)
w1:0
1.0
Tensor("w2:0", shape=(), dtype=float32_ref)
w2:0
2.0
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)