一、传入参数probs和参数logits的区别

Categorical()的参数有三个,分别为probslogitsvalidate_args,通过研究其源码,可以看到:

class Categorical(Distribution):
	def __init__(self, probs=None, logits=None, validate_args=None):
	    if (probs is None) == (logits is None):  # 既不传入probs也不传入logits,抛出error
	        raise ValueError("Either `probs` or `logits` must be specified, but not both.")
	    if probs is not None:  # 传入probs
	        if probs.dim() < 1:  # 传入probs维度异常,抛出error
	            raise ValueError("`probs` parameter must be at least one-dimensional.")
	        self.probs = probs / probs.sum(-1, keepdim=True)
	    else:  # 传入logits
	        if logits.dim() < 1:  # 传入logits维度异常,抛出error
	            raise ValueError("`logits` parameter must be at least one-dimensional.")
	        # Normalize
	        self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
	    self._param = self.probs if probs is not None else self.logits
	    self._num_events = self._param.size()[-1]
	    batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
	    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)

如果传入的probs不为空,比如传入probs=[0.4, 0.3, 0.2, 0.1],或者probs=[4.0, 3.0, 2.0, 1.0],代码是直接对传入的probs进行归一化处理,对每个数据除以传入数据的累加和得到归一化后的数值,归一化的数据累加和为1。通过公式表示为: p j = p j ∑ i = 1 n p i p_j = \frac{p_j}{\sum_{i=1}^{n}p_i} pj=i=1npipj经过处理后的所有 p i p_i pi的累加和为1,即 ∑ i = 1 n p i = 1 \sum_{i=1}^{n}p_i=1 i=1npi=1

可以通过代码进行验证:

import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.probs)  # tensor([0.4000, 0.3000, 0.2000, 0.1000])

probs = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(probs=probs)
print(pd.probs)  # tensor([0.4000, 0.3000, 0.2000, 0.1000])

如果不传入probs,传入logits,可以看到代码的处理为:

self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)

这里需要稍微解释一下 logits.logsumexp(),这里具体借用大佬的介绍,传送门在这里 —— 【关于LogSumExp】

简单来说,Categorical()对传入的logits数据做了以下处理: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)简单来说,就是对logits中的每一个数据都减去其对数指数累加和,公式的最后一部分就是代码的具体实现。公式中减号后面的部分就是LogSumExp,看字面意思很形象。

说多无益,通过代码看一下具体使用效果,同样,我们传入logits=[0.4, 0.3, 0.2, 0.1],或者logits=[4.0, 3.0, 2.0, 1.0],具体结果如下所示:

import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.logits)  # tensor([-0.4402, -1.4402, -2.4402, -3.4402])

logit = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits)  # tensor([-1.2425, -1.3425, -1.4425, -1.5425])

我们对logits=[4, 3, 2, 1]进行验证,使用上面介绍的公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)
可以得到:

import math
logit = 4 - math.log(math.exp(4) + math.exp(3) + math.exp(2) + math.exp(1))
print(logits)  # -0.4401896985611957

这里进一步验证了我们的想法是正确的。

这里也可以看到,如果传入的是probs=[0.4, 0.3, 0.2, 0.1],或者probs=[4.0, 3.0, 2.0, 1.0],得到的probs都是一样的,都是tensor([0.4000, 0.3000, 0.2000, 0.1000])

如果传入的是logits=[0.4, 0.3, 0.2, 0.1],或者logits=[4.0, 3.0, 2.0, 1.0],得到的logits是不一样的,logits的结果分别是tensor([-1.2425, -1.3425, -1.4425, -1.5425])tensor([-0.4402, -1.4402, -2.4402, -3.4402]),这是因为在logits中使用到了指数,4和0.4对应的指数值是不同的,所以得到的logits的值是不同的。

二、通过probs获取logits || 通过logits获取probs

torch.distributions.Categorical()中可以通过logits_to_probs获取从logits转换的probs,通过probs_to_logits获取从probs转换的logits,其具体实现如下所示:

import torch
import torch.nn.functional as F
def logits_to_probs(logits, is_binary=False):
    r"""
    Converts a tensor of logits into probabilities. Note that for the
    binary case, each value denotes log odds, whereas for the
    multi-dimensional case, the values along the last dimension denote
    the log probabilities (possibly unnormalized) of the events.
    """
    if is_binary:  # 二分类
        return torch.sigmoid(logits)  # 二分类问题,使用sigmoid
    return F.softmax(logits, dim=-1)  # 多分类问题,使用softmax

def clamp_probs(probs):
    eps = torch.finfo(probs.dtype).eps  # 获取probs对应的dtype数据类型使得1.0 + eps != 1.0 的最小值
    return probs.clamp(min=eps, max=1 - eps)  # 对probs进行处理,probs的最小值为eps,最大值为1-eps

