简介

本文基于libtorch手写完整Transformer架构代码包含

1.编码器 Encoder 
2.解码器 Decoder
3.多头注意力 MultiHeadAttention
4.前馈网络 FFN

将  十一本章 “Transformer简单示例”  的libtorch库自带Transformer 换成 我们手写MyTransformer 来验证结果

Transformer架构分析

自已手写的功能肯定不如libtorch库自带Transformer功能完整,本着最简单原则,实现论文基本功能点,也能实现翻译任务,能够实现多头注意力的编写 基本上就完成了 50%,剩下的都是些拼装工作。功能分析如图有4点,1,2,4 目前未知实现细节,3是 只有一个 多头注意力要加掩码操作,解码器要配一个掩码。

1. 层归一化  直接引用libtorch库的 torch::nn::LayerNorm功能作用

    1.  训练超级稳定
    2.  加速收敛 
    3.  特征维度: 零均值(0)、单位方差(1)

2. 前馈网络

激活函数为ReLU 实现也简单

3. 目标掩码 也就上三角 实现也简单

	torch::Tensor generate_square_subsequent_mask(int64_t sz)
	{
		auto mask = torch::triu(torch::ones({ sz, sz }, torch::kFloat32), 1);
		
		mask = mask.masked_fill(mask == 1, -std::numeric_limits<float>::infinity());
		return mask;
	}

4. 从编码器输出结果流向解码器分别是K和V 见论文,解码器经过多头注意力掩码输出的是Q

注意 编码器结果只有一个值,也就是K和V相同

MyTransformer代码实现

1. 前馈网络 FFN

class FeedForwardNetImpl : public torch::nn::Module
{
public:
	FeedForwardNetImpl(int64_t dim=512, int64_t dff=2048)
	{

		ffn = register_module("SeqFFN", torch::nn::Sequential(torch::nn::Linear(dim, dff),
			torch::nn::ReLU(),
			torch::nn::Linear(dff, dim)
		));
	}

	auto forward(torch::Tensor x)
	{
		return ffn->forward(x);
	}

	torch::nn::Sequential ffn{};
	
};
TORCH_MODULE(FeedForwardNet);

只要知道维度和网络节点数,用torch::nn::Sequential组装简单化

2.MultiHeadAttention

把上一章内容修改一下,接收Q,K,V,mask参数,然后对Q,K,V权重变转(做神经网络输出),最后QKV运算

class MultiHeadAttentionImpl : public torch::nn::Module
{
public:

	//q k v : [seq, batch, dim]
	torch::Tensor forward(torch::Tensor& q, torch::Tensor& k, torch::Tensor& v,torch::Tensor mask = {})
	{
		auto q1 = Q->forward(q);
		auto k1 = Q->forward(k);
		auto v1 = Q->forward(v);

		return ScaledDotProductAttention(q1, k1, v1, mask);
	}


};
TORCH_MODULE(MultiHeadAttention);

3.编码器 Encoder 

Transformer有一个或多个编码器,用一个类EncoderLayer实现编码器全部功能,再用一个类Encoders实现多个编码器同时工作, 用torch::nn::ModuleList将EncoderLayer关联起来


class EncoderLayerImpl : public torch::nn::Module
{
public:
	EncoderLayerImpl(int64_t dim, int64_t head, int64_t dff)
	{
		torch::nn::LayerNormOptions normOpt({ dim });
		norm1 = register_module("norm1", torch::nn::LayerNorm(normOpt));
		norm2 = register_module("norm2", torch::nn::LayerNorm(normOpt));
		ffn = register_module("ffn", FeedForwardNet(dim, dff));
		attention = register_module("attention", MultiHeadAttention(dim, head));
	}

	auto forward(torch::Tensor x)
	{
		auto y = attention->forward(x,x,x);

		y = norm1->forward(x + y); ///  残差连接

		auto y2 = ffn->forward(y);

		return norm2->forward(y + y2); ///  残差连接
	}

	FeedForwardNet ffn{ nullptr };
	torch::nn::LayerNorm norm1{ nullptr }, norm2{ nullptr };
	MultiHeadAttention attention{ nullptr };
};

