项目情绪对话模型(2)
训练数据集
对比上次有细微的改变
import json
import time
import random
from zhipuai import ZhipuAI
from sentence_transformers import SentenceTransformer
import numpy as np
"""
示例数据:
# 用户输入库(可自定义扩展)
user_inputs = [
"今天心情不太好", "推荐个电影吧", "怎么才能早睡早起",
"养猫好还是养狗好", "工作压力好大", "最近总是失眠"
]
"""
# 初始化模型
client = ZhipuAI(api_key="2fe572be13854d3e9aabf4f7504bb4df.AZaoU7eDZjFLV4lC") # 替换为你的API Key
#加载Embeddingmodel
style_model = SentenceTransformer(r"D:\two\第六期\L2\15-大模型微调项目实战-训练与部署\demo15\embedding_model\sungw111\text2vec-base-chinese-sentence")
test_text = "测试文本"
vector = style_model.encode(test_text)
norm = np.linalg.norm(vector)
print(f"向量模长:{norm:.4f}")
#===============================
#1.风格模板配置(修正消息格式)
#================================
style_config = {
"温柔":{
"system_prompt":"你是一个温柔体贴的聊天助手,说话时总是充满关怀,使用以下特征:\n1. 包含'呢、呀、啦'等语气词\n2. 使用🌸💖😊等温暖表情\n3. 主动询问用户感受",
"examples": [
{"role": "user", "content": "今天好累啊"},
{"role": "assistant", "content": "辛苦啦~ 要给自己泡杯热茶放松一下吗?🌸"},
{"role": "user", "content": "考试没考好..."},
{"role": "assistant", "content": "没关系的呀~ 下次一定会更好!需要我陪你聊聊吗?😊"}
],
"temperature": 0.3
},
"毒舌":{
"system_prompt":"你是一个喜欢用犀利吐槽表达关心的朋友,需满足:\n1. 使用网络流行语(如'栓Q''退退退')\n2. 包含夸张比喻('你这速度堪比树懒')\n3. 结尾隐藏关心 \n4.需包含:'好家伙', '栓Q', '!', '🏋️'等",
"examples": [
{"role": "user", "content": "又胖了5斤!"},
{"role": "assistant", "content": "好家伙!你这是要把体重秤压成分子料理?🏋️"},
{"role": "user", "content": "游戏又输了"},
{"role": "assistant", "content": "菜就多练练!需要给你推荐《从零开始的电竞之路》吗?🎮"}
],
"temperature": 0.3
},
}
#========================
#生成函数(修正消息的结构)
#========================
def generate_style_data(style_name, num_samples=50):
config = style_config[style_name]
data = []
# 构建消息上下文(包含系统提示和示例对话)
messages = [
{"role": "system", "content": config["system_prompt"]},
*config["examples"] # 直接展开示例对话
]
# 用户输入库(可自定义扩展)
# user_inputs = [
# "今天心情不太好", "推荐个电影吧", "怎么才能早睡早起",
# "养猫好还是养狗好", "工作压力好大", "最近总是失眠"
# ]
# 从本地文件加载用户输入
user_inputs = []
with open('cleaned_output.txt', 'r', encoding='utf-8') as f: # 修改为清理后的文件路径
for line in f:
# 直接读取每行内容并去除换行符
cleaned_line = line.rstrip('\n') # 或使用 line.strip()
if cleaned_line: # 空行过滤(冗余保护)
user_inputs.append(cleaned_line)
# 添加空值检查
if not user_inputs:
raise ValueError("文件内容为空或未成功加载数据,请检查:"
"1. 文件路径是否正确 2. 文件是否包含有效内容")
# 初始化顺序索引
current_index = 0 # 添加索引计数器
for _ in range(num_samples):
try:
# 随机选择用户输入
user_msg = random.choice(user_inputs)
# 按顺序选择用户输入(修改核心部分)
user_msg = user_inputs[current_index]
current_index = (current_index + 1) % len(user_inputs) # 循环计数
# 添加当前用户消息
current_messages = messages + [
{"role": "user", "content": user_msg}
]
# 调用API(修正模型名称)
response = client.chat.completions.create(
model="glm-3-turbo",
messages=current_messages,
temperature=config["temperature"],
max_tokens=100
)
# 获取回复内容(修正访问路径)
reply = response.choices[0].message.content
# 质量过滤(数据审核)
if is_valid_reply(style_name, user_msg, reply):
data.append({
"user": user_msg,
"assistant": reply,
"style": style_name
})
time.sleep(0.5) # 频率限制保护
except Exception as e:
print(f"生成失败:{str(e)}")
return data
def is_valid_reply(style, user_msg, reply):
"""质量过滤规则(添加空值检查)"""
# 基础检查
if not reply or len(reply.strip()) == 0:
print("内容为空!")
return False
# 规则1:回复长度检查
if len(reply) < 5 or len(reply) > 150:
print("长度不够!")
return False
# 规则2:风格关键词检查
style_keywords = {
"温柔": ["呢", "呀", "😊", "🌸"],
"毒舌": ["好家伙", "栓Q", "!", "🏋️"]
}
if not any(kw in reply for kw in style_keywords.get(style, [])):
print("不包含关键词!")
return False
# 规则3:语义相似度检查
try:
ref_text = next(msg["content"] for msg in style_config[style]["examples"]
if msg["role"] == "assistant")
ref_vec = style_model.encode(ref_text)
reply_vec = style_model.encode(reply)
similarity = np.dot(ref_vec, reply_vec)
# print("======>ref_vec",ref_vec)
# print("======>reply_vec", reply_vec)
print("======>similarity", similarity)
return similarity > 0.65
except:
print("=========>相似度过低:", similarity)
return False
#=============================
#3.执行生成(添加容错)
#============================
if __name__ == '__main__':
all_data = []
try:
print("开始生成温柔风格数据...")
gentle_data = generate_style_data("温柔", 10000)
all_data.extend(gentle_data)
print("开始生成毒舌风格数据...")
sarcastic_data = generate_style_data("毒舌", 10000)
all_data.extend(sarcastic_data)
except KeyboardInterrupt:
print("\n用户中断,保存已生成数据...")
finally:
with open("style_chat_data.json", "w", encoding="utf-8") as f:
json.dump(all_data, f, ensure_ascii=False, indent=2)
print(f"数据已保存,有效样本数:{len(all_data)}")


