原代码的逻辑是train函数构造图,并训练。val_train函数只负责切分训练集。跑代码之后遇到

InvalidArgumentError: You must feed a value for placeholder tensor '*' with dtype float

后来发现是因为每次train函数都是在default_graph上修改,所以两次train的调用,使得sess重复使用了其内部的变量,并且之前定义的placeholder也没有被feed进值。解决方法是使得每次train函数内部都在其新建的Graph中修改构造图。代码如下:

with tf.Graph().as_default():

详情请参考tf.Graph


补充一些我个人的理解:对于我需要交叉验证的问题而言,其实我想要的是每个训练都是在各自单独的图上进行的。Session是进行资源调度分配的模块;Session可以调用Graph,然后按照Graph的路线和输入的数据进行相应的数据流动和更改;Graph里面定义的各种Variable和Operator;如果Graph没有显示定义(如with tf.Graph().as_default(): 就是指在接下来的部分使用一个新的Graph),那都是在session上的默认Graph上操作,即各种定义Variable, Placeholder等操作。那我的情况是多次交叉验证,每次训练都是在默认图上做的操作(构建模型),结果定义了好多个Placeholder等,就遇到了这个问题。解决方式就是文中提到的方式了。
接触tensorflow时间不长,多多指教啦。

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