TORCH_MODULE(EncoderLayer);

class EncodersImpl : public torch::nn::Module
{
public:
	EncodersImpl(int64_t dim, int64_t head, int64_t ffn, int64_t layers)
	{
		moduleLayers = register_module("moduleLayers", torch::nn::ModuleList());

		for (int i = 0; i < layers; i++)
		{
			moduleLayers->push_back(EncoderLayer(dim, head, ffn));
		}
	}

	auto forward(torch::Tensor x)
	{
	
		for each(auto& item in *moduleLayers)
		{
			x = item->as<EncoderLayer>()->forward(x);
		}
		return x;
	}
	torch::nn::ModuleList moduleLayers{ nullptr };
	
};

TORCH_MODULE(Encoders);

4. 解码器 Decoder

跟编码器结构差不多,这里只列不同点

1.DecoderLayer 的 forward 要求 目标序列,编码器输出目标掩码

class DecoderLayerImpl : public torch::nn::Module
{
public:


	auto forward(torch::Tensor& tgt, torch::Tensor& memory,torch::Tensor tgtmask)
	{
		
		auto y = MaskAttention(tgt, tgtmask);

		//cout << "y\n" << y.sizes() << endl;
		//cout << "memory\n" << memory.sizes() << endl;

		auto y2 = attention2->forward(y, memory, memory);

		auto y3 = norm2->forward(y+y2); //  残差连接

		auto y4 = ffn->forward(y3);

		return norm3->forward(y3 + y4); //  残差连接
	}

private:

	torch::Tensor MaskAttention(torch::Tensor x, torch::Tensor mask)
	{
            attention->forward(x,x,x, mask);
	}

};



class DecodersImpl : public torch::nn::Module
{
public:


	torch::nn::ModuleList moduleLayers{ nullptr };

};

TORCH_MODULE(Decoders);

5.MyTransformer

  基本功能都有了这部组装一下


