text_matching-sentence_transformers 双塔结构
transformers
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
项目地址:https://gitcode.com/gh_mirrors/tra/transformers
·
3层RoBERTa效果(RBT3)
class SentenceTransformer(nn.Layer):
def __init__(self, pretrained_model, dropout=None):
super().__init__()
self.ptm = pretrained_model
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
# num_labels = 2 (similar or dissimilar)
self.classifier = nn.Linear(self.ptm.config["hidden_size"] * 3, 2)
def forward(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):
query_token_embedding, _ = self.ptm(
query_input_ids, query_token_type_ids, query_position_ids,
query_attention_mask)
query_token_embedding = self.dropout(query_token_embedding)
query_attention_mask = paddle.unsqueeze(
(query_input_ids != self.ptm.pad_token_id
).astype(self.ptm.pooler.dense.weight.dtype),
axis=2)
# Set token embeddings to 0 for padding tokens
query_token_embedding = query_token_embedding * query_attention_mask
query_sum_embedding = paddle.sum(query_token_embedding, axis=1)
query_sum_mask = paddle.sum(query_attention_mask, axis=1)
query_mean = query_sum_embedding / query_sum_mask
title_token_embedding, _ = self.ptm(
title_input_ids, title_token_type_ids, title_position_ids,
title_attention_mask)
title_token_embedding = self.dropout(title_token_embedding)
title_attention_mask = paddle.unsqueeze(
(title_input_ids != self.ptm.pad_token_id
).astype(self.ptm.pooler.dense.weight.dtype),
axis=2)
# Set token embeddings to 0 for padding tokens
title_token_embedding = title_token_embedding * title_attention_mask
title_sum_embedding = paddle.sum(title_token_embedding, axis=1)
title_sum_mask = paddle.sum(title_attention_mask, axis=1)
title_mean = title_sum_embedding / title_sum_mask
sub = paddle.abs(paddle.subtract(query_mean, title_mean))
projection = paddle.concat([query_mean, title_mean, sub], axis=-1)
logits = self.classifier(projection)
probs = F.softmax(logits)
return probs
双塔模型结构构造
Sentence Transformer采用了双塔(Siamese)的网络结构。Query和Title分别输入ERNIE,共享一个ERNIE参数,得到各自的token embedding特征。之后对token embedding进行pooling(此处教程使用mean pooling操作),之后输出分别记作u,v。之后将三个表征(u,v,|u-v|)拼接起来,进行二分类。网络结构如上图所示。
那么Sentence Transformer采用Siamese的网路结构,是如何提升预测速度呢?
Siamese的网络结构好处在于query和title分别输入同一套网络。如在信息搜索任务中,此时就可以将数据库中的title文本提前计算好对应sequence_output特征,保存在数据库中。当用户搜索query时,只需计算query的sequence_output特征与保存在数据库中的title sequence_output特征,通过一个简单的mean_pooling和全连接层进行二分类即可。从而大幅提升预测效率,同时也保障了模型性能。
实际场景中,只要改变下模型的输出,就可以提前把所以的待检索文本向量话,然后进行存储。
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:1 个月前 )
1094dd34
* Remove type annotation
* remove print statement 15 小时前
afb35a10
Num parameters in index.json 16 小时前
更多推荐




所有评论(0)