TensorFlow-模型的保存和调用(ckpt方式)
·
TensorFlow-模型的保存和调用(ckpt方式)
硬件:NVIDIA-GTX1080
软件:Windows7、python3.6.5、tensorflow-gpu-1.4.0
一、基础知识
1、checkpoint:模型文本信息
2、meta:模型graph,调用时可重载入
3、index、data:模型数据
二、代码展示
1、保存模型
import tensorflow as tf
import numpy as np
# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
init = tf.global_variables_initializer()
#define saver
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
#save ckpt
save_path = saver.save(sess, "my_net/save_net.ckpt")
print("Save to path: ", save_path)
2、调用模型
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
# restore graph
saver = tf.train.import_meta_graph('my_net/save_net.ckpt.meta')
#restore ckpt
saver.restore(sess, "my_net/save_net.ckpt")
# check variable W and b, like weight or bias
print("weights:", sess.run('weights:0'))
print("biases:", sess.run('biases:0'))
任何问题请加唯一QQ2258205918(名称samylee)!
或唯一VX:samylee_csdn
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)