torch.distributions.Categorical()的简单记录
一、传入参数probs和参数logits的区别
Categorical()
的参数有三个,分别为probs
,logits
,validate_args
,通过研究其源码,可以看到:
class Categorical(Distribution):
def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None): # 既不传入probs也不传入logits,抛出error
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None: # 传入probs
if probs.dim() < 1: # 传入probs维度异常,抛出error
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs / probs.sum(-1, keepdim=True)
else: # 传入logits
if logits.dim() < 1: # 传入logits维度异常,抛出error
raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1]
batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
如果传入的probs
不为空,比如传入probs=[0.4, 0.3, 0.2, 0.1]
,或者probs=[4.0, 3.0, 2.0, 1.0]
,代码是直接对传入的probs
进行归一化处理,对每个数据除以传入数据的累加和得到归一化后的数值,归一化的数据累加和为1。通过公式表示为:
p
j
=
p
j
∑
i
=
1
n
p
i
p_j = \frac{p_j}{\sum_{i=1}^{n}p_i}
pj=∑i=1npipj经过处理后的所有
p
i
p_i
pi的累加和为1,即
∑
i
=
1
n
p
i
=
1
\sum_{i=1}^{n}p_i=1
i=1∑npi=1
可以通过代码进行验证:
import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.probs) # tensor([0.4000, 0.3000, 0.2000, 0.1000])
probs = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(probs=probs)
print(pd.probs) # tensor([0.4000, 0.3000, 0.2000, 0.1000])
如果不传入probs
,传入logits
,可以看到代码的处理为:
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
这里需要稍微解释一下 logits.logsumexp()
,这里具体借用大佬的介绍,传送门在这里 —— 【关于LogSumExp】
简单来说,Categorical()
对传入的logits
数据做了以下处理:
l
o
g
(
e
x
j
∑
i
=
1
n
e
x
i
)
=
l
o
g
(
e
x
j
)
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
=
x
j
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i})
log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)简单来说,就是对logits
中的每一个数据都减去其对数指数累加和,公式的最后一部分就是代码的具体实现。公式中减号后面的部分就是LogSumExp
,看字面意思很形象。
说多无益,通过代码看一下具体使用效果,同样,我们传入logits=[0.4, 0.3, 0.2, 0.1]
,或者logits=[4.0, 3.0, 2.0, 1.0]
,具体结果如下所示:
import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.logits) # tensor([-0.4402, -1.4402, -2.4402, -3.4402])
logit = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits) # tensor([-1.2425, -1.3425, -1.4425, -1.5425])
我们对logits=[4, 3, 2, 1]
进行验证,使用上面介绍的公式:
l
o
g
(
e
x
j
∑
i
=
1
n
e
x
i
)
=
l
o
g
(
e
x
j
)
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
=
x
j
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i})
log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)
可以得到:
import math
logit = 4 - math.log(math.exp(4) + math.exp(3) + math.exp(2) + math.exp(1))
print(logits) # -0.4401896985611957
这里进一步验证了我们的想法是正确的。
这里也可以看到,如果传入的是probs=[0.4, 0.3, 0.2, 0.1]
,或者probs=[4.0, 3.0, 2.0, 1.0]
,得到的probs
都是一样的,都是tensor([0.4000, 0.3000, 0.2000, 0.1000])
。
如果传入的是logits=[0.4, 0.3, 0.2, 0.1]
,或者logits=[4.0, 3.0, 2.0, 1.0]
,得到的logits
是不一样的,logits
的结果分别是tensor([-1.2425, -1.3425, -1.4425, -1.5425])
和tensor([-0.4402, -1.4402, -2.4402, -3.4402])
,这是因为在logits
中使用到了指数,4和0.4对应的指数值是不同的,所以得到的logits
的值是不同的。
二、通过probs获取logits || 通过logits获取probs
在torch.distributions.Categorical()
中可以通过logits_to_probs
获取从logits
转换的probs
,通过probs_to_logits
获取从probs
转换的logits
,其具体实现如下所示:
import torch
import torch.nn.functional as F
def logits_to_probs(logits, is_binary=False):
r"""
Converts a tensor of logits into probabilities. Note that for the
binary case, each value denotes log odds, whereas for the
multi-dimensional case, the values along the last dimension denote
the log probabilities (possibly unnormalized) of the events.
"""
if is_binary: # 二分类
return torch.sigmoid(logits) # 二分类问题,使用sigmoid
return F.softmax(logits, dim=-1) # 多分类问题,使用softmax
def clamp_probs(probs):
eps = torch.finfo(probs.dtype).eps # 获取probs对应的dtype数据类型使得1.0 + eps != 1.0 的最小值
return probs.clamp(min=eps, max=1 - eps) # 对probs进行处理,probs的最小值为eps,最大值为1-eps
def probs_to_logits(probs, is_binary=False):
r"""
Converts a tensor of probabilities into logits. For the binary case,
this denotes the probability of occurrence of the event indexed by `1`.
For the multi-dimensional case, the values along the last dimension
denote the probabilities of occurrence of each of the events.
"""
ps_clamped = clamp_probs(probs)
if is_binary: # 二分类
return torch.log(ps_clamped) - torch.log1p(-ps_clamped) # 二分类问题,使用对数几率
return torch.log(ps_clamped) # 多分类问题,使用对数概率
@lazy_property
def logits(self):
return probs_to_logits(self.probs)
@lazy_property
def probs(self):
return logits_to_probs(self.logits)
假如传入logits=[4.0, 3.0, 2.0, 1.0]
,在上面我们介绍了,Categorical()
中使用公式:
l
o
g
(
e
x
j
∑
i
=
1
n
e
x
i
)
=
l
o
g
(
e
x
j
)
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
=
x
j
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i})
log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)对传入logits
进行处理。如果想获取Categorical.probs
,代码中的实现方式为:
def probs(self):
return logits_to_probs(self.logits)
def logits_to_probs(logits, is_binary=False):
if is_binary: # 二分类,使用sigmoid
return torch.sigmoid(logits)
return F.softmax(logits, dim=-1) # 多分类,使用softmax
logits_to_probs
函数中的is_binary
是一个布尔值,用于指示logits_to_probs
函数的输入是否为二分类问题。
如果is_binary
为True
,则表示输入为二分类问题,此时logits
表示每个样本的log odds
(对数几率),需要使用sigmoid
函数将logits
转换为probs
概率分布。如果is_binary
为False
,则表示输入为多分类问题,此时logits
表示每个样本的log probabilities
(对数概率),需要使用softmax
函数将logits
转换为probs
概率分布。
对数几率和对数概率都是常用的概率表示方法,它们的区别在于所表示的概率不同。
对数几率(log odds
)是指一个事件发生的概率与该事件不发生的概率的比值的对数。对于一个事件A,其对数几率为:
l
o
g
i
t
(
A
)
=
l
o
g
P
(
A
)
1
−
P
(
A
)
logit(A) = log\frac{P(A)}{1 - P(A)}
logit(A)=log1−P(A)P(A)
其中,P(A)
表示事件A发生的概率。对数几率的取值范围为负无穷到正无穷,当对数几率为0
时(
l
o
g
1
=
0
log1=0
log1=0),表示事件A发生的概率为0.5
,当对数几率为正数时,表示事件A发生的概率大于0.5
,当对数几率为负数时,表示事件A发生的概率小于0.5
。
对数概率(log probability
)是指一个事件发生的概率的对数。对于一个事件A,其对数概率为:
l o g ( P ( A ) ) log(P(A)) log(P(A))
其中,P(A)
表示事件A发生的概率。对数概率的取值范围为负无穷到0,当对数概率为0
时,表示事件A发生的概率为1
,当对数概率为负数时,表示事件A发生的概率小于1
。
在机器学习中,对数几率和对数概率常用于表示分类模型的输出。对于二分类问题,通常使用对数几率表示模型的输出,对于多分类问题,通常使用对数概率表示模型的输出。在实际应用中,对数几率和对数概率可以通过sigmoid
函数和softmax
函数进行转换。
在二分类问题中,logits
表示每个样本属于正类的对数几率,即
l
o
g
p
1
−
p
log\frac{p}{1-p}
log1−pp,其中p
表示样本属于正类的概率。使用sigmoid
函数将logits
转换为概率分布后,可以得到样本属于正类的概率。
在多分类问题中,logits
表示每个样本属于每个类别加粗样式的对数概率,即log(p1)
, log(p2)
, …, log(pk)
,其中p1
, p2
, …, pk
表示样本属于每个类别的概率。使用softmax
函数将logits
转换为概率分布后,可以得到样本属于每个类别的概率。
因此,is_binary
参数的作用是指示logits_to_probs
函数的输入类型,以便在转换概率分布时选择合适的函数。
通过简单的方式来说明def logits()
的具体实现原理。
- 二分类
sigmoid
的具体计算可以表示为:
X
=
[
x
1
,
x
2
]
X=[x_1, x_2]
X=[x1,x2]
s
i
g
m
o
i
d
(
X
)
=
[
1
1
+
e
−
x
1
,
1
1
+
e
−
x
2
]
sigmoid(X) = [\frac{1}{1+e^{-x_1}}, \frac{1}{1+e^{-x_2}}]
sigmoid(X)=[1+e−x11,1+e−x21]
假定传入到假定传入到Categorical()
的logits
为:
l
o
g
i
t
s
=
[
l
n
p
1
−
p
,
l
n
1
−
p
p
]
logits=[ln\frac{p}{1-p},ln\frac{1-p}{p}]
logits=[ln1−pp,lnp1−p]注意到logits
表示为对数几率的形式,p
表示样本属于正类的概率,1-p
表示样本属于负类的概率。对logits
进行sigmoid
处理可以得到:
s
i
g
m
o
i
d
(
l
i
g
i
t
s
)
=
[
1
1
+
e
−
l
n
p
1
−
p
,
1
1
+
e
−
l
n
1
−
p
p
]
=
[
1
1
+
e
l
n
1
−
p
p
,
1
1
+
e
l
n
p
1
−
p
]
=
[
1
1
+
1
−
p
p
,
1
1
+
p
1
−
p
]
=
[
1
1
p
,
1
1
1
−
p
]
=
[
p
,
1
−
p
]
\begin{aligned} sigmoid(ligits) &= [\frac{1}{1+e^{-ln\frac{p}{1-p}}}, \frac{1}{1+e^{-ln\frac{1-p}{p}}}] \\ &= [\frac{1}{1+e^{ln\frac{1-p}{p}}}, \frac{1}{1+e^{ln\frac{p}{1-p}}}] \\ &= [\frac{1}{1+\frac{1-p}{p}}, \frac{1}{1+\frac{p}{1-p}}] \\ &= [\frac{1}{\frac{1}{p}}, \frac{1}{\frac{1}{1-p}}] \\ &= [p, 1-p] \\ \end{aligned}
sigmoid(ligits)=[1+e−ln1−pp1,1+e−lnp1−p1]=[1+elnp1−p1,1+eln1−pp1]=[1+p1−p1,1+1−pp1]=[p11,1−p11]=[p,1−p]
公式中第一行将
e
−
l
n
p
1
−
p
e^{-ln\frac{p}{1-p}}
e−ln1−pp中的-
移动到
l
n
p
1
−
p
ln\frac{p}{1-p}
ln1−pp中得到
e
l
n
(
p
1
−
p
)
−
1
e^{ln(\frac{p}{1-p})^{-1}}
eln(1−pp)−1,根据指数-1
的特性将
e
l
n
(
p
1
−
p
)
−
1
e^{ln(\frac{p}{1-p})^{-1}}
eln(1−pp)−1中的分子和分母交换位置得到第二行公式,第二行公式和第三行公式使用到了下面的原理:
e
l
n
x
=
y
l
n
e
l
n
x
=
l
n
y
l
n
x
=
l
n
y
x
=
y
\begin{aligned} e^{ln \space x} &= y \\ ln \space e^{ln \space x}&= ln\space y \\ ln \space x&= ln \space y \\ x &= y \end{aligned}
eln xln eln xln xx=y=ln y=ln y=y
对于
e
l
n
1
−
p
p
e^{ln\frac{1-p}{p}}
elnp1−p,根据上面的原理,可以得到:
e
l
n
1
−
p
p
=
1
−
p
p
e^{ln\frac{1-p}{p}}=\frac{1-p}{p}
elnp1−p=p1−p由此,推出logits_to_probs
中二分类的具体实现原理,由于在def logits()
中is_binary
默认为False
,我们需要用特殊的方法来测试一下logits_to_probs
在二分类中的具体运行情况:
import torch
import math
from torch.distributions import utils
probs = torch.tensor([0.2, 0.8]) # 概率和为1
print(utils.probs_to_logits(probs, True)) # tensor([-1.3863, 1.3863])
print(math.log(0.2) - math.log(0.8)) # -1.3862943611198906
通过代码可以看到,使用二分类的logits_to_probs
的结果和我们预期想要得到的结果一致。
- 多分类
softmax
的具体计算可以表示为:
X
=
[
x
1
,
x
2
,
x
3
]
X = [x_1,x_2,x_3]
X=[x1,x2,x3]
s
o
f
t
m
a
x
(
X
)
=
[
e
x
1
e
x
1
+
e
x
2
+
e
x
3
,
e
x
2
e
x
1
+
e
x
2
+
e
x
3
,
e
x
3
e
x
1
+
e
x
2
+
e
x
3
]
softmax(X) = [\frac{e^{x_1}}{e^{x_1} + e^{x_2}+e^{x_3}}, \frac{e^{x_2}}{e^{x_1} + e^{x_2}+e^{x_3}},\frac{e^{x_3}}{e^{x_1} + e^{x_2}+e^{x_3}}]
softmax(X)=[ex1+ex2+ex3ex1,ex1+ex2+ex3ex2,ex1+ex2+ex3ex3]
假定传入到Categorical()
的logits
为:
l
o
g
i
t
s
=
[
l
1
,
l
2
,
l
3
]
logits = [l_1,l_2,l_3]
logits=[l1,l2,l3]使用公式:
l
o
g
(
e
x
j
∑
i
=
1
n
e
x
i
)
=
l
o
g
(
e
x
j
)
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
=
x
j
−
l
o
g
(
∑
i
=
1
n
e
x
i
)
log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i})
log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)对logits
进行处理后,可以得到
l
o
g
i
t
s
=
[
l
1
−
l
o
g
∑
i
=
1
3
e
l
i
,
l
2
−
l
o
g
∑
i
=
1
3
e
l
i
,
l
3
−
l
o
g
∑
i
=
1
3
e
l
i
]
logits=[l1-log\sum_{i=1}^3e^{l_i},l2-log\sum_{i=1}^3e^{l_i},l3-log\sum_{i=1}^3e^{l_i}]
logits=[l1−logi=1∑3eli,l2−logi=1∑3eli,l3−logi=1∑3eli]
对处理后的logits
进行softmax
处理,可以得到:
s
o
f
t
m
a
x
(
l
i
g
i
t
s
)
=
[
e
l
1
−
l
o
g
∑
i
=
1
3
e
l
i
e
l
1
−
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
2
−
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
3
−
l
o
g
∑
i
=
1
3
e
l
i
,
e
l
2
−
l
o
g
∑
i
=
1
3
e
l
i
e
l
1
−
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
2
−
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
3
−
l
o
g
∑
i
=
1
3
e
l
i
,
e
l
3
−
l
o
g
∑
i
=
1
3
e
l
i
e
l
1
−
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
2
−
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
3
−
l
o
g
∑
i
=
1
3
e
l
i
]
softmax(ligits)=[\frac{e^{l1-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l2-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l3-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}}]
softmax(ligits)=[el1−log∑i=13eli+el2−log∑i=13eli+el3−log∑i=13eliel1−log∑i=13eli,el1−log∑i=13eli+el2−log∑i=13eli+el3−log∑i=13eliel2−log∑i=13eli,el1−log∑i=13eli+el2−log∑i=13eli+el3−log∑i=13eliel3−log∑i=13eli]
对上述公式进行简单转化:
s
o
f
t
m
a
x
(
l
o
g
i
t
s
)
=
[
e
l
1
l
o
g
∑
i
=
1
3
e
l
i
e
l
2
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
3
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
1
l
o
g
∑
i
=
1
3
e
l
i
,
e
l
2
l
o
g
∑
i
=
1
3
e
l
i
e
l
2
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
3
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
1
l
o
g
∑
i
=
1
3
e
l
i
,
e
l
3
l
o
g
∑
i
=
1
3
e
l
i
e
l
2
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
3
l
o
g
∑
i
=
1
3
e
l
i
+
e
l
1
l
o
g
∑
i
=
1
3
e
l
i
]
softmax(logits)=[\frac{\frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}]
softmax(logits)=[log∑i=13eliel2+log∑i=13eliel3+log∑i=13eliel1log∑i=13eliel1,log∑i=13eliel2+log∑i=13eliel3+log∑i=13eliel1log∑i=13eliel2,log∑i=13eliel2+log∑i=13eliel3+log∑i=13eliel1log∑i=13eliel3]
最后可以得到:
s
o
f
t
m
a
x
(
l
o
g
i
t
s
)
=
[
e
l
1
e
l
1
+
e
l
2
+
e
l
3
,
e
l
2
e
l
1
+
e
l
2
+
e
l
3
,
e
l
3
e
l
1
+
e
l
2
+
e
l
3
]
softmax(logits)= [\frac{e^{l_1}}{e^{l_1} + e^{l_2}+e^{l_3}}, \frac{e^{l_2}}{e^{l_1} + e^{l_2}+e^{l_3}},\frac{e^{l_3}}{e^{l_1} + e^{l_2}+e^{l_3}}]
softmax(logits)=[el1+el2+el3el1,el1+el2+el3el2,el1+el2+el3el3]
即可以理解为直接对初始输入的logits
进行softmax
处理即可得到对应的probs
。
通过代码来验证一下:
import math
import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.probs) # tensor([0.6439, 0.2369, 0.0871, 0.0321])
num = math.exp(4.0)/(math.exp(4.0)+math.exp(3.0)+math.exp(2.0)+math.exp(1.0))
print(num) # 0.6439142598879722
可以看到,通过接口获取的probs
和我们手动计算得到的probs
的结果是一致的,验证了我们的想法。
如果传入的数据为probs
,使用Categorical.logits
获取对应的对数概率值,代码实现为:
torch.log(ps_clamped)
注意到代码中存在clamp
操作,如下所示:
def clamp_probs(probs):
eps = torch.finfo(probs.dtype).eps # 获取probs对应的dtype数据类型使得1.0 + eps != 1.0 的最小值
return probs.clamp(min=eps, max=1 - eps) # 对probs进行处理,probs的最小值为eps,最大值为1-eps
clamp_probs
函数接受一个参数probs
,它是一个概率分布。函数首先使用torch.finfo
函数获取probs
对应的dtype数据类型,以获取使得1.0 + eps != 1.0
的最小值eps
。然后,使用probs.clamp
函数对概率分布进行截断处理,将概率分布的最小值设为eps
,最大值设为1-eps
。这样可以避免概率分布中出现0
或1
的情况,从而避免在计算交叉熵损失时出现NaN
的情况。
这里我们不讨论 is_binary
的情况,可以看到代码实现只是简单地使用了对数转换,我们通过代码检查一下:
import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.logits) # tensor([-0.9163, -1.2040, -1.6094, -2.3026])
print(torch.log(torch.tensor(0.4))) # tensor(-0.9163)
可以看到,两者的值是一样的,验证了我们的想法。
三、sample()采样
代码简单如下所示:
def sample(self, sample_shape=torch.Size()):
if not isinstance(sample_shape, torch.Size):
sample_shape = torch.Size(sample_shape)
# self._num_events = self._param.size()[-1]
probs_2d = self.probs.reshape(-1, self._num_events) # 维度变换
samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T # 采样
return samples_2d.reshape(self._extended_shape(sample_shape)) # 维度变换
sample()的操作比较简单,这里主要记录两个地方:
(1)torch.multinomial()
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor
- 对
input
的每一行进行num_samples
次取值,输出为每次取值的索引; input
每一行中的元素是该索引被采样的权重。如果元素为0,那么其他位置被采样完之前,这个位置都不会被采样;replacement=False
为不放回采样,replacement=True
为有放回采样;- 在不放回采样中,
num_samples
的值必须小于等于input.size(-1)
的值,否则会报错(在不放回采样中,每个样本只能被采样一次,样本被采样后就会从采样池中删除,不会被后续采样过程获取到该样本);
举个例子:
import torch
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
torch.multinomial(weights, 2) # tensor([1, 2])
torch.multinomial(weights, 6) # 不放回取样,报错,sample n_sample > prob_dist.size(-1)
torch.multinomial(weights, 4, replacement=True) # tensor([2, 1, 1, 2])
(2)numel()
- numel()用来统计tensor中元素的个数;
更多推荐
所有评论(0)