【大模型】Transformers库单机多卡推理之device_map
transformers
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
项目地址:https://gitcode.com/gh_mirrors/tra/transformers
·
Hugging Face的transformers库支持自动模型(AutoModel)的模型实例化方法,来自动载入并使用GPT、ChatGLM等模型。在AutoModel.from_pretrained() 方法中的 device_map 参数,可实现单机多卡推理。
device_map参数解析
device_map是AutoModel.from_pretrained() 方法中的一个重要参数,它用于指定模型的各个部件应加载到哪个具体的计算设备上,以实现资源的有效分配和利用。这个参数在进行模型并行或分布式训练时特别有用。
device_map 参数有 auto, balanced, balanced_low_0, sequential 几种选项,具体如下:
“auto” 和 “balanced”:将会在所有的GPU上平衡切分模型。主要是有可能发现更高效的分配策略。“balanced” 参数的功能则保持稳定。(可按需使用)“balanced_low_0”:会在除了第一个GPU上的其它GPU上平衡划分模型,并且在第一个 GPU 上占据较少资源。这个选项符合需要在第一个 GPU 上进行额外操作的需求,例如需要在第一个 GPU 执行 generate 函数(迭代过程)。(推荐使用)“sequential”:按照GPU的顺序分配模型分片,从 GPU 0 开始,直到最后的 GPU(那么最后的 GPU 往往不会被占满,和 - “balanced_low_0” 的区别就是第一个还是最后一个,以及非均衡填充),但是我在实际使用当中GPU 0 会直接爆显存了。(不推荐使用)
device_map=“auto” 代码示例
这里我们的环境为单机两张显卡,使用 device_map="auto" 来加载 ChatGLM-6B模型,观察显卡占用情况。
示例代码如下:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# 加载模型
model_path = "./model/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
text = '什么是机器学习?'
inputs = tokenizer(text, return_tensors="pt")
print(inputs)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
-
程序运行前显卡占用情况如下:

可以看到0号显卡本身被其他程序占用了约13G的显存。 -
使用
device_map="auto"后的显卡占用情况如下:

使用auto策略后,显卡0和1分别多占用了约6~7G的显存。
手动配置
- 配置为单卡
device = "cuda:1"
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map=device)
- 配置为多卡
假设想要模型的某些部分在第一张显卡,另一部分在第二张显卡,需要知道模型的层名或者按照模型的组件大小进行合理分配。不过,具体层名需要根据实际模型来确定,这里提供一个概念性的示例:
device = {
"transformer.h.0": "cuda:0", # 第一部分放在GPU 0
"transformer.h.1": "cuda:1", # 第二部分放在GPU 1
# ... 根据模型结构继续分配
}
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map=device)
参考资料
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:2 个月前 )
33868a05
* [i18n-HI] Translated accelerate page to Hindi
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
---------
Co-authored-by: Kay <kay@Kays-MacBook-Pro.local>
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> 11 天前
e2ac16b2
* rework converter
* Update modular_model_converter.py
* Update modular_model_converter.py
* Update modular_model_converter.py
* Update modular_model_converter.py
* cleaning
* cleaning
* finalize imports
* imports
* Update modular_model_converter.py
* Better renaming to avoid visiting same file multiple times
* start converting files
* style
* address most comments
* style
* remove unused stuff in get_needed_imports
* style
* move class dependency functions outside class
* Move main functions outside class
* style
* Update modular_model_converter.py
* rename func
* add augmented dependencies
* Update modular_model_converter.py
* Add types_to_file_type + tweak annotation handling
* Allow assignment dependency mapping + fix regex
* style + update modular examples
* fix modular_roberta example (wrong redefinition of __init__)
* slightly correct order in which dependencies will appear
* style
* review comments
* Performance + better handling of dependencies when they are imported
* style
* Add advanced new classes capabilities
* style
* add forgotten check
* Update modeling_llava_next_video.py
* Add prority list ordering in check_conversion as well
* Update check_modular_conversion.py
* Update configuration_gemma.py 12 天前
更多推荐

所有评论(0)