tensorflow做交叉验证遇到InvalidArgumentError
·
原代码的逻辑是train函数构造图,并训练。val_train函数只负责切分训练集。跑代码之后遇到
InvalidArgumentError: You must feed a value for placeholder tensor '*' with dtype float
后来发现是因为每次train函数都是在default_graph上修改,所以两次train的调用,使得sess重复使用了其内部的变量,并且之前定义的placeholder也没有被feed进值。解决方法是使得每次train函数内部都在其新建的Graph中修改构造图。代码如下:
with tf.Graph().as_default():
详情请参考tf.Graph。
补充一些我个人的理解:对于我需要交叉验证的问题而言,其实我想要的是每个训练都是在各自单独的图上进行的。Session是进行资源调度分配的模块;Session可以调用Graph,然后按照Graph的路线和输入的数据进行相应的数据流动和更改;Graph里面定义的各种Variable和Operator;如果Graph没有显示定义(如with tf.Graph().as_default(): 就是指在接下来的部分使用一个新的Graph),那都是在session上的默认Graph上操作,即各种定义Variable, Placeholder等操作。那我的情况是多次交叉验证,每次训练都是在默认图上做的操作(构建模型),结果定义了好多个Placeholder等,就遇到了这个问题。解决方式就是文中提到的方式了。
接触tensorflow时间不长,多多指教啦。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)