此篇博客重在总结Tensorflow,Keras模型训练的模型文件转换为pb结构的方式,节省寻找转换工具的时间。

1. Tensorflow ckpt模型转换pb模型

我们在Tensorflow中训练出来的模型一般是ckpt格式的,一个ckpt文件对应有xxx.ckpt.dataxxx.ckpt.metaxxx.ckpt.index三个内容。

而在生产环境中,一般C++只能加载pb的模型,即将ckpt的结构3合1,一个模型只对应一个pb(当然甚至可能多个模型也能合成为一个pb,这里不进行展开)。

废话不说了,上代码

def freeze_graph(input_checkpoint, output_graph):
    '''

    :param input_checkpoint: xxx.ckpt(千万不要加后面的xxx.ckpt.data这种,到ckpt就行了!)
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "softmax" # 模型输入节点,根据情况自定义
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

用法

input_checkpoint = 'xxx.ckpt'
out_graph = 'froze_xxx.pb'
freeze_graph(input_checkpoint, out_graph) 

2. Keras h5模型转换pb模型

现在keras和Tensorflow的集成也越来越紧密了,用户可以通过tf.contrib.keras在tensorflow中引入keras使用,即keras和tensorflow相互耦合,而非之前那样,只是tensorflow的高层封装。

因为keras的很多ops封装的很简单,所以现在一般用keras搭模型的人很多,那么问题来了,如果想在生产环境中使用keras框架产生的hdf5格式的模型文件,也需要将其转换为pb格式,怎么做呢?Let’s roll it!

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a prunned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    prunned so subgraphs that are not neccesary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph


input_fld = sys.path[0]
weight_file = 'vgg16_without_dropout.h5'
output_graph_name = 'vgg16_without_dropout.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
    os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)


print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

3. 参考资料

[1] Eileng: keras模型保存为tensorflow的二进制模型
[2] 嘿芝麻:tensorflow框架.ckpt .pb模型节点tensor_name打印及ckpt模型转.pb模型

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