class MyTransformerImpl : public torch::nn::Module
{
public:
	MyTransformerImpl(int64_t dim, int64_t head, int64_t ffn, int64_t layerEncoder, int64_t layerDecoder)
	{
		src_emb = register_module("src_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(src_vocab_size, dim)));
		tgt_emb_ = register_module("tgt_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(tgt_vocab_size, dim)));
		pos_encoder = register_module("pos_encoder", PositionalEncoding(dim, max_vocab_len));
		encoders = register_module("Encoders", Encoders(dim, head, ffn, layerEncoder));
		decoders = register_module("Decoders", Decoders(dim, head, ffn, layerDecoder));
		fc = register_module("fc", torch::nn::Linear(dim, tgt_vocab_size));
	}

	torch::Tensor forward(torch::Tensor src, torch::Tensor tgt)
	{
		
	}

	torch::Tensor predict(torch::Tensor src)
	{
		
	}

private:


	torch::Tensor generate_square_subsequent_mask(int64_t sz)
	{
		auto mask = torch::triu(torch::ones({ sz, sz }, torch::kFloat32), 1);
		
		mask = mask.masked_fill(mask == 1, -std::numeric_limits<float>::infinity());
		return mask;
	}


	Encoders encoders{nullptr};
	Decoders decoders{nullptr};

	torch::nn::Embedding src_emb{ nullptr };
	torch::nn::Embedding tgt_emb_{ nullptr };
	PositionalEncoding pos_encoder{ nullptr };

	torch::nn::Linear fc{ nullptr };
};

TORCH_MODULE(MyTransformer);

6. 入口程序

void HandwrittenTransformerMain()
{
	torch::manual_seed(6);
	std::string model_path = "MyTransformer_model2.pt";
	MyTransformer model(dim_model,2, dim_feed,1,1);

	std::ifstream filem(model_path);
	bool bmodel = filem.is_open();
	if (!bmodel || true)
	{
		TrainData2(model);
		torch::save(model, model_path);
	}
	else
	{
		torch::load(model, model_path);
		std::cout << "load model ...." << std::endl;
	}

	filem.close();

	TestData2(model);
}

7. 训练代码

class MyTransformerImpl : public torch::nn::Module
{
public:


	torch::Tensor forward(torch::Tensor src, torch::Tensor tgt)
	{
		auto none_mask = torch::Tensor();

		auto tgt_mask = generate_square_subsequent_mask(tgt.size(1));
	

		//[batch, seq]  --> [seq, batch]
		src = src.permute({ 1,0 });
		tgt = tgt.permute({ 1,0 });

		//std::cout << "input " << src << std::endl;
		src = src_emb->forward(src) * std::sqrt(dim_model);
		src = pos_encoder->forward(src);

		tgt = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
		tgt = pos_encoder->forward(tgt);
	

		return TransformerForward(src, tgt, tgt_mask);
	}

	

private:
	torch::Tensor TransformerForward(torch::Tensor& src,torch::Tensor& tgt,torch::Tensor tgtmask)
	{
		 auto outputEncoder = encoders->forward(src);
		 auto  outputDecoder = decoders->forward(tgt, outputEncoder, tgtmask);
		 return fc->forward(outputDecoder);
	}



};

TORCH_MODULE(MyTransformer);

8. 测试代码

class MyTransformerImpl : public torch::nn::Module
{
public:


torch::Tensor predict(torch::Tensor src)
{
	//std::cout << src << std::endl;

	auto srcemb = src_emb->forward(src) * std::sqrt(dim_model);
	srcemb = pos_encoder->forward(srcemb);
	auto memory = encoders->forward(srcemb);

	std::vector<int64_t> tgtpad = GetWordId(tgt_vocab, "S");

	int i = 0;
	while (i < tgt_vocab_size * 2)
	{
		torch::Tensor tgt = torch::tensor(tgtpad, torch::kLong);
		auto tgt_mask = generate_square_subsequent_mask(tgt.size(0));
		///std::cout << "tgt_mask " << tgt_mask << std::endl;
		auto tgt_emb = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
		tgt_emb = pos_encoder->forward(tgt_emb);
		auto out = decoders->forward(tgt_emb, memory, tgt_mask);
		out = fc->forward(out).squeeze(-2);
		auto next_token = out.argmax(-1);
		int64_t key = next_token[i].item<int64_t>();
		tgtpad.push_back(key);
		//tgtpad.insert(tgtpad.begin(), );
		if ("E" == GetWordById(tgt_vocab, key))
		{
			break;
		}
		i++;
	}

	return torch::tensor(tgtpad, torch::kLong);
}


};

TORCH_MODULE(MyTransformer);

运行效果和完整代码

运行效果:

代码文件 TransformerTestData.h

#pragma once

#include <torch/torch.h>
#include <iostream>
#include <torch/serialize.h>
#include <regex>
#include <fstream>


#define  dim_model   128
#define  dim_feed    256
#define  max_vocab_len  500
#define  max_train       100

#define PadId 0

typedef const std::unordered_map< std::string, int64_t> TableVocab;
typedef std::vector<std::pair<int64_t, int64_t>>  WordList;


/// <翻译>
///  Welcome to PyTorch Tutorials  ---> 欢迎来到派托奇教程
///  Welcome to Machine Learning -----> 欢迎来到机器学习
/// </翻译>
extern TableVocab src_vocab;
extern TableVocab tgt_vocab;

extern  int64_t src_vocab_size;
extern  int64_t tgt_vocab_size;


std::vector<std::string> Split(const std::string& s);


std::string GetWordById(TableVocab& vocabId, int64_t dataid);


std::vector<int64_t> GetWordId(TableVocab& vocabId, std::string data);


WordList GetLoadDataWordId(std::pair<std::string, std::string> data);




class translatDataset : public torch::data::Dataset<translatDataset>
{
public:

    translatDataset()
    {
        wordCount.push_back(GetLoadDataWordId({ "Welcome to PyTorch Tutorials Pad Pad Pad Pad Pad","欢 迎 来 到 派 托 奇 教 程" }));
        wordCount.push_back(GetLoadDataWordId({ "Welcome to Machine Learning  Pad Pad Pad Pad","欢 迎 来 到 机 器 学 习" }));

    }


    torch::optional<size_t> size() const
    {
        return wordCount.size();
    }

    torch::data::Example<torch::Tensor, torch::Tensor>  get(size_t index) override
    {
        auto item = wordCount[index];
        std::vector<int64_t> tmpinput;
        std::vector<int64_t> tmptarget1;


        for each(auto& i in item)
        {
            tmpinput.push_back(i.first);
            tmptarget1.push_back(i.second);

        }

        auto input = torch::tensor(tmpinput, torch::kLong);
        auto target = torch::tensor(tmptarget1, torch::kLong);


        return { input, target };
    }

public:
    std::vector<WordList> wordCount;

};

std::pair<torch::Tensor, torch::Tensor>  CreateDecoderInputOutput(torch::Tensor data);


class PositionalEncodingImpl :public torch::nn::Module
{
public:
    PositionalEncodingImpl(int64_t d_model, int64_t max_len)
    {
        _d_model = d_model;
        _max_len = max_len;
        _posEncode = torch::zeros({ _max_len, _d_model }, torch::kFloat32);

        Encoding();

        register_buffer("posEncode", _posEncode);
    }


    torch::Tensor forward(torch::Tensor x)
    {
        if ((x.dim() == 2))
        {
            x = x.unsqueeze_(-2);
        }

        auto dim = x.size(0);

        // std::cout <<"pos " << _posEncode.slice(0, 0, dim).sizes() << std::endl;
         //std::cout <<"x " << x.sizes() << std::endl;
        x = x + _posEncode.slice(0, 0, dim);
        return  x;
    }

private:
    void Encoding()
    {
        auto pos = torch::arange(0, _max_len, torch::kFloat32).reshape({ _max_len, 1 });
        auto den_indices = torch::arange(0, _d_model, 2, torch::kFloat32);
        auto den = torch::exp(-den_indices * std::log(10000.0f) / _d_model);
        _posEncode.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(0, _d_model, 2) }, torch::sin(pos * den));
        _posEncode.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(1, _d_model, 2) }, torch::cos(pos * den));
        _posEncode.unsqueeze_(-2);

    }

