最近在用tensorflow平台,需要用到自己构造cost函数,故记录如下:
tensorflow求lost(cost)损失函数的几种典型实现方法
参考文献
参考文献

1:Negative log-likelihood function( 也被称作 Cross-Entropy cost function.)
一般在softmax函数之后使用 negiative log-likelihood 作为代价函数( cost function)
原理部分不再赘述:
这里写图片描述

这种典型的cost函数在tensorflo里面代码如下:

把向量化后的图片 x 和权重矩阵 W 相乘,加上偏置 b ,然后计算每个分类的softmax概率值。

y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

注意, tf.reduce_sum 把minibatch里的每张图片的交叉熵值都加起来了。我们计算的交叉熵是指整个minibatch的。

2:但是上文的网络的output先经过了softmax操作,tensorflow的默认计算cost函数tf.nn.softmax_cross_entropy_with_logits是不需要在输入前对网络的output进行softmax的:
这里写图片描述
注意这里:
WARNING: This op expects unscaled logits, since it performs a softmax
on logits internally for efficiency. Do not call this op with the
output of softmax, as it will produce incorrect results.

因此用默认函数计算cost函数方法如下:

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
                                                        onehot_labels,
                                                        name='xentropy')
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')

注意这里的onehot_labels需要对样本类别标签进行onthot编码,onehot编码是机器学习中常用的工具,假如当前训练样本共5个类,当前label为3,onthot后为00100,类别5编码后是00001
对应代码是onehot_labelstf.one_hot(y,n_classes)

3:tensorlayer中计算soct函数提供了一种不需要编码的方法:
这里的y不需要onthot编码

import tensorlayer as tl
cost = tl.cost.cross_entropy(pred, y, name='cost') 
#这里的y不需要onehot,是整数,例如3

未安装tensorlayer 的话,可参考tensorlayer

4:接下来我想讲讲对损失函数计算的具体理解
假如batchsize=8
训练样本总类别为5
那么一个batchsize的数据经过神经网络最后的输出为[8,5]
对其进行softmax操作之后每一行5个元素的含义就是属于五个类别的概率,假如当前某一个正确标签为2的softmax输出为[0.1,0.5,0.01,0.3,0..09]
那么它的cost为-log([0.1,0.5,0.01,0.3,0..09])*[0,1,0,0,0]T
[0,1,0,0,0]T是一个列向量,cost=-log(0.5)=0.3,我们的目的是使得cost最小,由于-log(1)=0是最小值,因此我们就是不断让该样本的分数向着自己正确类别对应的概率靠近。
对于同一个batch内部不用的样本,分别像上面一样计算然后求mean或者sum最后cost的值,然后最小化cost即可!

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