一、embedding_model验证与转换
通常embedding_model核心目录如下:

embedding_model通常情况是自带有归一化层的。一般可以通过Normalize目录或者modules.json来判断是否准确包含归一化层。如果未包含归一化层,则需要手动转换添加:
词向量库未归化,会导致模型计算速度变慢,归化不影响模型精度
import numpy as np
from sentence_transformers import SentenceTransformer,models
model_path = r"D:\PycharmProjects\demo_15\embedding_model\sungw111\text2vec-base-chinese-sentence"
bert = models.Transformer(model_path)
pooling = models.Pooling(bert.get_word_embedding_dimension(),
pooling_mode='mean')
# 添加缺失的归一化层
normalize = models.Normalize()
# 组合完整模型
full_model = SentenceTransformer(modules=[bert, pooling, normalize])
print(full_model)
save_path=r"D:\PycharmProjects\demo_15\embedding_model\zy\text2vec-base-chinese-sentence"
full_model.save(save_path)
# 加载修复后的模型
model = SentenceTransformer(r"D:\PycharmProjects\demo_15\embedding_model\zy\text2vec-base-chinese-sentence")
# 验证向量归一化
text = "测试文本"
vec = model.encode(text)
print("修正后模长:", np.linalg.norm(vec)) # 应输出≈1.0
输出模长应该接近1
转换后目录如下:

转换后的modules.json如下
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
加载转换验证后的embedding_model输出的相似度:

二、数据集转换
1. 选择模型微调工具,根据微调框架选择对应数据集格式(此处以xtuner为例,使用xtuner默认格式)
数据集转换代码
import json
def convert_format(source_data):
target_data = []
for item in source_data:
# 构建新的对话格式
new_convo = {
"conversation": [
{
"input": item["user"],
"output": f"{item['style']}\n{item['assistant']}"
}
]
}
target_data.append(new_convo)
return target_data
# 从文件读取源数据
with open("input.json", "r", encoding="utf-8") as f:
source_data = json.load(f)
# 执行转换
converted_data = convert_format(source_data)
# 写入目标文件
with open("output.json", "w", encoding="utf-8") as f:
json.dump(converted_data, f, ensure_ascii=False, indent=2)
三、模型选型
1. 根据任务选择对应的评测数据,对预期模型客观评测