public:
    torch::Tensor _posEncode;
    int64_t _d_model = dim_model;
    int64_t _max_len = max_vocab_len;
};

TORCH_MODULE(PositionalEncoding);

TransformerTestData.cpp

#include "TransformerTestData.h"

TableVocab src_vocab =
{
    {"Pad",PadId},
    {"Welcome",1},
    {"to",2},
    {"PyTorch",3},
    {"Machine",4},
    {"Tutorials",5},
    {"Learning",6 }
};

TableVocab tgt_vocab =
{
    {"Pad",PadId},
    {"S",1},
    {"E",2},
    {"欢",3},
    {"迎",4},
    {"来",5},
    {"到",6},
    {"派",7},
    {"托",8},
    {"奇",9},
    {"教",10},
    {"程",11},
    {"机",12},
    {"器",13},
    {"学",14},
    {"习",15}
};


int64_t src_vocab_size = src_vocab.size();
int64_t tgt_vocab_size = tgt_vocab.size();


std::vector<std::string> Split(const std::string& s)
{
    std::vector<std::string> res;
    std::stringstream ss(s);
    std::string word;

    while (ss >> word)
    {
        res.push_back(word);
    }
    return res;
}

std::string GetWordById(TableVocab& vocabId, int64_t dataid)
{
    std::string Word = { "0" };
    for (auto& w : vocabId)
    {
        if (w.second == dataid)
        {
            Word = w.first;
            break;
        }
    }
    return Word;
}


std::vector<int64_t> GetWordId(TableVocab& vocabId, std::string data)
{
    std::vector<int64_t> input;
    for (auto ch : Split(data))
    {
        input.push_back(vocabId.at(ch));
    }
    return input;
}

