1.掩码(Mask)简介

在 Transformer 中,掩码(Mask)是核心组件之一,用于控制注意力机制的“可见范围”,保证模型的正确性和合理性。Transformer函数接口有六个掩码参数,掩码类型可以分两大类掩码填充掩码

  Tensor forward(
      const Tensor& src,
      const Tensor& tgt,
      const Tensor& src_mask = {},
      const Tensor& tgt_mask = {},
      const Tensor& memory_mask = {},
      const Tensor& src_key_padding_mask = {},
      const Tensor& tgt_key_padding_mask = {},
      const Tensor& memory_key_padding_mask = {});

掩码 

1.src_mask

2.tgt_mask

3.memory_mask

ByteTensor BoolTensor FloatTensor

非0值: 屏蔽(忽略)

0值:   正常

true:    屏蔽(忽略)

false:    正常

数值加到注意力权重上

 -inf(负无穷):  屏蔽(忽略)

填充掩码 

1.src_key_padding_mask

2.tgt_key_padding_mask

3.memory_key_padding_mask

ByteTensor BoolTensor -

非0值: 屏蔽(忽略)

0值:     正常

true:    屏蔽(忽略)

false:    正常

-

以下逐一详解。

2. 源掩码 (src_mask)

用于屏蔽源序列(src)自身的无效位置(如自定义的禁止关注区域)。

适用场景极少用,一般情况下无需屏蔽, 张量形状要求[seq,seq], seq 是源序列的长度

 auto src_mask = torch::Tensor();
// 或者 
 auto srclen =src.size(1);
 auto src_mask =torch::zeros({srclen,srclen}, torch::kBool);

3.目标掩码(tgt_mask)

核心作用是屏蔽未来位置,如在预测第 i 个词时,不能提前看到i后面的词(i+1,i+2 ... 位置的词)

形状要求[seq,seq], seq 是目标序列的长度,

适用场景:所有序列生成任务必须使用

transformer提供generate_square_subsequent_mask()函数创建 (上三角掩码)

   auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(1));
   std::cout << "tgt_mask\n" << tgt_mask << std::endl;

/*

tgt_mask
 0 -inf -inf -inf -inf -inf -inf -inf -inf -inf
 0  0 -inf -inf -inf -inf -inf -inf -inf -inf
 0  0  0 -inf -inf -inf -inf -inf -inf -inf
 0  0  0  0 -inf -inf -inf -inf -inf -inf
 0  0  0  0  0 -inf -inf -inf -inf -inf
 0  0  0  0  0  0 -inf -inf -inf -inf
 0  0  0  0  0  0  0 -inf -inf -inf
 0  0  0  0  0  0  0  0 -inf -inf
 0  0  0  0  0  0  0  0  0 -inf
 0  0  0  0  0  0  0  0  0  0
[ CPUFloatType{10,10} ]
*/

4. 记忆(内存)掩码(memory_mask)

作用编码器输出到解码器输时屏蔽某些无效位置,

形状要求[tgt_seq,src_seq],tgt_seq 是目标序列的长度,src_seq是源序列的长度

默认情况下,无需屏蔽

5.源序列填充掩码(src_key_padding_mask)

作用源序列长度小于目标序列长度时,填充源序列长度让它与目标序列长度相等

形状要求[batch,seq]seq是源序列的长度,batch批量大小

句子"Welcome to Machine Learning  Pad Pad Pad Pad" 数字编码[1,2,4,6,0,0,0,0]

适用场景:几乎所有序列任务必须使用


auto src_key_padding_mask = (src == PadId).to(torch::kBool);  // [batch,seq]
std::cout << "src_key_padding_mask\n" << src_key_padding_mask << std::endl;
/*
src_key_padding_mask
 0  0  0  0  1  1  1  1
[ CPUBoolType{1,8} ]
*/

6.目标序列填充掩码(tgt_key_padding_mask)

原理与源序列填充掩码相同,当目标序列长度小于源序列长度时填充目标序列长度

 auto tgt_key_padding_mask = (tgt == PadId).to(torch::kBool);

