1. 前提

最近用roberta模型需要添加special tokens,但每次运行在GPU上会报错(上面还有一堆的block)
在这里插入图片描述

而在CPU上则报错
在这里插入图片描述
网上搜了很多资料,说是如果增加了special tokens或是修改了vocab.txt,则需要加上model.resize_token_embeddings(len(tokenizer)),不然维度会不对,但一直不太清楚加在哪里,刚开始加在了dataset处理的地方,但仍然报错。

2. 具体操作

先展示一下roberta文件夹
在这里插入图片描述
added_tokens.json放需要添加的tokens

{"[CH-2]": 21133, "[CH-0]": 21131, "[CH-3]": 21134, "[CH-6]": 21137, "[CH-9]": 21140, "[CH-4]": 21135, "[CH-1]": 21132, "[CH-8]": 21139, "”": 21129, "</s>": 21130, "“": 21128, "[CH-5]": 21136, "[CH-7]": 21138}

special_tokens_map.json放特殊tokens

{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}

tokenizer_config.json放tokenizer的一些的配置

{"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "special_tokens_map_file": "special_tokens_map.json", "name_or_path": "chinese-roberta-wwm-ext", "use_fast": true, "tokenizer_file": "tokenizer.json", "tokenizer_class": "BertTokenizer"}

在bert模型代码处添上self.bert.resize_token_embeddings(len(self.tokenizer))

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config['bert_path'])
        self.tokenizer = BertTokenizer.from_pretrained(config['bert_path'])

        # self.tokenizer.add_tokens(self.new_tokens, special_tokens=True)
        self.bert.resize_token_embeddings(len(self.tokenizer))

        for param in self.bert.parameters():
            param.requires_grad = True

这样就大功告成啦~

GitHub 加速计划 / tra / transformers
130.24 K
25.88 K
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:2 个月前 )
049682a5 * Added Example Doc for token classification on all tokenClassificationModels copied from llama * Refactor code to add code sample docstrings for Gemma and Gemma2 models (including modular Gemma) * Refactor code to update model checkpoint names for Qwen2 models 18 小时前
644d5287 * docs: ko: model_doc/bartpho.md * feat: nmt draft * Update docs/source/ko/model_doc/bartpho.md * Update docs/source/ko/_toctree.yml Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> 19 小时前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