WordList GetLoadDataWordId(std::pair<std::string, std::string> data)
{
    std::vector<int64_t> input = GetWordId(src_vocab, data.first);
    std::vector<int64_t> target = GetWordId(tgt_vocab, data.second);

    WordList item;
    for (int i = 0; i < input.size() && i < target.size(); i++)
    {
        item.push_back({ input.at(i),target.at(i) });
    }

    return item;
}

std::pair<torch::Tensor, torch::Tensor>  CreateDecoderInputOutput(torch::Tensor data)
{
    auto E = torch::tensor(GetWordId(tgt_vocab, "E"), torch::kLong).view({ 1,1 });
    auto S = torch::tensor(GetWordId(tgt_vocab, "S"), torch::kLong).view({ 1,1 });

    auto  input = torch::cat({ S, data }, 1);
    auto  output = torch::cat({ data,E }, 1);

    return { input ,output };
}

HandwrittenTransformer.cpp

#include <torch/torch.h>
#include <iostream>
#include <torch/serialize.h>
#include <regex>
//#include <iostream>
#include <fstream>

#include "TransformerTestData.h"

using namespace std;


class FeedForwardNetImpl : public torch::nn::Module
{
public:
	FeedForwardNetImpl(int64_t dim=512, int64_t dff=2048)
	{

		ffn = register_module("SeqFFN", torch::nn::Sequential(torch::nn::Linear(dim, dff),
			torch::nn::ReLU(),
			torch::nn::Linear(dff, dim)
		));
	}

	auto forward(torch::Tensor x)
	{
		return ffn->forward(x);
	}

	torch::nn::Sequential ffn{};
	
};
TORCH_MODULE(FeedForwardNet);


class MultiHeadAttentionImpl : public torch::nn::Module
{
public:

	MultiHeadAttentionImpl(int64_t dim, int64_t head)
	{
		assert(dim % head == 0);

		InitQKV(dim, head);
	}

	//q k v : [seq, batch, dim]
	torch::Tensor forward(torch::Tensor& q, torch::Tensor& k, torch::Tensor& v,torch::Tensor mask = {})
	{
		auto q1 = Q->forward(q);
		auto k1 = Q->forward(k);
		auto v1 = Q->forward(v);

		return ScaledDotProductAttention(q1, k1, v1, mask);
	}


private:
	void InitQKV(int64_t dim, int64_t head)
	{
		
		auto linear = torch::nn::LinearOptions(dim, dim).bias(false);

		Q = register_module("q", torch::nn::Linear(linear));
		K = register_module("k", torch::nn::Linear(linear));
		V = register_module("v", torch::nn::Linear(linear));
		Wo = register_module("Wo", torch::nn::Linear(linear)); // 输出投影

		norm_fact = 1.0 / sqrt(dim);

		Dk = dim / head;
		H = head;	
	}

	/// q: [seq, batch, dim]
	torch::Tensor ScaledDotProductAttention(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor& mask)
	{
		/// k==v q可以不等于 k v
		auto seq = q.size(0);
		auto batch = q.size(1);
		auto dim = q.size(2);

		auto seq2 = k.size(0);
		auto batch2 = k.size(1);
		auto dim2 = k.size(2);

		q = q.view({ seq,batch,H,Dk }); //q: [seq, batch, dim] ->   [S, B, H, Dk] 
		k = k.view({ seq2,batch2,H,Dk });
		v = v.view({ seq2,batch2,H,Dk });

		q = q.permute({ 1,2,0,3 }); //[S, B, H, Dk] --->[B, H, S, Dk]
		k = k.permute({ 1,2,0,3 });
		v = v.permute({ 1,2,0,3 });

		auto kt = k.permute({ 0,1,3,2 }); //kt:  [B, H, S, Dk] --> [B, H, Dk, S]

		//cout << "q\n" << q.sizes() << endl;
		//cout << "kt\n" << kt.sizes() << endl;


		auto attn_score = torch::matmul(q, kt);

		attn_score = attn_score * norm_fact;
	
		if (mask.defined())
		{
			attn_score += mask;
		}

		attn_score = torch::softmax(attn_score, -1); /// attn_score: [B, H, S, S]

		auto out = torch::matmul(attn_score, v); // [B, H, S, S] * [B, H, S, Dk]  ->  out: [B, H, S, Dk]
		out = out.transpose(1, 2).contiguous().view({ seq,batch, dim }); //  [B, H, S, Dk] --> [B, S, H, Dk] -> [seq,batch, dim]
		//cout <<"out\n" << out.squeeze() << endl;
		out = Wo->forward(out);
		return out;
	}