当前任务为日常聊天对话模型,主要要求模型的中文理解能力,因此这里以CLUE(中文理解)数据进行评测:
#输出数据集清单
python tools/list_configs .py clue
#输出如下:
+-----------------------------+--------------------------------------------------
----------------------------+
| Dataset | Config Path
|
|-----------------------------+--------------------------------------------------
----------------------------|
| CLUE_C3_gen |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_gen.py |
| CLUE_C3_gen_8c358f |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_gen_8c358f.py |
| CLUE_C3_ppl |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl.py |
| CLUE_C3_ppl_56b537 |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl_56b537.py |
| CLUE_C3_ppl_e24a31 |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl_e24a31.py |
| CLUE_CMRC_gen |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py |
| CLUE_CMRC_gen_1bd3c8 |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_1bd3c8.py |
| CLUE_CMRC_gen_3749cd |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_3749cd.py |
| CLUE_CMRC_gen_8484b9 |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_8484b9.py |
| CLUE_CMRC_gen_941108 |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_941108.py |
| CLUE_DRCD_gen |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen.py |
| CLUE_DRCD_gen_1bd3c8 |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_1bd3c8.py |
| CLUE_DRCD_gen_3749cd |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_3749cd.py |
| CLUE_DRCD_gen_8484b9 |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_8484b9.py |
| CLUE_DRCD_gen_941108 |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_941108.py |
| CLUE_afqmc_gen |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_gen.py |
| CLUE_afqmc_gen_901306 |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_gen_901306.py |
| CLUE_afqmc_ppl |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl.py |
| CLUE_afqmc_ppl_378c5b |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_378c5b.py |
| CLUE_afqmc_ppl_6507d7 |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_6507d7.py |
| CLUE_afqmc_ppl_7b0c1e |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_7b0c1e.py |
| CLUE_cmnli_gen |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen.py |
| CLUE_cmnli_gen_1abf97 |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen_1abf97.py |
| CLUE_cmnli_gen_51e956 |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen_51e956.py |
| CLUE_cmnli_ppl |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl.py |
| CLUE_cmnli_ppl_98dd6e |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_98dd6e.py |
| CLUE_cmnli_ppl_ef69e7 |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_ef69e7.py |
| CLUE_cmnli_ppl_fdc6de |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_fdc6de.py |
| CLUE_ocnli_gen |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen.py |
| CLUE_ocnli_gen_51e956 |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_51e956.py |
| CLUE_ocnli_gen_c4cb6c |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_c4cb6c.py |
| CLUE_ocnli_ppl |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl.py |
| CLUE_ocnli_ppl_98dd6e |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_98dd6e.py |
| CLUE_ocnli_ppl_ef69e7 |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_ef69e7.py |
| CLUE_ocnli_ppl_fdc6de |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_fdc6de.py |
| FewCLUE_bustm_gen |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_gen.py |
| FewCLUE_bustm_gen_634f41 |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_gen_634f41.py |
| FewCLUE_bustm_ppl |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl.py |
| FewCLUE_bustm_ppl_4b16c0 |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_4b16c0.py |
| FewCLUE_bustm_ppl_9ef540 |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_9ef540.py |
| FewCLUE_bustm_ppl_e53034 |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_e53034.py |
| FewCLUE_chid_gen |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_gen.py |
| FewCLUE_chid_gen_0a29a2 |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_gen_0a29a2.py |
| FewCLUE_chid_ppl |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl.py |
| FewCLUE_chid_ppl_8f2872 |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl_8f2872.py |
| FewCLUE_chid_ppl_acccb5 |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl_acccb5.py |
| FewCLUE_cluewsc_gen |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_gen.py |
| FewCLUE_cluewsc_gen_c68933 |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_gen_c68933.py |
| FewCLUE_cluewsc_ppl |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl.py |
| FewCLUE_cluewsc_ppl_12e4e0 |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_12e4e0.py |
| FewCLUE_cluewsc_ppl_4284a0 |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_4284a0.py |
| FewCLUE_cluewsc_ppl_868415 |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_868415.py |
| FewCLUE_csl_gen |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen.py |
| FewCLUE_csl_gen_28b223 |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_28b223.py |
| FewCLUE_csl_gen_87f4a8 |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_87f4a8.py |
| FewCLUE_csl_ppl |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl.py |
| FewCLUE_csl_ppl_769f8d |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl_769f8d.py |
| FewCLUE_csl_ppl_841b62 |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl_841b62.py |
| FewCLUE_eprstmt_gen |
opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_gen.py |
| FewCLUE_eprstmt_gen_740ea0 |
opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_gen_740ea0.py |
| FewCLUE_eprstmt_ppl |
opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py |
| FewCLUE_eprstmt_ppl_1ce587 |
opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl_1ce587.py |
| FewCLUE_eprstmt_ppl_f1e631 |
opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl_f1e631.py |
| FewCLUE_ocnli_fc_gen |
opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen.py |
| FewCLUE_ocnli_fc_gen_f97a97 |
opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen_f97a97.py |
| FewCLUE_ocnli_fc_ppl |
opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl.py |
| FewCLUE_ocnli_fc_ppl_9e8b3d |
opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl_9e8b3d.py |
| FewCLUE_ocnli_fc_ppl_c08300 |
opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl_c08300.py |
| FewCLUE_tnews_gen |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen.py |
| FewCLUE_tnews_gen_b90e4a |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen_b90e4a.py |
| FewCLUE_tnews_ppl |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl.py |
| FewCLUE_tnews_ppl_7d1c07 |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_7d1c07.py |
| FewCLUE_tnews_ppl_d10e8a |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_d10e8a.py |
| FewCLUE_tnews_ppl_fff486 |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_fff486.py |
+-----------------------------+--------------------------------------------------
----------------------------+
当前任务大多是短语对话,可以选择FewCLUE_bustm_gen(短文本分类)、 FewCLUE_ocnli_fc_gen(自然语言推理)对预期模型进行评估。
python run .py \
--models hf_qwen1_5_0_5b_chat hf_qwen1_5_1_8b_chat \
--datasets FewCLUE_bustm_gen FewCLUE_ocnli_fc_gen \
--debug
根据评估结果,选择最终模型。
选用Qwen 1.5 1.8b
四、模型训练