def probs_to_logits(probs, is_binary=False):
    r"""
    Converts a tensor of probabilities into logits. For the binary case,
    this denotes the probability of occurrence of the event indexed by `1`.
    For the multi-dimensional case, the values along the last dimension
    denote the probabilities of occurrence of each of the events.
    """
    ps_clamped = clamp_probs(probs)
    if is_binary:  # 二分类
        return torch.log(ps_clamped) - torch.log1p(-ps_clamped)  # 二分类问题,使用对数几率
    return torch.log(ps_clamped)  # 多分类问题,使用对数概率

@lazy_property
def logits(self):
    return probs_to_logits(self.probs)

@lazy_property
def probs(self):
    return logits_to_probs(self.logits)

假如传入logits=[4.0, 3.0, 2.0, 1.0],在上面我们介绍了,Categorical()中使用公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)对传入logits进行处理。如果想获取Categorical.probs,代码中的实现方式为:

def probs(self):
    return logits_to_probs(self.logits)

def logits_to_probs(logits, is_binary=False):
    if is_binary:  # 二分类,使用sigmoid
        return torch.sigmoid(logits)
    return F.softmax(logits, dim=-1)  # 多分类,使用softmax

logits_to_probs函数中的is_binary是一个布尔值,用于指示logits_to_probs函数的输入是否为二分类问题。

如果is_binaryTrue,则表示输入为二分类问题,此时logits表示每个样本的log odds(对数几率),需要使用sigmoid函数将logits转换为probs概率分布。如果is_binaryFalse,则表示输入为多分类问题,此时logits表示每个样本的log probabilities(对数概率),需要使用softmax函数将logits转换为probs概率分布。

对数几率对数概率都是常用的概率表示方法,它们的区别在于所表示的概率不同。

对数几率log odds)是指一个事件发生的概率与该事件不发生的概率的比值的对数。对于一个事件A,其对数几率为:

l o g i t ( A ) = l o g P ( A ) 1 − P ( A ) logit(A) = log\frac{P(A)}{1 - P(A)} logit(A)=log1P(A)P(A)
其中,P(A)表示事件A发生的概率。对数几率的取值范围为负无穷正无穷,当对数几率0时( l o g 1 = 0 log1=0 log1=0),表示事件A发生的概率为0.5,当对数几率正数时,表示事件A发生的概率大于0.5,当对数几率负数时,表示事件A发生的概率小于0.5

对数概率log probability)是指一个事件发生的概率的对数。对于一个事件A,其对数概率为:

l o g ( P ( A ) ) log(P(A)) log(P(A))

其中,P(A)表示事件A发生的概率。对数概率的取值范围为负无穷到0,当对数概率0时,表示事件A发生的概率为1,当对数概率为负数时,表示事件A发生的概率小于1

在机器学习中,对数几率对数概率常用于表示分类模型的输出。对于二分类问题,通常使用对数几率表示模型的输出,对于多分类问题,通常使用对数概率表示模型的输出。在实际应用中,对数几率对数概率可以通过sigmoid函数和softmax函数进行转换。

二分类问题中,logits表示每个样本属于正类对数几率,即 l o g p 1 − p log\frac{p}{1-p} log1pp,其中p表示样本属于正类的概率。使用sigmoid函数将logits转换为概率分布后,可以得到样本属于正类的概率

多分类问题中,logits表示每个样本属于每个类别加粗样式对数概率,即log(p1), log(p2), …, log(pk),其中p1, p2, …, pk表示样本属于每个类别的概率。使用softmax函数将logits转换为概率分布后,可以得到样本属于每个类别的概率

因此,is_binary参数的作用是指示logits_to_probs函数的输入类型,以便在转换概率分布时选择合适的函数。

通过简单的方式来说明def logits()的具体实现原理。

  • 二分类

sigmoid的具体计算可以表示为: X = [ x 1 , x 2 ] X=[x_1, x_2] X=[x1,x2] s i g m o i d ( X ) = [ 1 1 + e − x 1 , 1 1 + e − x 2 ] sigmoid(X) = [\frac{1}{1+e^{-x_1}}, \frac{1}{1+e^{-x_2}}] sigmoid(X)=[1+ex11,1+ex21]