	torch::nn::Linear Q{ nullptr };
	torch::nn::Linear K{ nullptr };
	torch::nn::Linear V{ nullptr };
	torch::nn::Linear Wo{ nullptr };

	double norm_fact = 0;
	int64_t Dk;
	int64_t H;
};
TORCH_MODULE(MultiHeadAttention);


class EncoderLayerImpl : public torch::nn::Module
{
public:
	EncoderLayerImpl(int64_t dim, int64_t head, int64_t dff)
	{
		torch::nn::LayerNormOptions normOpt({ dim });
		norm1 = register_module("norm1", torch::nn::LayerNorm(normOpt));
		norm2 = register_module("norm2", torch::nn::LayerNorm(normOpt));
		ffn = register_module("ffn", FeedForwardNet(dim, dff));
		attention = register_module("attention", MultiHeadAttention(dim, head));
	}

	auto forward(torch::Tensor x)
	{
		auto y = attention->forward(x,x,x);

		y = norm1->forward(x + y); ///  残差连接

		auto y2 = ffn->forward(y);

		return norm2->forward(y + y2); ///  残差连接
	}

	FeedForwardNet ffn{ nullptr };
	torch::nn::LayerNorm norm1{ nullptr }, norm2{ nullptr };
	MultiHeadAttention attention{ nullptr };
};

TORCH_MODULE(EncoderLayer);

class EncodersImpl : public torch::nn::Module
{
public:
	EncodersImpl(int64_t dim, int64_t head, int64_t ffn, int64_t layers)
	{
		moduleLayers = register_module("moduleLayers", torch::nn::ModuleList());

		for (int i = 0; i < layers; i++)
		{
			moduleLayers->push_back(EncoderLayer(dim, head, ffn));
		}
	}

	auto forward(torch::Tensor x)
	{
	
		for each(auto& item in *moduleLayers)
		{
			x = item->as<EncoderLayer>()->forward(x);
		}
		return x;
	}
	torch::nn::ModuleList moduleLayers{ nullptr };
	
};

TORCH_MODULE(Encoders);




class DecoderLayerImpl : public torch::nn::Module
{
public:
	DecoderLayerImpl(int64_t dim, int64_t head, int64_t dff)
	{
		torch::nn::LayerNormOptions normOpt({ dim });
		norm1 = register_module("norm1", torch::nn::LayerNorm(normOpt));
		norm2 = register_module("norm2", torch::nn::LayerNorm(normOpt));
		norm3 = register_module("norm3", torch::nn::LayerNorm(normOpt));
		ffn = register_module("ffn", FeedForwardNet(dim, dff));
		attention = register_module("attention", MultiHeadAttention(dim, head));
		attention2 = register_module("attention2", MultiHeadAttention(dim, head));
	}

	auto forward(torch::Tensor& tgt, torch::Tensor& memory,torch::Tensor tgtmask)
	{
		
		auto y = MaskAttention(tgt, tgtmask);

		//cout << "y\n" << y.sizes() << endl;
		//cout << "memory\n" << memory.sizes() << endl;

		auto y2 = attention2->forward(y, memory, memory);

		auto y3 = norm2->forward(y+y2); //  残差连接

		auto y4 = ffn->forward(y3);

		return norm3->forward(y3 + y4); //  残差连接
	}

private:

	torch::Tensor MaskAttention(torch::Tensor x, torch::Tensor mask)
	{
		auto y = attention->forward(x,x,x, mask);
		y = norm1->forward(x + y); //  残差连接
		return y;
	}

public:
	FeedForwardNet ffn{ nullptr };
	torch::nn::LayerNorm norm1{ nullptr }, norm2{ nullptr }, norm3{ nullptr };
	MultiHeadAttention attention{ nullptr };
	MultiHeadAttention attention2{ nullptr };
};