1. 配置训练文件
### PART 1中
#预训练模型存放的位置
pretrained_model_name_or_path = 'model_path'#基座模型路径
#微调数据存放的位置
data_files = '/root/public/data/target_data.json'
# 训练中最大的文本长度
max_length = 512
# 每一批训练样本的大小
batch_size = 2
#最大训练轮数
max_epochs = 3
#验证数据
evaluation_inputs = [
'只剩一个心脏了还能活吗?', '爸爸再婚,我是不是就有了个新娘?',
'樟脑丸是我吃过最难吃的硬糖有奇怪的味道怎么还有人买','马上要上游泳课了,昨天洗的泳裤还没干,怎么办',
'我只出生了一次,为什么每年都要庆生'
]
# PART 3中
dataset=dict(type=load_dataset, path="json",data_files=data_files)
dataset_map_fn=None
# Copyright (c) OpenMMLab. All rights reserved.
# 导入必要的库
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (
CheckpointHook,
DistSamplerSeedHook,
IterTimerHook,
LoggerHook,
ParamSchedulerHook,
)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (
DatasetInfoHook,
EvaluateChatHook,
VarlenAttnArgsToMessageHubHook,
)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
#######################################################################
# PART 1 Settings #
#######################################################################
# 本部分定义了训练的核心超参数和路径设置
# Model
pretrained_model_name_or_path = "/hy-tmp/llm/Qwen/Qwen1.5-0.5B-Chat" # 预训练模型本地路径,使用0.5B版本的Qwen1.5-Chat
use_varlen_attn = False # 是否使用可变长度注意力机制
# Data
alpaca_en_path = "/hy-tmp/day15/data/output.json" # 训练数据文件路径(JSON格式),包含指令-输出对
prompt_template = PROMPT_TEMPLATE.qwen_chat # 使用Qwen聊天模型的对话模板
max_length = 250 # 单个样本的最大token长度,较长的长度允许处理更长的上下文
pack_to_max_length = True # 是否将多个短样本打包到max_length以提高训练效率
# parallel
sequence_parallel_size = 1 # 序列并行大小,1表示不使用序列并行
# Scheduler & Optimizer
batch_size = 35 # 每个GPU设备的批处理大小,设为1可能是因为序列较长(2048)导致显存需求大
accumulative_counts = 16 # 梯度累积步数,实际等效批大小 = batch_size * accumulative_counts = 16
accumulative_counts *= sequence_parallel_size # 序列并行时调整累积计数
dataloader_num_workers = 0 # 多线程
max_epochs = 500 # 最大训练轮数
optim_type = AdamW # 优化器类型
lr = 2e-4 # 学习率
betas = (0.9, 0.999) # AdamW优化器的动量参数
weight_decay = 0 # 权重衰减系数
max_norm = 1 # 梯度裁剪的最大范数
warmup_ratio = 0.03 # 学习率warmup阶段占总训练步数的比例
# Save
save_steps = 500 # 每隔500个训练步保存一次检查点
save_total_limit = 2 # 最多保留2个最新检查点
# ========== 新增:指定工作目录 ==========
work_dir = '/hy-tmp/xtuner/work_dirs/my_config/20060326_qwen1.5_0.8_chat'
# Evaluate the generation performance during the training
evaluation_freq = 200 # 每隔500个训练步进行一次生成效果评估
SYSTEM = SYSTEM_TEMPLATE.alpaca # 使用Alpaca格式的系统提示词
evaluation_inputs = ['老妈非让我嫁给她同事儿子,怎么逃啊!',
'同事抢功时故意提高音量,要当场揭穿吗?',
'音乐节踩到不明液体,鞋底黏糊一生阴影!',
'拍Vlog被猴抢包,素材变动物世界!'] # 评估时使用的固定输入问题
#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
# 本部分定义模型和分词器的配置
tokenizer = dict(
type=AutoTokenizer.from_pretrained, # 自动加载对应的分词器
pretrained_model_name_or_path=pretrained_model_name_or_path, # 与模型相同的路径
trust_remote_code=True, # 信任来自HuggingFace的远程代码(Qwen模型需要)
padding_side="right", # 在右侧进行填充,适用于自回归模型
)
model = dict(
type=SupervisedFinetune, # 监督微调包装器
use_varlen_attn=use_varlen_attn, # 不使用可变长度注意力
llm=dict(
type=AutoModelForCausalLM.from_pretrained, # 加载因果语言模型
pretrained_model_name_or_path=pretrained_model_name_or_path, # 0.5B参数的Qwen1.5-Chat模型
trust_remote_code=True, # 信任远程代码
torch_dtype=torch.float16, # 使用float16精度,减少显存占用
quantization_config=dict( # 4位量化配置(QLoRA)
type=BitsAndBytesConfig,
load_in_4bit=True, # 使用4位量化加载模型
load_in_8bit=False, # 不使用8位量化
llm_int8_threshold=6.0, # 8位量化阈值
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16, # 4位量化的计算数据类型
bnb_4bit_use_double_quant=True, # 使用双重量化
bnb_4bit_quant_type="nf4", # 4位标准化浮点量化
),
device_map="auto", # 自动将模型层分配到可用设备上,支持多GPU
),
lora=dict( # LoRA配置,用于参数高效微调
type=LoraConfig,
r=16, # LoRA秩,控制可训练参数量,16比32更小,适用于小模型
lora_alpha=32, # LoRA缩放因子
lora_dropout=0.1, # LoRA层的dropout率
bias="none", # 不训练偏置参数
task_type="CAUSAL_LM", # 任务类型为因果语言建模
),
)
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
# 本部分定义数据集和数据加载器的配置
# 数据集处理配置
alpaca_en = dict(
type=process_hf_dataset, # 使用XTuner的数据处理流程
dataset=dict(type=load_dataset, path="json", data_files=alpaca_en_path), # 加载JSON格式数据集
tokenizer=tokenizer, # 分词器
max_length=max_length, # 最大序列长度2048
dataset_map_fn=None, # 使用默认的Alpaca格式映射函数
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), # 应用对话模板
remove_unused_columns=True, # 移除原始数据中不需要的列
shuffle_before_pack=True, # 打包前打乱数据
pack_to_max_length=pack_to_max_length, # 将短样本打包到最大长度
use_varlen_attn=use_varlen_attn, # 不使用可变长度注意力
)
# 选择采样器:序列并行采样器或默认采样器
sampler = SequenceParallelSampler if sequence_parallel_size > 1 else DefaultSampler
# 训练数据加载器配置
train_dataloader = dict(
batch_size=batch_size, # 批大小1
num_workers=dataloader_num_workers, # 不使用多进程数据加载
dataset=alpaca_en, # 数据集配置
sampler=dict(type=sampler, shuffle=True), # 采样器,训练时打乱数据
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn), # 批处理函数
)
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# 本部分定义优化器和学习率调度器
# 优化器包装器配置
optim_wrapper = dict(
type=AmpOptimWrapper, # 自动混合精度优化器包装器
optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), # AdamW优化器
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), # 梯度裁剪
accumulative_counts=accumulative_counts, # 梯度累积16步
loss_scale="dynamic", # 动态损失缩放
dtype="float16", # 使用float16进行混合精度训练
)
# 学习率调度策略
# 更多信息: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md
param_scheduler = [
dict( # 线性warmup阶段
type=LinearLR,
start_factor=1e-5, # 起始学习率为lr的1e-5倍
by_epoch=True, # 按epoch调度
begin=0, # 从第0个epoch开始
end=warmup_ratio * max_epochs, # warmup结束于总epoch的3%处
convert_to_iter_based=True, # 转换为基于迭代的调度
),
dict( # 余弦退火阶段
type=CosineAnnealingLR,
eta_min=0.0, # 最小学习率为0
by_epoch=True,
begin=warmup_ratio * max_epochs, # warmup结束后开始
end=max_epochs, # 训练结束时结束
convert_to_iter_based=True,
),
]
# 训练配置
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) # 训练循环,最多3个epoch
#######################################################################
# PART 5 Runtime #
#######################################################################
# 本部分定义训练过程中的各种运行时设置和钩子
# 自定义钩子
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer), # 记录数据集信息
dict( # 评估对话钩子
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq, # 每500次迭代评估一次
evaluation_inputs=evaluation_inputs, # 评估问题列表
system=SYSTEM, # 系统提示
prompt_template=prompt_template, # 提示模板
),
]
# 如果使用可变长度注意力,添加额外的钩子
if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
# 配置默认钩子
default_hooks = dict(
timer=dict(type=IterTimerHook), # 记录每次迭代的时间
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), # 每10次迭代打印日志
param_scheduler=dict(type=ParamSchedulerHook), # 参数调度器钩子
checkpoint=dict( # 检查点保存钩子
type=CheckpointHook,
by_epoch=False, # 按迭代步数保存
interval=save_steps, # 每500步保存一次
max_keep_ckpts=save_total_limit, # 最多保留2个检查点
),
sampler_seed=dict(type=DistSamplerSeedHook), # 分布式环境下的采样器种子设置
)
# 环境配置
env_cfg = dict(
cudnn_benchmark=False, # 不启用cudnn benchmark
mp_cfg=dict(mp_start_method="fork", opencv_num_threads=0), # 多进程配置
dist_cfg=dict(backend="nccl"), # 分布式后端使用NCCL
)
# 可视化器(未使用)
visualizer = None
# 日志级别
log_level = "INFO" # 信息级别日志
# 从哪个检查点加载(None表示从头开始训练)
load_from = None
# 是否从加载的检查点恢复训练(为True但load_from为None时,会寻找最新的检查点)
resume = True # 设置为True,如果找到检查点会从中恢复训练状态
# 随机性设置
randomness = dict(seed=None, deterministic=False) # 随机种子为None(随机),不启用确定性算法
# 日志处理器设置
log_processor = dict(by_epoch=False) # 按迭代记录日志
2. 在当前目录下,输入以下命令启动微调脚本
#单机单卡
python -m xtuner.tools.train /hy-tmp/day15/my_config.py
#单机多卡
NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --
deepspeed deepspeed_zero2
根据显存的占用情况,调整训练的批次
开始训练后loss(损失)会慢慢下降
需要将loss训练到0.009,效果很好(loss到0.7左右就开始测了)

