AI 基础知识十二 Transformer掩码(Mask)
1.掩码(Mask)简介
在 Transformer 中,掩码(Mask)是核心组件之一,用于控制注意力机制的“可见范围”,保证模型的正确性和合理性。Transformer函数接口有六个掩码参数,掩码类型可以分两大类掩码和填充掩码
Tensor forward(
const Tensor& src,
const Tensor& tgt,
const Tensor& src_mask = {},
const Tensor& tgt_mask = {},
const Tensor& memory_mask = {},
const Tensor& src_key_padding_mask = {},
const Tensor& tgt_key_padding_mask = {},
const Tensor& memory_key_padding_mask = {});
|
掩码 1.src_mask 2.tgt_mask 3.memory_mask |
ByteTensor | BoolTensor | FloatTensor | |
|
非0值: 屏蔽(忽略) 0值: 正常 |
true: 屏蔽(忽略) false: 正常 |
数值加到注意力权重上 -inf(负无穷): 屏蔽(忽略) |
||
|
填充掩码 1.src_key_padding_mask 2.tgt_key_padding_mask 3.memory_key_padding_mask |
ByteTensor | BoolTensor | - | |
|
非0值: 屏蔽(忽略) 0值: 正常 |
true: 屏蔽(忽略) false: 正常 |
- |
以下逐一详解。
2. 源掩码 (src_mask)
用于屏蔽源序列(src)自身的无效位置(如自定义的禁止关注区域)。
适用场景极少用,一般情况下无需屏蔽, 张量形状要求[seq,seq], seq 是源序列的长度
auto src_mask = torch::Tensor();
// 或者
auto srclen =src.size(1);
auto src_mask =torch::zeros({srclen,srclen}, torch::kBool);
3.目标掩码(tgt_mask)
核心作用是屏蔽未来位置,如在预测第 i 个词时,不能提前看到i后面的词(i+1,i+2 ... 位置的词)
形状要求[seq,seq], seq 是目标序列的长度,
适用场景:所有序列生成任务必须使用
transformer提供generate_square_subsequent_mask()函数创建 (上三角掩码)
auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(1));
std::cout << "tgt_mask\n" << tgt_mask << std::endl;
/*
tgt_mask
0 -inf -inf -inf -inf -inf -inf -inf -inf -inf
0 0 -inf -inf -inf -inf -inf -inf -inf -inf
0 0 0 -inf -inf -inf -inf -inf -inf -inf
0 0 0 0 -inf -inf -inf -inf -inf -inf
0 0 0 0 0 -inf -inf -inf -inf -inf
0 0 0 0 0 0 -inf -inf -inf -inf
0 0 0 0 0 0 0 -inf -inf -inf
0 0 0 0 0 0 0 0 -inf -inf
0 0 0 0 0 0 0 0 0 -inf
0 0 0 0 0 0 0 0 0 0
[ CPUFloatType{10,10} ]
*/
4. 记忆(内存)掩码(memory_mask)
作用编码器输出到解码器输时屏蔽某些无效位置,
形状要求[tgt_seq,src_seq],tgt_seq 是目标序列的长度,src_seq是源序列的长度
默认情况下,无需屏蔽
5.源序列填充掩码(src_key_padding_mask)
作用源序列长度小于目标序列长度时,填充源序列长度让它与目标序列长度相等
形状要求[batch,seq],seq是源序列的长度,batch批量大小
句子"Welcome to Machine Learning Pad Pad Pad Pad" 数字编码[1,2,4,6,0,0,0,0]
适用场景:几乎所有序列任务必须使用
auto src_key_padding_mask = (src == PadId).to(torch::kBool); // [batch,seq]
std::cout << "src_key_padding_mask\n" << src_key_padding_mask << std::endl;
/*
src_key_padding_mask
0 0 0 0 1 1 1 1
[ CPUBoolType{1,8} ]
*/
6.目标序列填充掩码(tgt_key_padding_mask)
原理与源序列填充掩码相同,当目标序列长度小于源序列长度时填充目标序列长度
auto tgt_key_padding_mask = (tgt == PadId).to(torch::kBool);
7. 记忆填充掩码(memory_key_padding_mask)
一般情况下和 源序列填充掩码是同一个掩码!直接复用即可
8.掩码总结
| 掩码类型 | 形状要求 | 场景 | |
| src_mask | [seq,seq] | 极少用 空张量 | |
| tgt_mask | [seq, seq] |
必须使用 屏蔽未来位置 |
|
| memory_mask |
[tgt_seq, src_seq] |
默认情况下,无需屏蔽 | |
| src_key_padding_mask | [batch,seq] |
必须使用 对齐目标序列长度 |
|
| tgt_key_padding_mask | [batch,seq] |
必须使用 对齐源序列长度 |
|
| memory_key_padding_mask | [batch,seq] | 同src_key_padding_mask |
9. 完善Transformer示例
1. 损失函数,使用了掩码之后计算损失函数应该过滤掩码
auto options = torch::nn::CrossEntropyLossOptions().ignore_index(PadId);
torch::nn::CrossEntropyLoss loss_fn(options);
2.训练模型,加掩码处理
torch::Tensor forward(torch::Tensor src, torch::Tensor tgt)
{
auto none_mask = torch::Tensor();
auto srclen =src.size(1);
auto src_mask =torch::zeros({srclen,srclen}, torch::kBool);
auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(1));
auto src_key_padding_mask = (src == PadId).to(torch::kBool); // [batch,seq]
auto tgt_key_padding_mask = (tgt == PadId).to(torch::kBool); // [batch,seq]
auto memory_key_padding_mask = src_key_padding_mask;
......
// tgt & src: (seq, batch, dim)
//auto outs = transformer->forward(src, tgt);
auto outs = transformer->forward(src, tgt, src_mask, tgt_mask, none_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask);
outs = fc->forward(outs);
return outs;
}
3.加掩码之后运行效果

