PyTorch中的nn.Embedding的使用、参数及案例
·
PyTorch中的nn.Embedding的使用
Embedding层在神经网络中主要起到降维或升维的作用。具体来说,它通过将输入(通常是离散的、不连续的数据,如单词或类别)映射到连续的向量空间,从而实现数据的降维或升维。
在降维方面,Embedding层可以用来降低数据的维度,减少计算和存储开销。例如,在自然语言处理任务中,词嵌入可以将每个单词表示为一个实数向量,从而将高维的词汇空间映射到一个低维的连续向量空间。这有助于提高模型的泛化能力和计算效率。
在升维方面,Embedding层可以将低维的数据映射到高维的空间,以便于提取更丰富的特征或进行更复杂的分析。例如,在图像处理任务中,嵌入层可以将图像的像素值映射到一个高维的空间,从而更好地捕捉图像中的复杂特征和结构。
此外,Embedding层还可以通过训练来学习输入数据中的语义或结构信息,使得相似的输入在嵌入空间中具有相似的向量表示。这种嵌入向量可以用于各种高级任务,如聚类、分类、推荐等。
nn.Embedding:
参数 | 作用 |
---|---|
num_embeddings | 嵌入字典的大小 |
embedding_dim | 每个嵌入词向量的大小 |
padding_idx | 可选参数:(索引指定填充)如果给定,则遇到padding_idx中的索引,将其位置填0 |
max_norm | 可选参数:如果给定,则将范数大于的每个嵌入向量重新规范化为具有范数 |
norm_type | 要为选项计算的p-norm的p |
scale_grad_by_freq | 如果给定,这将按小批量中单词频率的倒数来缩放梯度 |
sparse | 稀疏张量 |
在PyTorch中,nn.Embedding用来实现词与词向量的映射。nn.Embedding具有一个权重(weight),形状是(num_words, embedding_dim)----->(字典大小, 词向量维度)。
例子:
导包:
import torch
import torch.nn
import jieba
分词生成word2indexDict、index2wordDict:
sentence = "Embedding层是深度学习中的一种重要技术,它可以有效地处理高维、离散和非线性的数据,使得神经网络能够更好地理解和处理复杂的问题。"
wordList = list(set(" ".join(jieba.cut(sentence)).split()))
word2indexDict = dict()
index2wordDict = dict()
for index, word in enumerate(wordList):
word2indexDict[word] = index + 1
index2wordDict[index + 1] = word
方法word2index、index2word:
def word2index(wordList, word2indexDict):
indexList = list()
for word in wordList:
indexList.append(word2indexDict[word])
return indexList
def index2word(indexList, index2wordDict):
wordList = list()
for index in indexList:
wordList.append(index2wordDict[index])
return wordList
在自然语言处理中,padding是一个常见的处理方法,用于统一不同长度的序列到一个固定的长度。这主要是因为深度学习模型,特别是循环神经网络(RNN)和长短时记忆网络(LSTM),通常需要固定长度的输入序列。
- 统一不同长度的序列:当一个批次中有不同长度的句子时,需要将它们统一到一个固定的长度。padding可以将短的句子填充到固定长度,使得每个序列的长度都一致。
- 保证模型稳定性:固定长度的输入序列可以提高模型的稳定性。因为模型在训练时可以固定输入的维度,从而更容易训练和优化。
- 防止过拟合:当使用RNN或LSTM等循环神经网络时,如果输入的序列长度差异很大,模型的训练时间会增加,而且容易产生过拟合。通过padding,可以使得每个序列的长度相同,从而更容易控制模型的复杂度。
padding句子index方法:
def paddingIndexList(indexList, maxLength):
if len(indexList) > maxLength:
return indexList[:maxLength]
else:
for i in range(maxLength - len(indexList)):
indexList.append(0)
return indexList
def paddingIndex(sentenceList, word2indexDict, maxLength):
sentenceIndexList = list()
for sentence in sentenceList:
# 分词
sentenceList = " ".join(jieba.cut(sentence)).split()
indexList = word2index(sentenceList, word2indexDict)
indexList = paddingIndexList(indexList, maxLength)
sentenceIndexList.append(indexList)
return sentenceIndexList
nn.Embedding的使用:
# 创建最大词个数为词字典长度,每个词用维度为3表示
# index从1开始的所以len(word2indexDict) + 1
embedding = nn.Embedding(len(word2indexDict) + 1, 3)
# 转换为tensor
x = torch.LongTensor(x)
out = embedding(x)
# 输入的形状
print(x.shape)
# 词嵌入矩阵形状
print(out.shape)
# 词嵌入矩阵
print(out)
# 词嵌入权重
print(embedding.weight)
输出:
torch.Size([3, 12])
torch.Size([3, 12, 3])
tensor([[[-0.2272, -0.6535, 1.4155],
[-0.2961, 0.5435, 0.3829],
[ 0.9255, -1.4417, 0.6543],
[-0.0796, -0.5464, -0.1243],
[-0.7846, -1.2652, -0.8326],
[-0.2464, 1.0438, 0.3106],
[-0.4759, 0.5017, -1.6207],
[ 0.6134, 0.9412, -0.9230],
[ 0.1368, -0.8741, -2.1221],
[ 0.4055, -0.1992, 1.2301],
[-0.4898, 1.0194, 0.3048],
[-0.4898, 1.0194, 0.3048]],
[[-0.1783, 1.0126, 2.7392],
[ 2.9278, 0.6301, -1.4197],
[ 1.9807, 1.7687, -0.1354],
[-1.3087, -0.1205, 0.4163],
[ 0.2636, 0.6082, 1.6028],
[ 1.4193, -0.1932, -0.9246],
[-0.5884, -1.4016, 0.7351],
[-0.4310, 0.6724, 1.2530],
[ 1.3704, -0.0102, 0.3129],
[ 1.0435, -0.9689, 0.4574],
[-0.4759, 0.5017, -1.6207],
[ 0.5711, 1.0981, -0.1208]],
[[-0.0969, 0.4424, 2.7773],
[ 0.1563, -1.5930, 0.0504],
[ 0.2673, -0.1552, -0.9901],
[ 0.6867, -1.3702, 1.6596],
[-1.3087, -0.1205, 0.4163],
[ 1.3590, -0.4433, 0.4342],
[ 1.3704, -0.0102, 0.3129],
[ 0.2636, 0.6082, 1.6028],
[ 0.4744, -0.3505, -1.1645],
[-0.4759, 0.5017, -1.6207],
[-0.3297, 2.3305, 1.3228],
[-0.4898, 1.0194, 0.3048]]], grad_fn=<EmbeddingBackward0>)
Parameter containing:
tensor([[-0.4898, 1.0194, 0.3048],
[-0.1783, 1.0126, 2.7392],
[ 0.2673, -0.1552, -0.9901],
[ 1.4193, -0.1932, -0.9246],
[-0.4310, 0.6724, 1.2530],
[-0.2272, -0.6535, 1.4155],
[ 2.9278, 0.6301, -1.4197],
[ 0.4055, -0.1992, 1.2301],
[-0.2464, 1.0438, 0.3106],
[ 0.6777, -1.4049, -1.0489],
[ 1.9807, 1.7687, -0.1354],
[ 1.3590, -0.4433, 0.4342],
[ 1.0435, -0.9689, 0.4574],
[-0.3297, 2.3305, 1.3228],
[-0.2961, 0.5435, 0.3829],
[ 0.4744, -0.3505, -1.1645],
[ 0.6134, 0.9412, -0.9230],
[ 1.3704, -0.0102, 0.3129],
[ 0.5711, 1.0981, -0.1208],
[-0.0969, 0.4424, 2.7773],
[ 0.1563, -1.5930, 0.0504],
[-0.7846, -1.2652, -0.8326],
[ 0.2636, 0.6082, 1.6028],
[-1.3087, -0.1205, 0.4163],
[-0.0796, -0.5464, -0.1243],
[ 0.8210, 0.2765, 0.4816],
[ 0.1368, -0.8741, -2.1221],
[-0.4759, 0.5017, -1.6207],
[ 0.6867, -1.3702, 1.6596],
[ 0.9255, -1.4417, 0.6543],
[-0.5884, -1.4016, 0.7351]], requires_grad=True)
更多推荐
已为社区贡献2条内容
所有评论(0)