3.模型转换
模型训练后会自动保存成PTH 模型(例如iter_2000 .pth,如果使用了DeepSpeed,则将会是一个文件夹),我们需要利用xtuner convert pth_to_hf 将其转换为HuggingFace 模型,以便于后续使用。具体命令为:
python /hy-tmp/xtuner/xtuner/tools/model_converters/pth_to_hf.py \
/hy-tmp/day15/my_config.py \
/hy-tmp/xtuner/work_dirs/my_config/20060326_qwen1.5_0.8_chat/iter_1000.pth \
/hy-tmp/llm/Qwen/zhuaun_Xtuner_Qwen1.5_0.5b_chat
# 例如:xtuner convert pth_to_hf internlm2_chat_7b_qlora_custom_sft_e1_copy.py
#./iter_2000.pth ./iter_2000_
4.模型合并
如果使用了LoRA / QLoRA 微调,则模型转换后将得到adapter 参数,而并不包含原LLM 参数。如果您期望获得合并后的模型权重(例如用于后续评测),那么可以利用xtuner convert merge:
python xtuner/tools/model_converters/merge.py /hy-tmp/llm/Qwen/Qwen1.5-0.5B-Chat /hy-tmp/llm/Qwen/zhuaun_Xtuner_Qwen1.5_0.5b_chat /hy-tmp/llm/Qwen/he_xtuner_qwen1.5_0.5b_chat
合并之后的模型
五、模型推理部署
LMDeploy 支持两种添加对话模板的形式:
一种是利用现有对话模板,直接配置一个如下的 json 文件使用。
{
"model_name": "your awesome chat template name",
"system": "<|im_start|>system\n",
"meta_instruction": "You are a robot developed by LMDeploy.",
"eosys": "<|im_end|>\n",
"user": "<|im_start|>user\n",
"eoh": "<|im_end|>\n",
"assistant": "<|im_start|>assistant\n",
"eoa": "<|im_end|>",
"separator": "\n",
"capability": "chat",
"stop_words": ["<|im_end|>"]
}
model_name 为必填项,可以是LMDeploy 内置对话模板名(通过lmdeploy list 可查阅),也可以是新名字。其他字段可选填。当model_name 是内置对话模板名时,json文件中各非null字段会覆盖原有对话模板的对应属性。而当model_name 是新名字时,它会把将 BaseChatTemplate直接注册成新的对话模板。其具体定义可以参考BaseChatTemplate。
这样一个模板将会以下面的形式进行拼接。
{system}{meta_instruction}{eosys}{user}{user_content}{eoh}{assistant}
{assistant_content}{eoa}{separator}{user}...
在使用 CLI 工具时,可以通过 --chat-template 传入自定义对话模板,比如:
lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template
${JSON_FILE}
也可以在通过接口函数传入,比如:
from lmdeploy import ChatTemplateConfig, serve
serve('internlm/internlm2_5-7b-chat',
chat_template_config=ChatTemplateConfig.from_json('${JSON_FILE}'))
我采用的是以 LMDeploy 现有对话模板,自定义一个python对话模板类,注册成功后直接用即可。优 点是自定义程度高,可控性强。 下面是一个注册 LMDeploy 对话模板的例子