7. 记忆填充掩码(memory_key_padding_mask)

一般情况下和 源序列填充掩码是同一个掩码!直接复用即可

8.掩码总结

掩码类型 形状要求 场景
src_mask [seq,seq] 极少用 空张量
tgt_mask [seq, seq]

必须使用

屏蔽未来位置

memory_mask

[tgt_seq,

src_seq]

默认情况下,无需屏蔽
src_key_padding_mask [batch,seq]

必须使用

对齐目标序列长度

tgt_key_padding_mask [batch,seq]

必须使用

对齐源序列长度

memory_key_padding_mask [batch,seq] 同src_key_padding_mask

9. 完善Transformer示例

1. 损失函数,使用了掩码之后计算损失函数应该过滤掩码

    auto options = torch::nn::CrossEntropyLossOptions().ignore_index(PadId);
    torch::nn::CrossEntropyLoss loss_fn(options);

2.训练模型,加掩码处理

 torch::Tensor forward(torch::Tensor src, torch::Tensor tgt)
 {
     auto none_mask = torch::Tensor();
     auto srclen =src.size(1);
     auto src_mask =torch::zeros({srclen,srclen}, torch::kBool);
     auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(1));
     auto src_key_padding_mask = (src == PadId).to(torch::kBool);  // [batch,seq]
     auto tgt_key_padding_mask = (tgt == PadId).to(torch::kBool);  // [batch,seq]
     auto memory_key_padding_mask = src_key_padding_mask;
 
       ......
     // tgt & src: (seq, batch, dim)
     //auto outs = transformer->forward(src, tgt);
     auto outs = transformer->forward(src, tgt, src_mask, tgt_mask, none_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask);

     outs = fc->forward(outs);
     
     return outs;

 }

3.加掩码之后运行效果

模型参数有30多万,只有翻译“Tutorials”时不对,可以说Transformer功能很强大。

4. Transformer示例完整代码

#include <torch/torch.h>
#include <iostream>
#include <torch/serialize.h>
#include <regex>
//#include <iostream>
#include <fstream>


#define  dim_model   128
#define  dim_feed    256
#define  max_vocab_len  500
#define  max_train       100

#define PadId 0

typedef const std::unordered_map< std::string, int64_t> TableVocab;
typedef std::vector<std::pair<int64_t, int64_t>>  WordList;


/// <翻译>
///  Welcome to PyTorch Tutorials  ---> 欢迎来到派托奇教程
///  Welcome to Machine Learning -----> 欢迎来到机器学习
/// </翻译>
TableVocab src_vocab =
{
    {"Pad",PadId},
    {"Welcome",1},
    {"to",2},
    {"PyTorch",3},
    {"Machine",4},
    {"Tutorials",5},
    {"Learning",6 }
};

TableVocab tgt_vocab =
{
    {"Pad",PadId},
    {"S",1},
    {"E",2},
    {"欢",3},
    {"迎",4},
    {"来",5},
    {"到",6},
    {"派",7},
    {"托",8},
    {"奇",9},
    {"教",10},
    {"程",11},
    {"机",12},
    {"器",13},
    {"学",14},
    {"习",15}
};


int64_t src_vocab_size = src_vocab.size();

int64_t tgt_vocab_size = tgt_vocab.size();


std::vector<std::string> Split(const std::string& s)
{
    std::vector<std::string> res;
    std::stringstream ss(s);
    std::string word;

    while (ss >> word)
    {
        res.push_back(word);
    }
    return res;
}

std::string GetWordById(TableVocab& vocabId,int64_t dataid)
{
    std::string Word = { "0" };
    for (auto& w : vocabId)
    {
        if (w.second == dataid)
        {
            Word = w.first;
            break;
        }
    }
    return Word;
}

std::vector<int64_t> GetWordId(TableVocab& vocabId,std::string data)
{
    std::vector<int64_t> input;
    for (auto ch : Split(data))
    {
        input.push_back(vocabId.at(ch));
    }
    return input;
}