假定传入到假定传入到Categorical()logits为: l o g i t s = [ l n p 1 − p , l n 1 − p p ] logits=[ln\frac{p}{1-p},ln\frac{1-p}{p}] logits=[ln1pp,lnp1p]注意到logits表示为对数几率的形式,p表示样本属于正类的概率,1-p表示样本属于负类的概率。对logits进行sigmoid处理可以得到: s i g m o i d ( l i g i t s ) = [ 1 1 + e − l n p 1 − p , 1 1 + e − l n 1 − p p ] = [ 1 1 + e l n 1 − p p , 1 1 + e l n p 1 − p ] = [ 1 1 + 1 − p p , 1 1 + p 1 − p ] = [ 1 1 p , 1 1 1 − p ] = [ p , 1 − p ] \begin{aligned} sigmoid(ligits) &= [\frac{1}{1+e^{-ln\frac{p}{1-p}}}, \frac{1}{1+e^{-ln\frac{1-p}{p}}}] \\ &= [\frac{1}{1+e^{ln\frac{1-p}{p}}}, \frac{1}{1+e^{ln\frac{p}{1-p}}}] \\ &= [\frac{1}{1+\frac{1-p}{p}}, \frac{1}{1+\frac{p}{1-p}}] \\ &= [\frac{1}{\frac{1}{p}}, \frac{1}{\frac{1}{1-p}}] \\ &= [p, 1-p] \\ \end{aligned} sigmoid(ligits)=[1+eln1pp1,1+elnp1p1]=[1+elnp1p1,1+eln1pp1]=[1+p1p1,1+1pp1]=[p11,1p11]=[p,1p]

公式中第一行将 e − l n p 1 − p e^{-ln\frac{p}{1-p}} eln1pp中的-移动到 l n p 1 − p ln\frac{p}{1-p} ln1pp中得到 e l n ( p 1 − p ) − 1 e^{ln(\frac{p}{1-p})^{-1}} eln(1pp)1,根据指数-1的特性将 e l n ( p 1 − p ) − 1 e^{ln(\frac{p}{1-p})^{-1}} eln(1pp)1中的分子和分母交换位置得到第二行公式,第二行公式和第三行公式使用到了下面的原理: e l n   x = y l n   e l n   x = l n   y l n   x = l n   y x = y \begin{aligned} e^{ln \space x} &= y \\ ln \space e^{ln \space x}&= ln\space y \\ ln \space x&= ln \space y \\ x &= y \end{aligned} eln xln eln xln xx=y=ln y=ln y=y
对于 e l n 1 − p p e^{ln\frac{1-p}{p}} elnp1p,根据上面的原理,可以得到: e l n 1 − p p = 1 − p p e^{ln\frac{1-p}{p}}=\frac{1-p}{p} elnp1p=p1p由此,推出logits_to_probs中二分类的具体实现原理,由于在def logits()is_binary默认为False,我们需要用特殊的方法来测试一下logits_to_probs在二分类中的具体运行情况:

import torch
import math
from torch.distributions import utils

probs = torch.tensor([0.2, 0.8])           # 概率和为1
print(utils.probs_to_logits(probs, True))  # tensor([-1.3863,  1.3863])
print(math.log(0.2) - math.log(0.8))       # -1.3862943611198906

通过代码可以看到,使用二分类的logits_to_probs的结果和我们预期想要得到的结果一致。

  • 多分类

softmax的具体计算可以表示为: X = [ x 1 , x 2 , x 3 ] X = [x_1,x_2,x_3] X=[x1,x2,x3] s o f t m a x ( X ) = [ e x 1 e x 1 + e x 2 + e x 3 , e x 2 e x 1 + e x 2 + e x 3 , e x 3 e x 1 + e x 2 + e x 3 ] softmax(X) = [\frac{e^{x_1}}{e^{x_1} + e^{x_2}+e^{x_3}}, \frac{e^{x_2}}{e^{x_1} + e^{x_2}+e^{x_3}},\frac{e^{x_3}}{e^{x_1} + e^{x_2}+e^{x_3}}] softmax(X)=[ex1+ex2+ex3ex1,ex1+ex2+ex3ex2,ex1+ex2+ex3ex3]