复制圈住的部分

粘贴代码到这个地方

对话模板转换脚本:完整代码
import re
import json
from typing import Dict, Any
def universal_converter(original_template: Dict[str, Any]) -> Dict[str, Any]:
"""将多种风格的原始模板转换为lmdeploy官方格式"""
# 字段映射关系(核心逻辑)
field_mapping = {
# 基础字段映射
"SYSTEM": "system",
"INSTRUCTION": ("user", "assistant"), # 需要拆分处理
"SUFFIX": "eoa",
"SEP": "separator",
"STOP_WORDS": "stop_words",
# 特殊处理字段
"SUFFIX_AS_EOS": None, # 该字段在官方模板中不需要
}
# 初始化目标模板(包含必填字段默认值)
converted = {
"meta_instruction": "You are a helpful assistant.", # 必填项
"capability": "chat", # 必填项
"eosys": "<|im_end|>\n", # 通常固定格式
"eoh": "<|im_end|>\n", # 通常固定格式
}
# 自动处理字段映射
for src_key, dest_key in field_mapping.items():
if src_key in original_template:
value = original_template[src_key]
# 处理需要拆分的字段(如INSTRUCTION)
if isinstance(dest_key, tuple) and src_key == "INSTRUCTION":
# 使用正则拆分user和assistant部分
parts = re.split(r'(\<\|im_start\|>assistant\n?)', value)
converted["user"] = parts[0].strip()
if len(parts) > 1:
converted["assistant"] = parts[1] + parts[2] if len(parts) > 2 else parts[1]
# 处理直接映射字段
elif dest_key and not isinstance(dest_key, tuple):
converted[dest_key] = value
# 特殊处理system字段的占位符
if "system" in converted:
converted["system"] = converted["system"].replace("{system}", "{{ system }}")
# 处理用户输入占位符
if "user" in converted:
converted["user"] = converted["user"].replace("{input}", "{{ input }}")
# 自动处理停止词(兼容列表和字符串)
if "stop_words" in converted and isinstance(converted["stop_words"], str):
converted["stop_words"] = [converted["stop_words"]]
# 保留原始模板中的额外字段(带警告)
for key in original_template:
if key not in field_mapping:
print(f"Warning: 发现未映射字段 [{key}], 已保留原样")
converted[key] = original_template[key]
return converted
# 示例用法
original_qwen_chat = dict(
SYSTEM="<|im_start|>system\n{system}<|im_end|>\n",
INSTRUCTION="<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n",
SUFFIX="<|im_end|>",
SUFFIX_AS_EOS=True,
SEP="\n",
STOP_WORDS=["<|im_end|>", "<|endoftext|>"]
)
# 执行转换
converted_template = universal_converter(original_qwen_chat)
# 生成JSON文件
with open('chat_template.json', 'w') as f:
json.dump(converted_template, f,
indent=2,
ensure_ascii=False,
separators=(',', ':'))
LMDeploy部署