WordList GetLoadDataWordId(std::pair<std::string, std::string> data)
{
    std::vector<int64_t> input = GetWordId(src_vocab,data.first);
    std::vector<int64_t> target = GetWordId(tgt_vocab,data.second);

    WordList item;
    for (int i = 0; i < input.size() && i < target.size(); i++)
    {
        item.push_back({ input.at(i),target.at(i) });
    }

    return item;
}



class translatDataset : public torch::data::Dataset<translatDataset>
{
public:
    
    translatDataset()
    {
        wordCount.push_back(GetLoadDataWordId({ "Welcome to PyTorch Tutorials Pad Pad Pad Pad Pad","欢 迎 来 到 派 托 奇 教 程" }));
        wordCount.push_back(GetLoadDataWordId({ "Welcome to Machine Learning  Pad Pad Pad Pad","欢 迎 来 到 机 器 学 习" }));

    }


    torch::optional<size_t> size() const
    {
        return wordCount.size();
    }

    torch::data::Example<torch::Tensor, torch::Tensor>  get(size_t index) override
    {
        auto item = wordCount[index];
        std::vector<int64_t> tmpinput;
        std::vector<int64_t> tmptarget1;
   

        for each(auto& i in item)
        {
            tmpinput.push_back(i.first);
            tmptarget1.push_back(i.second);
  
        }

        auto input = torch::tensor(tmpinput, torch::kLong);
        auto target = torch::tensor(tmptarget1, torch::kLong);


        return { input, target};
    }

public:
    std::vector<WordList> wordCount;

};

std::pair<torch::Tensor, torch::Tensor>  CreateDecoderInputOutput(torch::Tensor data)
{
   auto E = torch::tensor(GetWordId(tgt_vocab, "E"), torch::kLong).view({1,1});
   auto S = torch::tensor(GetWordId(tgt_vocab, "S"), torch::kLong).view({ 1,1 });

   auto  input = torch::cat({ S, data }, 1);
   auto  output = torch::cat({data,E}, 1);

   return { input ,output };
}

class PositionalEncodingImpl :public torch::nn::Module
{
public:
    PositionalEncodingImpl(int64_t d_model, int64_t max_len)
    {
        _d_model = d_model;
        _max_len = max_len;
        _posEncode = torch::zeros({ _max_len, _d_model }, torch::kFloat32);

        Encoding();

        register_buffer("posEncode", _posEncode);
    }


    torch::Tensor forward(torch::Tensor x)
    {
        if ((x.dim() == 2))
        {
            x = x.unsqueeze_(-2);
        }

        auto dim = x.size(0);

        // std::cout <<"pos " << _posEncode.slice(0, 0, dim).sizes() << std::endl;
         //std::cout <<"x " << x.sizes() << std::endl;
        x = x + _posEncode.slice(0, 0, dim);
        return  x;
    }

private:
    void Encoding()
    {
        auto pos = torch::arange(0, _max_len, torch::kFloat32).reshape({ _max_len, 1 });
        auto den_indices = torch::arange(0, _d_model, 2, torch::kFloat32);
        auto den = torch::exp(-den_indices * std::log(10000.0f) / _d_model);
        _posEncode.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(0, _d_model, 2) }, torch::sin(pos * den));
        _posEncode.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(1, _d_model, 2) }, torch::cos(pos * den));
        _posEncode.unsqueeze_(-2);

    }

public:
    torch::Tensor _posEncode;
    int64_t _d_model = dim_model;
    int64_t _max_len = max_vocab_len;
};
TORCH_MODULE(PositionalEncoding);