TORCH_MODULE(DecoderLayer);


class DecodersImpl : public torch::nn::Module
{
public:
	DecodersImpl(int64_t dim, int64_t head, int64_t ffn, int64_t layers)
	{
		moduleLayers = register_module("moduleLayers2", torch::nn::ModuleList());

		for (int i = 0; i < layers; i++)
		{
			moduleLayers->push_back(DecoderLayer(dim, head, ffn));
		}
	}

	auto forward(torch::Tensor& tgt, torch::Tensor& memory, torch::Tensor tgtmask)
	{

		for each(auto& item in * moduleLayers)
		{
			tgt = item->as<DecoderLayer>()->forward(tgt, memory, tgtmask);
		}

		return tgt;
	}

	torch::nn::ModuleList moduleLayers{ nullptr };

};

TORCH_MODULE(Decoders);


class MyTransformerImpl : public torch::nn::Module
{
public:
	MyTransformerImpl(int64_t dim, int64_t head, int64_t ffn, int64_t layerEncoder, int64_t layerDecoder)
	{
		src_emb = register_module("src_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(src_vocab_size, dim)));
		tgt_emb_ = register_module("tgt_emb", torch::nn::Embedding(torch::nn::EmbeddingOptions(tgt_vocab_size, dim)));
		pos_encoder = register_module("pos_encoder", PositionalEncoding(dim, max_vocab_len));
		encoders = register_module("Encoders", Encoders(dim, head, ffn, layerEncoder));
		decoders = register_module("Decoders", Decoders(dim, head, ffn, layerDecoder));
		fc = register_module("fc", torch::nn::Linear(dim, tgt_vocab_size));
	}

	torch::Tensor forward(torch::Tensor src, torch::Tensor tgt)
	{
		auto none_mask = torch::Tensor();

		auto tgt_mask = generate_square_subsequent_mask(tgt.size(1));
	

		//[batch, seq]  --> [seq, batch]
		src = src.permute({ 1,0 });
		tgt = tgt.permute({ 1,0 });

		//std::cout << "input " << src << std::endl;
		src = src_emb->forward(src) * std::sqrt(dim_model);
		src = pos_encoder->forward(src);

		tgt = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
		tgt = pos_encoder->forward(tgt);
	

		return TransformerForward(src, tgt, tgt_mask);
	}

	torch::Tensor predict(torch::Tensor src)
	{
		//std::cout << src << std::endl;

		auto srcemb = src_emb->forward(src) * std::sqrt(dim_model);
		srcemb = pos_encoder->forward(srcemb);
		auto memory = encoders->forward(srcemb);

		std::vector<int64_t> tgtpad = GetWordId(tgt_vocab, "S");

		int i = 0;
		while (i < tgt_vocab_size * 2)
		{
			torch::Tensor tgt = torch::tensor(tgtpad, torch::kLong);
			auto tgt_mask = generate_square_subsequent_mask(tgt.size(0));
			///std::cout << "tgt_mask " << tgt_mask << std::endl;
			auto tgt_emb = tgt_emb_->forward(tgt) * std::sqrt(dim_model);
			tgt_emb = pos_encoder->forward(tgt_emb);
			auto out = decoders->forward(tgt_emb, memory, tgt_mask);
			out = fc->forward(out).squeeze(-2);
			auto next_token = out.argmax(-1);
			int64_t key = next_token[i].item<int64_t>();
			tgtpad.push_back(key);
			//tgtpad.insert(tgtpad.begin(), );
			if ("E" == GetWordById(tgt_vocab, key))
			{
				break;
			}
			i++;
		}

		return torch::tensor(tgtpad, torch::kLong);
	}

private:
	torch::Tensor TransformerForward(torch::Tensor& src,torch::Tensor& tgt,torch::Tensor tgtmask)
	{
		 auto outputEncoder = encoders->forward(src);
		 auto  outputDecoder = decoders->forward(tgt, outputEncoder, tgtmask);
		 return fc->forward(outputDecoder);
	}

	torch::Tensor generate_square_subsequent_mask(int64_t sz)
	{
		auto mask = torch::triu(torch::ones({ sz, sz }, torch::kFloat32), 1);
		
		mask = mask.masked_fill(mask == 1, -std::numeric_limits<float>::infinity());
		return mask;
	}


	Encoders encoders{nullptr};
	Decoders decoders{nullptr};

	torch::nn::Embedding src_emb{ nullptr };
	torch::nn::Embedding tgt_emb_{ nullptr };
	PositionalEncoding pos_encoder{ nullptr };

	torch::nn::Linear fc{ nullptr };
};

TORCH_MODULE(MyTransformer);

void TestData2(MyTransformer& model)
{
	 
	model->eval();
	std::cout << "测试&翻译:" << std::endl;
	std::vector<std::string> tests;
	tests.push_back("Welcome");
	tests.push_back("Welcome to");
	tests.push_back("Welcome to PyTorch");
	tests.push_back("Welcome to Machine");
	tests.push_back("Welcome to PyTorch Tutorials");
	tests.push_back("Welcome to Machine Learning");

	tests.push_back("Learning");
	tests.push_back("Tutorials");
	tests.push_back("PyTorch Tutorials");
	tests.push_back("Machine Learning");

	for (auto ch : tests)
	{
		auto item = GetWordId(src_vocab, ch);
		auto src = torch::tensor(item, torch::kLong);

		auto result = model->predict(src);

		// std::cout << std::regex_replace(ch, std::regex("Pad"), "") << " :  ";
		std::cout << ch << " :  ";


		for (int k = 0; k < result.numel(); k++)
		{
			std::cout << GetWordById(tgt_vocab, result[k].item<int64_t>()) << " ";
		}

		std::cout << std::endl;
	}
	
}




void TrainData2(MyTransformer& model)
{
	double accuracy = 0.03;

	auto datasetTrain = translatDataset().map(torch::data::transforms::Stack<>());
	auto train_data_loader = torch::data::make_data_loader(std::move(datasetTrain), torch::data::DataLoaderOptions().batch_size(1));
	auto options = torch::nn::CrossEntropyLossOptions().ignore_index(PadId);
	torch::nn::CrossEntropyLoss loss_fn(options);

	torch::optim::Adam optimizer(model->parameters(), torch::optim::AdamOptions(1e-3));
	model->train();
	std::cout << "训练模型" << std::endl;

	for (int i = 0; i < max_train; i++)
	{
		float total_loss = 0;
		for (auto& item : *train_data_loader)
		{
			/// item.data, item.target : [batch, seq]

			auto [tgtInput, tgtOutput] = CreateDecoderInputOutput(item.target);

			auto tgtOut = model->forward(item.data, tgtInput);

			auto output = tgtOut.reshape({ -1, tgt_vocab_size });
			optimizer.zero_grad();
			auto tgt = tgtOutput.squeeze(0);

			auto loss = loss_fn(output, tgt);
			total_loss += loss.item<float>();
			torch::nn::utils::clip_grad_norm_(model->parameters(), 1.0);
			loss.backward();

			optimizer.step();

		}


		if (i % 10 == 0 || (i + 1 == max_train))
		{
			std::cout << "i: " << i + 1 << " , loss: " << total_loss << std::endl;
		}

		if (total_loss <= accuracy)
		{
			std::cout << "i: " << i + 1 << " , loss: " << total_loss << std::endl << std::endl;
			break;
		}
	}
	std::cout << std::endl;
}




void HandwrittenTransformerMain()
{
	torch::manual_seed(6);
	std::string model_path = "MyTransformer_model2.pt";
	MyTransformer model(dim_model,2, dim_feed,1,1);

	std::ifstream filem(model_path);
	bool bmodel = filem.is_open();
	if (!bmodel || true)
	{
		TrainData2(model);
		torch::save(model, model_path);
	}
	else
	{
		torch::load(model, model_path);
		std::cout << "load model ...." << std::endl;
	}

	filem.close();

	TestData2(model);
}


感谢大家的支持,如要问题欢迎提问指正。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