选择合适的大模型推理框架部署模型(这里选择LMDeploy),须注意对话模板对其操作!
切换环境
conda activate lmdeploy

操作
这里我们选用CLI 工具推理,可以通过 --chat-template 传入自定义对话模板:
合并后的模型
lmdeploy serve api_server /hy-tmp/llm/Qwen/he_xtuner_qwen1.5_0.5b_chat --chat-template /hy-tmp/day15/chat_template.json

这里可以看到暴露出了 23333 这个端口,更改web代码中对应的端口即可

Streamlit
Streamlit 特点
易于使用:Streamlit 的设计哲学是简洁性,它允许用户通过简单的 Python 脚本快速构建 Web 应用。
即时预览:Streamlit 应用在编写代码时可以实时预览,这使得开发过程更加高效。
强大的交互性:Streamlit 提供了多种交互组件,如按钮、滑块、文本输入等,使得创建交互式应用变得简单。
数据可视化:支持多种图表库,如 Matplotlib、Altair 等,方便进行数据可视化。
适用于快速原型设计:适合快速开发和展示数据分析结果,特别适用于个人或小团队。
pip install streamlit
web界面代码
import streamlit as st
from openai import OpenAI
# ------------------------------
# 1. 页面配置与样式美化
# ------------------------------
st.set_page_config(
page_title="智能助手演示", # 浏览器标签页标题
page_icon="🤖", # 标签页图标
layout="wide", # 宽屏布局
initial_sidebar_state="expanded" # 侧边栏默认展开
)
# 注入自定义 CSS 样式(覆盖 Streamlit 默认风格)
st.markdown(
"""
<style>
/* 整体背景色 */
.stApp {
background-color: #f5f7fb;
}
/* 聊天消息容器 */
.stChatMessage {
background-color: transparent;
padding: 10px;
}
/* 用户消息气泡 */
.stChatMessage.user .stMarkdown {
background-color: #dcf8c6;
border-radius: 18px;
padding: 8px 16px;
display: inline-block;
max-width: 80%;
margin-left: auto;
text-align: left;
box-shadow: 0 1px 1px rgba(0,0,0,0.1);
}
/* 助手消息气泡 */
.stChatMessage.assistant .stMarkdown {
background-color: #ffffff;
border-radius: 18px;
padding: 8px 16px;
display: inline-block;
max-width: 80%;
margin-right: auto;
text-align: left;
box-shadow: 0 1px 1px rgba(0,0,0,0.1);
}
/* 输入框样式 */
.stChatInputContainer textarea {
border-radius: 24px;
border: 1px solid #e0e0e0;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
/* 侧边栏样式 */
.css-1d391kg {
background-color: #ffffff;
border-right: 1px solid #e9ecef;
}
/* 按钮样式 */
.stButton button {
border-radius: 20px;
background-color: #4CAF50;
color: white;
font-weight: 500;
transition: all 0.3s ease;
}
.stButton button:hover {
background-color: #45a049;
transform: translateY(-1px);
box-shadow: 0 2px 6px rgba(0,0,0,0.1);
}
/* 标题样式 */
.main-title {
font-size: 2.2rem;
font-weight: 600;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 0.5rem;
}
.subtitle {
color: #6c757d;
margin-bottom: 1rem;
font-size: 1rem;
}
/* 信息提示框 */
.info-box {
background-color: #e9ecef;
border-radius: 12px;
padding: 12px;
font-size: 0.9rem;
color: #495057;
margin-bottom: 1rem;
}
</style>
""",
unsafe_allow_html=True
)
# ------------------------------
# 2. 初始化 API 客户端
# ------------------------------
client = OpenAI(
base_url="http://localhost:23333/v1", # LMDeploy 服务地址
api_key="suibianxie" # 任意字符串,无实际校验
)
# ------------------------------
# 3. 侧边栏:模型信息与辅助功能
# ------------------------------
with st.sidebar:
st.markdown("## 🧠 智能助手")
st.markdown(
"""
<div class="info-box">
✅ **模型**:Qwen1.5-0.5B-Chat (情感微调版)<br>
🔗 **API**:LMDeploy 本地服务<br>
🎯 **模式**:单轮对话(每次仅基于当前问题回答)<br>
💬 **历史记录**:仅用于展示,不参与上下文
</div>
""",
unsafe_allow_html=True
)
# 清除历史按钮
if st.button("🗑️ 清空对话记录", use_container_width=True):
st.session_state.messages = []
st.rerun() # 刷新页面,清空显示
st.markdown("---")
st.markdown("### 📌 使用说明")
st.markdown(
"""
- 输入问题后按回车发送
- 输入 `exit` 可退出当前对话(仅退出,历史保留)
- 模型每次仅针对当前输入作答,不记忆前文
- 所有对话历史仅用于界面展示
"""
)
st.markdown("---")
st.markdown("💡 *提示:如需多轮对话,请修改 API 请求中的 `messages` 参数包含历史*")
# ------------------------------
# 4. 主界面标题
# ------------------------------
st.markdown('<div class="main-title">✨ 项目一效果演示</div>', unsafe_allow_html=True)
st.markdown('<div class="subtitle">基于 Qwen1.5-0.5B-Chat 的智能问答助手</div>', unsafe_allow_html=True)
# ------------------------------
# 5. 初始化对话历史(仅用于显示)
# ------------------------------
if "messages" not in st.session_state:
st.session_state.messages = []
# ------------------------------
# 6. 显示历史消息(自定义消息气泡)
# ------------------------------
for message in st.session_state.messages:
# 使用 streamlit 原生聊天组件,但通过 CSS 已美化
with st.chat_message(message["role"]):
st.markdown(message["content"])
# ------------------------------
# 7. 处理用户输入
# ------------------------------
# 获取用户输入(输入框位于底部)
prompt = st.chat_input("请输入您的问题,或输入 exit 退出...")
if prompt:
# 处理退出命令
if prompt.strip().lower() == "exit":
st.info("👋 您已退出当前会话。如需继续,请刷新页面或输入新问题。")
st.stop() # 停止当前脚本执行,不再调用模型
# 添加用户消息到显示历史
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# 调用 API(注意:这里只发送当前问题,无历史上下文)
try:
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], # 仅当前问题
model="/hy-tmp/llm/Qwen/he_xtuner_qwen1.5_0.5b_chat"
)
model_response = response.choices[0].message.content
# 添加助手回复到显示历史
st.session_state.messages.append({"role": "assistant", "content": model_response})
with st.chat_message("assistant"):
st.markdown(model_response)
except Exception as e:
st.error(f"❌ 发生错误:{e}")
# 可选:记录错误日志或提示重试
启动
使用下面命令启动streamlit web前端,测试最终效果:
streamlit run /hy-tmp/day15/chat_app.py

测试

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)