今天在测试一个模型的时候遇到了下面的问题,在网上没找到解决方法,说一下解决思路。
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_1_1' with dtype float and shape [?,224,224,3] [[Node: input_1_1 = Placeholder[dtype=DT_FLOAT, shape=[?,224,224,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

1、验证输入

错误信息给的提示很明显,输入不对。首先查了一下输入的shape和dtype,因为是测试,输入[1,224,224,3],元素类型float32。没问题呀!次奥,错误信息没用了。

2、验证网络结构

查看你的预测和训练网络结构是否一样,不仅仅是大小,包括卷积类型"same'和"valid"也要一致,因为用的是pb模型,肯定不是这个原因。当然我在找原因的时候也是直接不考虑这个原因,不过网上貌似有人是这个原因。

3、是否多模型结合

如果不是前面两个原因,百分之九十这个原因了。你的模型是不是含有多个模型?tensorflow、keras在多个模型存在情况下可能会造成graph冲突报错。换句话说就是,大家都在一个graph里面混乱了,因为之前遇到过这种问题,但是错误信息不一样,在排除第一种情况以后,内心里面已经在猜测是这个原因了。然后把两个模型单独测试,果然都没问题。找到原因了。

4、解决

既然是graph冲突,那就新建一个graph好了。
如果是pb模型,可以按照下面方法解决:

with tf.Graph().as_default():
    with tf.Session(config=config) as sess:
        with gfile.FastGFile(AngleModelPb, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')

如果是keras、tensorflow其他模型,可以直接在session.run之前获取默认graph即可:

graph = tf.get_default_graph()
......
with graph.as_default():
    sess.run()

总之,解决思路就是把graph隔离。

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