tensorflow 2.0 (三)损失函数——MSE\Entropy\Hinge Loss
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
计算预测值与正确解之间的误差是很重要的一步,一般有三种计算误差的方式
1.MSE =>
有以下两种表现形式
来看看代码实现:
import tensorflow as tf
y = tf.constant([1,2,3,0,2])
y = tf.one_hot(y, depth=4)
y = tf.cast(y, dtype=tf.float32)
out = tf.random.normal([5,4])
loss1 = tf.reduce_mean(tf.square(y - out))
loss2 = tf.square(tf.norm(y - out)) / (5*4) # norm范式平方(tf.square(y-out)),再除以总的数据(其实就是reduce_mean)
loss3 = tf.reduce_mean(tf.losses.MSE(y, out)) # MSE
print(loss1)
print(loss2)
print(loss3)
结果:可以发现都是一样的
然后了解一下惊喜度的概念:越是混乱不均匀(或者可以理解为方差越大),惊喜度越高。
再来看个例子:
a = tf.fill([4], 0.25) # [0.25, 0.25, 0.25, 0.25]
b = tf.reduce_sum(a * tf.math.log(a) / tf.math.log(2.) )
# tensorflow的不方便之处,其实b = loge (a) / loge (2.) = log2 (a) (以2为底的log a)
print(-b)
#结果:
# tf.Tensor(2.0, shape=(), dtype=float32)
a = tf.constant([0.1, 0.1, 0.1, 0.7])
b = tf.reduce_sum(a * tf.math.log(a) / tf.math.log(2.) )
print(-b)
#结果:
# tf.Tensor(1.3567797, shape=(), dtype=float32)
其实上面两个例子可以看出:概率分布越均匀,惊喜度越低(我print是加了负号的),越不均匀,惊喜度越高
2.Entropy =>
下面是求导的推导公式:
1.当 i = j:
2.当 i ≠ j:
下面来看看实战:
x = tf.random.normal([2,4])
w = tf.random.normal([4,3])
b = tf.zeros([3])
y = tf.constant([2,0])
with tf.GradientTape() as tape:
tape.watch([w, b])
logits = x@w + b
loss = tf.reduce_mean(
tf.losses.categorical_crossentropy(tf.one_hot(y, depth=3), logits, from_logits=True))
# categorical 特意指的是分类问题,from_logits 使得原数据比较稳定
grads = tape.gradient(loss, [w,b]) # loss 分别对 w, b 求偏导
print(grads)
print(grads[1])
输出结果如下:
[<tf.Tensor: id=87, shape=(4, 3), dtype=float32, numpy=
array([[-0.14071652, 0.11563519, 0.02508131],
[-0.07422685, 0.01591128, 0.05831555],
[ 0.22571921, -0.11906768, -0.1066515 ],
[ 0.02883573, 0.0865538 , -0.11538951]], dtype=float32)>, <tf.Tensor: id=85, shape=(3,), dtype=float32, numpy=array([-0.2822519 , 0.32046527, -0.03821341], dtype=float32)>]
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献4条内容
所有评论(0)