假定传入到Categorical()logits为: l o g i t s = [ l 1 , l 2 , l 3 ] logits = [l_1,l_2,l_3] logits=[l1,l2,l3]使用公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)logits进行处理后,可以得到 l o g i t s = [ l 1 − l o g ∑ i = 1 3 e l i , l 2 − l o g ∑ i = 1 3 e l i , l 3 − l o g ∑ i = 1 3 e l i ] logits=[l1-log\sum_{i=1}^3e^{l_i},l2-log\sum_{i=1}^3e^{l_i},l3-log\sum_{i=1}^3e^{l_i}] logits=[l1logi=13eli,l2logi=13eli,l3logi=13eli]
对处理后的logits进行softmax处理,可以得到: s o f t m a x ( l i g i t s ) = [ e l 1 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i , e l 2 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i , e l 3 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i ] softmax(ligits)=[\frac{e^{l1-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l2-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l3-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}}] softmax(ligits)=[el1logi=13eli+el2logi=13eli+el3logi=13eliel1logi=13eli,el1logi=13eli+el2logi=13eli+el3logi=13eliel2logi=13eli,el1logi=13eli+el2logi=13eli+el3logi=13eliel3logi=13eli]
对上述公式进行简单转化:
s o f t m a x ( l o g i t s ) = [ e l 1 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i , e l 2 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i , e l 3 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i ] softmax(logits)=[\frac{\frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}] softmax(logits)=[logi=13eliel2+logi=13eliel3+logi=13eliel1logi=13eliel1,logi=13eliel2+logi=13eliel3+logi=13eliel1logi=13eliel2,logi=13eliel2+logi=13eliel3+logi=13eliel1logi=13eliel3]
最后可以得到: s o f t m a x ( l o g i t s ) = [ e l 1 e l 1 + e l 2 + e l 3 , e l 2 e l 1 + e l 2 + e l 3 , e l 3 e l 1 + e l 2 + e l 3 ] softmax(logits)= [\frac{e^{l_1}}{e^{l_1} + e^{l_2}+e^{l_3}}, \frac{e^{l_2}}{e^{l_1} + e^{l_2}+e^{l_3}},\frac{e^{l_3}}{e^{l_1} + e^{l_2}+e^{l_3}}] softmax(logits)=[el1+el2+el3el1,el1+el2+el3el2,el1+el2+el3el3]
即可以理解为直接对初始输入的logits进行softmax处理即可得到对应的probs

通过代码来验证一下:

import math
import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)

print(pd.probs)  # tensor([0.6439, 0.2369, 0.0871, 0.0321])

num = math.exp(4.0)/(math.exp(4.0)+math.exp(3.0)+math.exp(2.0)+math.exp(1.0))
print(num)  # 0.6439142598879722

可以看到,通过接口获取的probs和我们手动计算得到的probs的结果是一致的,验证了我们的想法。

如果传入的数据为probs,使用Categorical.logits获取对应的对数概率值,代码实现为:

torch.log(ps_clamped)

注意到代码中存在clamp操作,如下所示:

def clamp_probs(probs):
    eps = torch.finfo(probs.dtype).eps  # 获取probs对应的dtype数据类型使得1.0 + eps != 1.0 的最小值
    return probs.clamp(min=eps, max=1 - eps)  # 对probs进行处理,probs的最小值为eps,最大值为1-eps

clamp_probs函数接受一个参数probs,它是一个概率分布。函数首先使用torch.finfo函数获取probs对应的dtype数据类型,以获取使得1.0 + eps != 1.0 的最小值eps。然后,使用probs.clamp函数对概率分布进行截断处理,将概率分布的最小值设为eps,最大值设为1-eps。这样可以避免概率分布中出现01的情况,从而避免在计算交叉熵损失时出现NaN的情况。

这里我们不讨论 is_binary的情况,可以看到代码实现只是简单地使用了对数转换,我们通过代码检查一下:

import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)

print(pd.logits)  # tensor([-0.9163, -1.2040, -1.6094, -2.3026])

print(torch.log(torch.tensor(0.4)))  # tensor(-0.9163)

可以看到,两者的值是一样的,验证了我们的想法。

三、sample()采样

代码简单如下所示:

def sample(self, sample_shape=torch.Size()):
    if not isinstance(sample_shape, torch.Size):
        sample_shape = torch.Size(sample_shape)
    # self._num_events = self._param.size()[-1]
    probs_2d = self.probs.reshape(-1, self._num_events)  # 维度变换
    samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T  # 采样
    return samples_2d.reshape(self._extended_shape(sample_shape))  # 维度变换

sample()的操作比较简单,这里主要记录两个地方:

(1)torch.multinomial()

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor
  • input的每一行进行num_samples次取值,输出为每次取值的索引;
  • input每一行中的元素是该索引被采样的权重。如果元素为0,那么其他位置被采样完之前,这个位置都不会被采样;
  • replacement=False为不放回采样,replacement=True为有放回采样;
  • 在不放回采样中,num_samples的值必须小于等于input.size(-1)的值,否则会报错(在不放回采样中,每个样本只能被采样一次,样本被采样后就会从采样池中删除,不会被后续采样过程获取到该样本);

举个例子:

import torch
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
torch.multinomial(weights, 2)  # tensor([1, 2])
torch.multinomial(weights, 6) # 不放回取样,报错,sample n_sample > prob_dist.size(-1)
torch.multinomial(weights, 4, replacement=True)  # tensor([2, 1, 1, 2])

(2)numel()

  • numel()用来统计tensor中元素的个数;
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