class TranslatorImpl : public torch::nn::Module
{
public:
    TranslatorImpl()
    {
        src_emb = register_module("src_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(src_vocab_size, dim_model)));
        tgt_emb_ = register_module("tgt_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(tgt_vocab_size, dim_model)));
        pos_encoder = register_module("pos_encoder", PositionalEncoding(dim_model, max_vocab_len));
        torch::nn::TransformerOptions opts;
        opts.nhead(2);
        opts.dim_feedforward(dim_feed);
        opts.num_decoder_layers(1);
        opts.num_encoder_layers(1);
        opts.dropout(0.0);
        opts.d_model(dim_model);
        transformer = register_module("transformer", torch::nn::Transformer(opts));
        fc = register_module("fc", torch::nn::Linear(dim_model, tgt_vocab_size));

    }

    torch::Tensor forward(torch::Tensor src, torch::Tensor tgt)
    {
        auto none_mask = torch::Tensor();
        auto srclen =src.size(1);
        auto src_mask =torch::zeros({srclen,srclen}, torch::kBool);
        auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(1));
        auto src_key_padding_mask = (src == PadId).to(torch::kBool);  // [batch,seq]
        auto tgt_key_padding_mask = (tgt == PadId).to(torch::kBool);  // [batch,seq]
        auto memory_key_padding_mask = src_key_padding_mask;
        //std::cout << "tgt_mask\n" << tgt_mask << std::endl;
        //std::cout << "src_key_padding_mask\n" << src_key_padding_mask << std::endl;
        //std::cout << "tgt_key_padding_mask\n" << tgt_key_padding_mask << std::endl;


        //[batch, seq]  --> [seq, batch]
        src = src.permute({ 1,0 });
        tgt = tgt.permute({ 1,0 });

        //std::cout << "input " << src << std::endl;
        src = src_emb->forward(src) * std::sqrt(dim_model);
        src = pos_encoder->forward(src);

        tgt = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
        tgt = pos_encoder->forward(tgt);

        // tgt & src: (seq, batch, dim)
        //auto outs = transformer->forward(src, tgt);
        auto outs = transformer->forward(src, tgt, src_mask, tgt_mask, none_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask);

        outs = fc->forward(outs);
        
        return outs;

    }


    torch::Tensor predict(torch::Tensor src)
    {
        //std::cout << src << std::endl;
        auto srclen = src.size(0);
        auto src_mask = torch::zeros({ srclen,srclen }, torch::kBool);
        auto srcemb = src_emb->forward(src) * std::sqrt(dim_model);
        srcemb = pos_encoder->forward(srcemb);
        auto memory = transformer->encoder.forward(srcemb, src_mask);

        std::vector<int64_t> tgtpad = GetWordId(tgt_vocab, "S");
     
        int i = 0;
        while (i < tgt_vocab_size*2)
        {
            torch::Tensor tgt = torch::tensor(tgtpad, torch::kLong);
            auto tgt_mask = transformer->generate_square_subsequent_mask(tgt.size(0));
            ///std::cout << "tgt_mask " << tgt_mask << std::endl;
            auto tgt_emb = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
            tgt_emb = pos_encoder->forward(tgt_emb);
            auto out = transformer->decoder.forward(tgt_emb, memory, tgt_mask);
            out = fc->forward(out).squeeze(-2);
            auto next_token = out.argmax(-1);
            int64_t key = next_token[i].item<int64_t>();
            tgtpad.push_back(key);
            //tgtpad.insert(tgtpad.begin(), );
            if ("E" == GetWordById(tgt_vocab, key))
            {
                break;
            }
            i++;
        }
       
        //tgtpad.pop_back();
        return torch::tensor(tgtpad, torch::kLong);
    }


    torch::nn::Embedding src_emb{ nullptr };
    torch::nn::Embedding tgt_emb_{ nullptr };
    PositionalEncoding pos_encoder{ nullptr };
    torch::nn::Transformer transformer{ nullptr };
    
    torch::nn::Linear fc{ nullptr };
};
TORCH_MODULE(Translator);

void TestData(Translator& model);
void TrainData(Translator& model);


std::tuple<int64_t, int64_t, int64_t> count_model_parameters(Translator& model) 
{
    int64_t total_params = 0;
    int64_t trainable_params = 0;
    int64_t non_trainable_params = 0;

    
    for (const auto& p : model->parameters())
    {
 
        int64_t numel = p.numel();
        total_params += numel;
    
        if (p.requires_grad()) 
        {
            trainable_params += numel;
        }
        else 
        {
            non_trainable_params += numel;
        }
    }

    return { total_params, trainable_params, non_trainable_params };
}


void TransformerMain()
{
    torch::manual_seed(4);

    std::string model_path = "translator_model.pt";
    Translator model;

    auto[a,b,c] = count_model_parameters(model);

    std::cout << "模型总参数: "<< a << std::endl;
    std::cout << "可训练参数: " << b << std::endl;
    std::cout << "不可训参数: " << c << std::endl << std::endl;

    std::ifstream filem(model_path);
    bool bmodel = filem.is_open();
    if (!bmodel)
    {
        TrainData(model);
        torch::save(model, model_path);
    }
    else
    {
        torch::load(model, model_path);
        std::cout << "load model ...." << std::endl;
    }
    filem.close();
    

    TestData(model);

}




void TrainData(Translator& model)
{
    double accuracy = 0.03;

    auto datasetTrain = translatDataset().map(torch::data::transforms::Stack<>());
    auto train_data_loader = torch::data::make_data_loader(std::move(datasetTrain), torch::data::DataLoaderOptions().batch_size(1));
    auto options = torch::nn::CrossEntropyLossOptions().ignore_index(PadId);
    torch::nn::CrossEntropyLoss loss_fn(options);

    torch::optim::Adam optimizer(model->parameters(), torch::optim::AdamOptions(1e-3));
    model->train();
    std::cout << "训练模型" << std::endl;

    for (int i = 0; i < max_train; i++)
    {
        float total_loss = 0;
        for (auto& item : *train_data_loader)
        {
            /// item.data, item.target : [batch, seq]

            auto[tgtInput,tgtOutput]  = CreateDecoderInputOutput(item.target);

            auto tgtOut = model->forward(item.data, tgtInput);

            auto output = tgtOut.reshape({ -1, tgt_vocab_size });
            optimizer.zero_grad();
            auto tgt = tgtOutput.squeeze(0);
     
            auto loss = loss_fn(output, tgt);
            total_loss += loss.item<float>();
            torch::nn::utils::clip_grad_norm_(model->parameters(), 1.0);
            loss.backward();
           
            optimizer.step();

        }


        if (i % 10 == 0 || (i + 1 == max_train))
        {
            std::cout << "i: " << i + 1 << " , loss: " << total_loss << std::endl;
        }

        if (total_loss <= accuracy)
        {
            std::cout <<"i: " << i + 1  << " , loss: " << total_loss <<  std::endl << std::endl;
            break;
        }
    }
    std::cout << std::endl;
}

void TestData(Translator& model)
{
    model->eval();
    std::cout << "测试&翻译:" << std::endl;
    std::vector<std::string> tests;
    tests.push_back("Welcome");
    tests.push_back("Welcome to");
    tests.push_back("Welcome to PyTorch");
    tests.push_back("Welcome to Machine");
    tests.push_back("Welcome to PyTorch Tutorials");
    tests.push_back("Welcome to Machine Learning");

    tests.push_back("Learning");
    tests.push_back("Tutorials");
    tests.push_back("PyTorch Tutorials");
    tests.push_back("Machine Learning");

    for (auto ch : tests)
    {
        auto item = GetWordId(src_vocab,ch);
        auto src = torch::tensor(item, torch::kLong);

        auto result = model->predict(src);
        
       // std::cout << std::regex_replace(ch, std::regex("Pad"), "") << " :  ";
         std::cout << ch << " :  ";


        for (int k = 0; k < result.numel(); k++)
        {
            std::cout << GetWordById(tgt_vocab,result[k].item<int64_t>()) << " ";
        }

        std::cout << std::endl;
    }

}


感谢大家的支持,如要问题欢迎提问指正。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