模型参数有30多万,只有翻译“Tutorials”时不对,可以说Transformer功能很强大。
4. Transformer示例完整代码
#include <torch/torch.h>
#include <iostream>
#include <torch/serialize.h>
#include <regex>
//#include <iostream>
#include <fstream>
#define dim_model 128
#define dim_feed 256
#define max_vocab_len 500
#define max_train 100
#define PadId 0
typedef const std::unordered_map< std::string, int64_t> TableVocab;
typedef std::vector<std::pair<int64_t, int64_t>> WordList;
/// <翻译>
/// Welcome to PyTorch Tutorials ---> 欢迎来到派托奇教程
/// Welcome to Machine Learning -----> 欢迎来到机器学习
/// </翻译>
TableVocab src_vocab =
{
{"Pad",PadId},
{"Welcome",1},
{"to",2},
{"PyTorch",3},
{"Machine",4},
{"Tutorials",5},
{"Learning",6 }
};
TableVocab tgt_vocab =
{
{"Pad",PadId},
{"S",1},
{"E",2},
{"欢",3},
{"迎",4},
{"来",5},
{"到",6},
{"派",7},
{"托",8},
{"奇",9},
{"教",10},
{"程",11},
{"机",12},
{"器",13},
{"学",14},
{"习",15}
};
int64_t src_vocab_size = src_vocab.size();
int64_t tgt_vocab_size = tgt_vocab.size();
std::vector<std::string> Split(const std::string& s)
{
std::vector<std::string> res;
std::stringstream ss(s);
std::string word;
while (ss >> word)
{
res.push_back(word);
}
return res;
}
std::string GetWordById(TableVocab& vocabId,int64_t dataid)
{
std::string Word = { "0" };
for (auto& w : vocabId)
{
if (w.second == dataid)
{
Word = w.first;
break;
}
}
return Word;
}
std::vector<int64_t> GetWordId(TableVocab& vocabId,std::string data)
{
std::vector<int64_t> input;
for (auto ch : Split(data))
{
input.push_back(vocabId.at(ch));
}
return input;
}
WordList GetLoadDataWordId(std::pair<std::string, std::string> data)
{
std::vector<int64_t> input = GetWordId(src_vocab,data.first);
std::vector<int64_t> target = GetWordId(tgt_vocab,data.second);
WordList item;
for (int i = 0; i < input.size() && i < target.size(); i++)
{
item.push_back({ input.at(i),target.at(i) });
}
return item;
}
class translatDataset : public torch::data::Dataset<translatDataset>
{
public:
translatDataset()
{
wordCount.push_back(GetLoadDataWordId({ "Welcome to PyTorch Tutorials Pad Pad Pad Pad Pad","欢 迎 来 到 派 托 奇 教 程" }));
wordCount.push_back(GetLoadDataWordId({ "Welcome to Machine Learning Pad Pad Pad Pad","欢 迎 来 到 机 器 学 习" }));
}
torch::optional<size_t> size() const
{
return wordCount.size();
}
torch::data::Example<torch::Tensor, torch::Tensor> get(size_t index) override
{
auto item = wordCount[index];
std::vector<int64_t> tmpinput;
std::vector<int64_t> tmptarget1;
for each(auto& i in item)
{
tmpinput.push_back(i.first);
tmptarget1.push_back(i.second);
}
auto input = torch::tensor(tmpinput, torch::kLong);
auto target = torch::tensor(tmptarget1, torch::kLong);
return { input, target};
}
public:
std::vector<WordList> wordCount;
};
std::pair<torch::Tensor, torch::Tensor> CreateDecoderInputOutput(torch::Tensor data)
{
auto E = torch::tensor(GetWordId(tgt_vocab, "E"), torch::kLong).view({1,1});
auto S = torch::tensor(GetWordId(tgt_vocab, "S"), torch::kLong).view({ 1,1 });
auto input = torch::cat({ S, data }, 1);
auto output = torch::cat({data,E}, 1);
return { input ,output };
}
class PositionalEncodingImpl :public torch::nn::Module
{
public:
PositionalEncodingImpl(int64_t d_model, int64_t max_len)
{
_d_model = d_model;
_max_len = max_len;
_posEncode = torch::zeros({ _max_len, _d_model }, torch::kFloat32);
Encoding();
register_buffer("posEncode", _posEncode);
}
torch::Tensor forward(torch::Tensor x)
{
if ((x.dim() == 2))
{
x = x.unsqueeze_(-2);
}
auto dim = x.size(0);
// std::cout <<"pos " << _posEncode.slice(0, 0, dim).sizes() << std::endl;
//std::cout <<"x " << x.sizes() << std::endl;
x = x + _posEncode.slice(0, 0, dim);
return x;
}
private:
void Encoding()
{
auto pos = torch::arange(0, _max_len, torch::kFloat32).reshape({ _max_len, 1 });
auto den_indices = torch::arange(0, _d_model, 2, torch::kFloat32);
auto den = torch::exp(-den_indices * std::log(10000.0f) / _d_model);
_posEncode.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(0, _d_model, 2) }, torch::sin(pos * den));
_posEncode.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(1, _d_model, 2) }, torch::cos(pos * den));
_posEncode.unsqueeze_(-2);
}
public:
torch::Tensor _posEncode;
int64_t _d_model = dim_model;
int64_t _max_len = max_vocab_len;
};
TORCH_MODULE(PositionalEncoding);
class TranslatorImpl : public torch::nn::Module
{
public:
TranslatorImpl()
{
src_emb = register_module("src_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(src_vocab_size, dim_model)));
tgt_emb_ = register_module("tgt_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(tgt_vocab_size, dim_model)));
pos_encoder = register_module("pos_encoder", PositionalEncoding(dim_model, max_vocab_len));
torch::nn::TransformerOptions opts;
opts.nhead(2);
opts.dim_feedforward(dim_feed);
opts.num_decoder_layers(1);
opts.num_encoder_layers(1);
opts.dropout(0.0);
opts.d_model(dim_model);
transformer = register_module("transformer", torch::nn::Transformer(opts));
fc = register_module("fc", torch::nn::Linear(dim_model, tgt_vocab_size));
}
torch::Tensor forward(torch::Tensor src, torch::Tensor tgt)
{
auto none_mask = torch::Tensor();
auto srclen =src.size(1);
auto src_mask =torch::zeros({srclen,srclen}, torch::kBool);
auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(1));
auto src_key_padding_mask = (src == PadId).to(torch::kBool); // [batch,seq]
auto tgt_key_padding_mask = (tgt == PadId).to(torch::kBool); // [batch,seq]
auto memory_key_padding_mask = src_key_padding_mask;
//std::cout << "tgt_mask\n" << tgt_mask << std::endl;
//std::cout << "src_key_padding_mask\n" << src_key_padding_mask << std::endl;
//std::cout << "tgt_key_padding_mask\n" << tgt_key_padding_mask << std::endl;
//[batch, seq] --> [seq, batch]
src = src.permute({ 1,0 });
tgt = tgt.permute({ 1,0 });
//std::cout << "input " << src << std::endl;
src = src_emb->forward(src) * std::sqrt(dim_model);
src = pos_encoder->forward(src);
tgt = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
tgt = pos_encoder->forward(tgt);
// tgt & src: (seq, batch, dim)
//auto outs = transformer->forward(src, tgt);
auto outs = transformer->forward(src, tgt, src_mask, tgt_mask, none_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask);
outs = fc->forward(outs);
return outs;
}
torch::Tensor predict(torch::Tensor src)
{
//std::cout << src << std::endl;
auto srclen = src.size(0);
auto src_mask = torch::zeros({ srclen,srclen }, torch::kBool);
auto srcemb = src_emb->forward(src) * std::sqrt(dim_model);
srcemb = pos_encoder->forward(srcemb);
auto memory = transformer->encoder.forward(srcemb, src_mask);
std::vector<int64_t> tgtpad = GetWordId(tgt_vocab, "S");
int i = 0;
while (i < tgt_vocab_size*2)
{
torch::Tensor tgt = torch::tensor(tgtpad, torch::kLong);
auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(0));
///std::cout << "tgt_mask " << tgt_mask << std::endl;
auto tgt_emb = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
tgt_emb = pos_encoder->forward(tgt_emb);
auto out = transformer->decoder.forward(tgt_emb, memory, tgt_mask);
out = fc->forward(out).squeeze(-2);
auto next_token = out.argmax(-1);
int64_t key = next_token[i].item<int64_t>();
tgtpad.push_back(key);
//tgtpad.insert(tgtpad.begin(), );
if ("E" == GetWordById(tgt_vocab, key))
{
break;
}
i++;
}
//tgtpad.pop_back();
return torch::tensor(tgtpad, torch::kLong);
}
torch::nn::Embedding src_emb{ nullptr };
torch::nn::Embedding tgt_emb_{ nullptr };
PositionalEncoding pos_encoder{ nullptr };
torch::nn::Transformer transformer{ nullptr };
torch::nn::Linear fc{ nullptr };
};
TORCH_MODULE(Translator);
void TestData(Translator& model);
void TrainData(Translator& model);
std::tuple<int64_t, int64_t, int64_t> count_model_parameters(Translator& model)
{
int64_t total_params = 0;
int64_t trainable_params = 0;
int64_t non_trainable_params = 0;
for (const auto& p : model->parameters())
{
int64_t numel = p.numel();
total_params += numel;
if (p.requires_grad())
{
trainable_params += numel;
}
else
{
non_trainable_params += numel;
}
}
return { total_params, trainable_params, non_trainable_params };
}
void TransformerMain()
{
torch::manual_seed(4);
std::string model_path = "translator_model.pt";
Translator model;
auto[a,b,c] = count_model_parameters(model);
std::cout << "模型总参数: "<< a << std::endl;
std::cout << "可训练参数: " << b << std::endl;
std::cout << "不可训参数: " << c << std::endl << std::endl;
std::ifstream filem(model_path);
bool bmodel = filem.is_open();
if (!bmodel)
{
TrainData(model);
torch::save(model, model_path);
}
else
{
torch::load(model, model_path);
std::cout << "load model ...." << std::endl;
}
filem.close();
TestData(model);
}
void TrainData(Translator& model)
{
double accuracy = 0.03;
auto datasetTrain = translatDataset().map(torch::data::transforms::Stack<>());
auto train_data_loader = torch::data::make_data_loader(std::move(datasetTrain), torch::data::DataLoaderOptions().batch_size(1));
auto options = torch::nn::CrossEntropyLossOptions().ignore_index(PadId);
torch::nn::CrossEntropyLoss loss_fn(options);
torch::optim::Adam optimizer(model->parameters(), torch::optim::AdamOptions(1e-3));
model->train();
std::cout << "训练模型" << std::endl;
for (int i = 0; i < max_train; i++)
{
float total_loss = 0;
for (auto& item : *train_data_loader)
{
/// item.data, item.target : [batch, seq]
auto[tgtInput,tgtOutput] = CreateDecoderInputOutput(item.target);
auto tgtOut = model->forward(item.data, tgtInput);
auto output = tgtOut.reshape({ -1, tgt_vocab_size });
optimizer.zero_grad();
auto tgt = tgtOutput.squeeze(0);
auto loss = loss_fn(output, tgt);
total_loss += loss.item<float>();
torch::nn::utils::clip_grad_norm_(model->parameters(), 1.0);
loss.backward();
optimizer.step();
}
if (i % 10 == 0 || (i + 1 == max_train))
{
std::cout << "i: " << i + 1 << " , loss: " << total_loss << std::endl;
}
if (total_loss <= accuracy)
{
std::cout <<"i: " << i + 1 << " , loss: " << total_loss << std::endl << std::endl;
break;
}
}
std::cout << std::endl;
}
void TestData(Translator& model)
{
model->eval();
std::cout << "测试&翻译:" << std::endl;
std::vector<std::string> tests;
tests.push_back("Welcome");
tests.push_back("Welcome to");
tests.push_back("Welcome to PyTorch");
tests.push_back("Welcome to Machine");
tests.push_back("Welcome to PyTorch Tutorials");
tests.push_back("Welcome to Machine Learning");
tests.push_back("Learning");
tests.push_back("Tutorials");
tests.push_back("PyTorch Tutorials");
tests.push_back("Machine Learning");
for (auto ch : tests)
{
auto item = GetWordId(src_vocab,ch);
auto src = torch::tensor(item, torch::kLong);
auto result = model->predict(src);
// std::cout << std::regex_replace(ch, std::regex("Pad"), "") << " : ";
std::cout << ch << " : ";
for (int k = 0; k < result.numel(); k++)
{
std::cout << GetWordById(tgt_vocab,result[k].item<int64_t>()) << " ";
}
std::cout << std::endl;
}
}
感谢大家的支持,如要问题欢迎提问指正。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)