训练数据集

对比上次有细微的改变

import json
import time
import random
from zhipuai import ZhipuAI
from sentence_transformers import SentenceTransformer
import numpy as np

"""
示例数据:
# 用户输入库(可自定义扩展)
    user_inputs = [
        "今天心情不太好", "推荐个电影吧", "怎么才能早睡早起",
        "养猫好还是养狗好", "工作压力好大", "最近总是失眠"
    ]
"""
# 初始化模型
client = ZhipuAI(api_key="2fe572be13854d3e9aabf4f7504bb4df.AZaoU7eDZjFLV4lC")  # 替换为你的API Key
#加载Embeddingmodel
style_model = SentenceTransformer(r"D:\two\第六期\L2\15-大模型微调项目实战-训练与部署\demo15\embedding_model\sungw111\text2vec-base-chinese-sentence")
test_text = "测试文本"
vector = style_model.encode(test_text)
norm = np.linalg.norm(vector)
print(f"向量模长:{norm:.4f}")

#===============================
#1.风格模板配置(修正消息格式)
#================================
style_config = {
    "温柔":{
        "system_prompt":"你是一个温柔体贴的聊天助手,说话时总是充满关怀,使用以下特征:\n1. 包含'呢、呀、啦'等语气词\n2. 使用🌸💖😊等温暖表情\n3. 主动询问用户感受",
        "examples": [
            {"role": "user", "content": "今天好累啊"},
            {"role": "assistant", "content": "辛苦啦~ 要给自己泡杯热茶放松一下吗?🌸"},
            {"role": "user", "content": "考试没考好..."},
            {"role": "assistant", "content": "没关系的呀~ 下次一定会更好!需要我陪你聊聊吗?😊"}
        ],
     "temperature": 0.3
    },
    "毒舌":{
        "system_prompt":"你是一个喜欢用犀利吐槽表达关心的朋友,需满足:\n1. 使用网络流行语(如'栓Q''退退退')\n2. 包含夸张比喻('你这速度堪比树懒')\n3. 结尾隐藏关心 \n4.需包含:'好家伙', '栓Q', '!', '🏋️'等",
        "examples": [
            {"role": "user", "content": "又胖了5斤!"},
            {"role": "assistant", "content": "好家伙!你这是要把体重秤压成分子料理?🏋️"},
            {"role": "user", "content": "游戏又输了"},
            {"role": "assistant", "content": "菜就多练练!需要给你推荐《从零开始的电竞之路》吗?🎮"}
        ],
     "temperature": 0.3
    },
}

#========================
#生成函数(修正消息的结构)
#========================

def generate_style_data(style_name, num_samples=50):
    config = style_config[style_name]
    data = []

    # 构建消息上下文(包含系统提示和示例对话)
    messages = [
        {"role": "system", "content": config["system_prompt"]},
        *config["examples"]  # 直接展开示例对话
    ]

    # 用户输入库(可自定义扩展)
    # user_inputs = [
    #     "今天心情不太好", "推荐个电影吧", "怎么才能早睡早起",
    #     "养猫好还是养狗好", "工作压力好大", "最近总是失眠"
    # ]
    # 从本地文件加载用户输入
    user_inputs = []
    with open('cleaned_output.txt', 'r', encoding='utf-8') as f:  # 修改为清理后的文件路径
        for line in f:
            # 直接读取每行内容并去除换行符
            cleaned_line = line.rstrip('\n')  # 或使用 line.strip()
            if cleaned_line:  # 空行过滤(冗余保护)
                user_inputs.append(cleaned_line)

    # 添加空值检查
    if not user_inputs:
        raise ValueError("文件内容为空或未成功加载数据,请检查:"
                         "1. 文件路径是否正确 2. 文件是否包含有效内容")

    # 初始化顺序索引
    current_index = 0  # 添加索引计数器
    for _ in range(num_samples):
        try:
            # 随机选择用户输入
            user_msg = random.choice(user_inputs)

            # 按顺序选择用户输入(修改核心部分)
            user_msg = user_inputs[current_index]
            current_index = (current_index + 1) % len(user_inputs)  # 循环计数

            # 添加当前用户消息
            current_messages = messages + [
                {"role": "user", "content": user_msg}
            ]

            # 调用API(修正模型名称)
            response = client.chat.completions.create(
                model="glm-3-turbo",
                messages=current_messages,
                temperature=config["temperature"],
                max_tokens=100
            )

            # 获取回复内容(修正访问路径)
            reply = response.choices[0].message.content

            # 质量过滤(数据审核)
            if is_valid_reply(style_name, user_msg, reply):
                data.append({
                    "user": user_msg,
                    "assistant": reply,
                    "style": style_name
                })

            time.sleep(0.5)  # 频率限制保护

        except Exception as e:
            print(f"生成失败:{str(e)}")

    return data

def is_valid_reply(style, user_msg, reply):
    """质量过滤规则(添加空值检查)"""
    # 基础检查
    if not reply or len(reply.strip()) == 0:
        print("内容为空!")
        return False

    # 规则1:回复长度检查
    if len(reply) < 5 or len(reply) > 150:
        print("长度不够!")
        return False

    # 规则2:风格关键词检查
    style_keywords = {
        "温柔": ["呢", "呀", "😊", "🌸"],
        "毒舌": ["好家伙", "栓Q", "!", "🏋️"]
    }
    if not any(kw in reply for kw in style_keywords.get(style, [])):
        print("不包含关键词!")
        return False

    # 规则3:语义相似度检查
    try:
        ref_text = next(msg["content"] for msg in style_config[style]["examples"]
                        if msg["role"] == "assistant")
        ref_vec = style_model.encode(ref_text)
        reply_vec = style_model.encode(reply)
        similarity = np.dot(ref_vec, reply_vec)
        # print("======>ref_vec",ref_vec)
        # print("======>reply_vec", reply_vec)
        print("======>similarity", similarity)
        return similarity > 0.65
    except:
        print("=========>相似度过低:", similarity)
        return False

#=============================
#3.执行生成(添加容错)
#============================
if __name__ == '__main__':
    all_data = []

    try:
        print("开始生成温柔风格数据...")
        gentle_data = generate_style_data("温柔", 10000)
        all_data.extend(gentle_data)

        print("开始生成毒舌风格数据...")
        sarcastic_data = generate_style_data("毒舌", 10000)
        all_data.extend(sarcastic_data)

    except KeyboardInterrupt:
        print("\n用户中断,保存已生成数据...")
    finally:
        with open("style_chat_data.json", "w", encoding="utf-8") as f:
            json.dump(all_data, f, ensure_ascii=False, indent=2)
        print(f"数据已保存,有效样本数:{len(all_data)}")

一、embedding_model验证与转换

通常embedding_model核心目录如下:

embedding_model通常情况是自带有归一化层的。一般可以通过Normalize目录或者modules.json判断是否准确包含归一化层。如果未包含归一化层,则需要手动转换添加:

词向量库未归化,会导致模型计算速度变慢,归化不影响模型精度

import numpy as np
from sentence_transformers import SentenceTransformer,models

model_path = r"D:\PycharmProjects\demo_15\embedding_model\sungw111\text2vec-base-chinese-sentence"
bert = models.Transformer(model_path)
pooling = models.Pooling(bert.get_word_embedding_dimension(),
                        pooling_mode='mean')

# 添加缺失的归一化层
normalize = models.Normalize()

# 组合完整模型
full_model = SentenceTransformer(modules=[bert, pooling, normalize])
print(full_model)

save_path=r"D:\PycharmProjects\demo_15\embedding_model\zy\text2vec-base-chinese-sentence"
full_model.save(save_path)

# 加载修复后的模型
model = SentenceTransformer(r"D:\PycharmProjects\demo_15\embedding_model\zy\text2vec-base-chinese-sentence")

# 验证向量归一化
text = "测试文本"
vec = model.encode(text)
print("修正后模长:", np.linalg.norm(vec))  # 应输出≈1.0

输出模长应该接近1

转换后目录如下:

转换后的modules.json如下

[
  {
    "idx": 0,
    "name": "0",
    "path": "",
    "type": "sentence_transformers.models.Transformer"
  },
  {
    "idx": 1,
    "name": "1",
    "path": "1_Pooling",
    "type": "sentence_transformers.models.Pooling"
  },
  {
    "idx": 2,
    "name": "2",
    "path": "2_Normalize",
    "type": "sentence_transformers.models.Normalize"
  }
]

加载转换验证后的embedding_model输出的相似度:

二、数据集转换

1. 选择模型微调工具,根据微调框架选择对应数据集格式(此处以xtuner为例,使xtuner默认格式)

数据集转换代码

import json

def convert_format(source_data):
    target_data = []
    for item in source_data:
        # 构建新的对话格式
        new_convo = {
            "conversation": [
                {
                    "input": item["user"],
                    "output": f"{item['style']}\n{item['assistant']}"
                }
            ]
        }
        target_data.append(new_convo)
    return target_data

# 从文件读取源数据
with open("input.json", "r", encoding="utf-8") as f:
    source_data = json.load(f)

# 执行转换
converted_data = convert_format(source_data)

# 写入目标文件
with open("output.json", "w", encoding="utf-8") as f:
    json.dump(converted_data, f, ensure_ascii=False, indent=2)

三、模型选型

1. 根据任务选择对应的评测数据,对预期模型客观评测

当前任务为日常聊天对话模型,主要要求模型的中文理解能力,因此这里以CLUE(中文理解)数据进行评测:

#输出数据集清单
python tools/list_configs .py clue
#输出如下:
+-----------------------------+--------------------------------------------------
----------------------------+
| Dataset                     | Config Path
|
|-----------------------------+--------------------------------------------------
----------------------------|
| CLUE_C3_gen                 |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_gen.py                          |
| CLUE_C3_gen_8c358f          |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_gen_8c358f.py                   |
| CLUE_C3_ppl                 |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl.py                          |
| CLUE_C3_ppl_56b537          |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl_56b537.py                   |
| CLUE_C3_ppl_e24a31          |
opencompass/configs/datasets/CLUE_C3/CLUE_C3_ppl_e24a31.py                   |
| CLUE_CMRC_gen               |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py                      |
| CLUE_CMRC_gen_1bd3c8        |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_1bd3c8.py               |
| CLUE_CMRC_gen_3749cd        |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_3749cd.py               |
| CLUE_CMRC_gen_8484b9        |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_8484b9.py               |
| CLUE_CMRC_gen_941108        |
opencompass/configs/datasets/CLUE_CMRC/CLUE_CMRC_gen_941108.py               |
| CLUE_DRCD_gen               |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen.py                      |
| CLUE_DRCD_gen_1bd3c8        |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_1bd3c8.py               |
| CLUE_DRCD_gen_3749cd        |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_3749cd.py               |
| CLUE_DRCD_gen_8484b9        |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_8484b9.py               |
| CLUE_DRCD_gen_941108        |
opencompass/configs/datasets/CLUE_DRCD/CLUE_DRCD_gen_941108.py               |
| CLUE_afqmc_gen              |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_gen.py                    |
| CLUE_afqmc_gen_901306       |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_gen_901306.py             |
| CLUE_afqmc_ppl              |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl.py                    |
| CLUE_afqmc_ppl_378c5b       |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_378c5b.py             |
| CLUE_afqmc_ppl_6507d7       |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_6507d7.py             |
| CLUE_afqmc_ppl_7b0c1e       |
opencompass/configs/datasets/CLUE_afqmc/CLUE_afqmc_ppl_7b0c1e.py             |
| CLUE_cmnli_gen              |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen.py                    |
| CLUE_cmnli_gen_1abf97       |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen_1abf97.py             |
| CLUE_cmnli_gen_51e956       |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_gen_51e956.py             |
| CLUE_cmnli_ppl              |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl.py                    |
| CLUE_cmnli_ppl_98dd6e       |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_98dd6e.py             |
| CLUE_cmnli_ppl_ef69e7       |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_ef69e7.py             |
| CLUE_cmnli_ppl_fdc6de       |
opencompass/configs/datasets/CLUE_cmnli/CLUE_cmnli_ppl_fdc6de.py             |
| CLUE_ocnli_gen              |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen.py                    |
| CLUE_ocnli_gen_51e956       |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_51e956.py             |
| CLUE_ocnli_gen_c4cb6c       |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_gen_c4cb6c.py             |
| CLUE_ocnli_ppl              |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl.py                    |
| CLUE_ocnli_ppl_98dd6e       |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_98dd6e.py             |
| CLUE_ocnli_ppl_ef69e7       |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_ef69e7.py             |
| CLUE_ocnli_ppl_fdc6de       |
opencompass/configs/datasets/CLUE_ocnli/CLUE_ocnli_ppl_fdc6de.py             |
| FewCLUE_bustm_gen           |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_gen.py              |
| FewCLUE_bustm_gen_634f41    |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_gen_634f41.py       |
| FewCLUE_bustm_ppl           |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl.py              |
| FewCLUE_bustm_ppl_4b16c0    |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_4b16c0.py       |
| FewCLUE_bustm_ppl_9ef540    |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_9ef540.py       |
| FewCLUE_bustm_ppl_e53034    |
opencompass/configs/datasets/FewCLUE_bustm/FewCLUE_bustm_ppl_e53034.py       |
| FewCLUE_chid_gen            |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_gen.py                |
| FewCLUE_chid_gen_0a29a2     |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_gen_0a29a2.py         |
| FewCLUE_chid_ppl            |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl.py                |
| FewCLUE_chid_ppl_8f2872     |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl_8f2872.py         |
| FewCLUE_chid_ppl_acccb5     |
opencompass/configs/datasets/FewCLUE_chid/FewCLUE_chid_ppl_acccb5.py         |
| FewCLUE_cluewsc_gen         |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_gen.py          |
| FewCLUE_cluewsc_gen_c68933  |
 opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_gen_c68933.py   |
| FewCLUE_cluewsc_ppl         |
opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl.py          |
| FewCLUE_cluewsc_ppl_12e4e0  |
 opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_12e4e0.py   |
| FewCLUE_cluewsc_ppl_4284a0  |
 opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_4284a0.py   |
| FewCLUE_cluewsc_ppl_868415  |
 opencompass/configs/datasets/FewCLUE_cluewsc/FewCLUE_cluewsc_ppl_868415.py   |
| FewCLUE_csl_gen             |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen.py                  |
| FewCLUE_csl_gen_28b223      |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_28b223.py           |
| FewCLUE_csl_gen_87f4a8      |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_gen_87f4a8.py           |
| FewCLUE_csl_ppl             |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl.py                  |
| FewCLUE_csl_ppl_769f8d      |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl_769f8d.py           |
| FewCLUE_csl_ppl_841b62      |
opencompass/configs/datasets/FewCLUE_csl/FewCLUE_csl_ppl_841b62.py           |
| FewCLUE_eprstmt_gen         |
opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_gen.py          |
| FewCLUE_eprstmt_gen_740ea0  |
 opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_gen_740ea0.py   |
| FewCLUE_eprstmt_ppl         |
opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py          |
| FewCLUE_eprstmt_ppl_1ce587  |
 opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl_1ce587.py   |
| FewCLUE_eprstmt_ppl_f1e631  |
 opencompass/configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl_f1e631.py   |
| FewCLUE_ocnli_fc_gen        |
opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen.py        |
| FewCLUE_ocnli_fc_gen_f97a97 |
 opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_gen_f97a97.py |
| FewCLUE_ocnli_fc_ppl        |
opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl.py        |
| FewCLUE_ocnli_fc_ppl_9e8b3d |
 opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl_9e8b3d.py |
| FewCLUE_ocnli_fc_ppl_c08300 |
 opencompass/configs/datasets/FewCLUE_ocnli_fc/FewCLUE_ocnli_fc_ppl_c08300.py |
| FewCLUE_tnews_gen           |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen.py              |
| FewCLUE_tnews_gen_b90e4a    |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_gen_b90e4a.py       |
| FewCLUE_tnews_ppl           |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl.py              |
| FewCLUE_tnews_ppl_7d1c07    |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_7d1c07.py       |
| FewCLUE_tnews_ppl_d10e8a    |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_d10e8a.py       |
| FewCLUE_tnews_ppl_fff486    |
opencompass/configs/datasets/FewCLUE_tnews/FewCLUE_tnews_ppl_fff486.py       |
+-----------------------------+--------------------------------------------------
----------------------------+

当前任务大多是短语对话,可以选择FewCLUE_bustm_gen(短文本分类)、 FewCLUE_ocnli_fc_gen(自然语言推理)对预期模型进行评估。

python run .py \
 --models hf_qwen1_5_0_5b_chat hf_qwen1_5_1_8b_chat \
 --datasets FewCLUE_bustm_gen FewCLUE_ocnli_fc_gen \
 --debug

根据评估结果,选择最终模型。

选用Qwen 1.5 1.8b

四、模型训练

1. 配置训练文件

### PART 1中

#预训练模型存放的位置
pretrained_model_name_or_path = 'model_path'#基座模型路径
#微调数据存放的位置

data_files = '/root/public/data/target_data.json'

# 训练中最大的文本长度

max_length = 512
# 每一批训练样本的大小

batch_size = 2
#最大训练轮数
max_epochs = 3

#验证数据
evaluation_inputs = [
    '只剩一个心脏了还能活吗?', '爸爸再婚,我是不是就有了个新娘?',
    '樟脑丸是我吃过最难吃的硬糖有奇怪的味道怎么还有人买','马上要上游泳课了,昨天洗的泳裤还没干,怎么办',
    '我只出生了一次,为什么每年都要庆生'
]

# PART 3中

dataset=dict(type=load_dataset, path="json",data_files=data_files)
dataset_map_fn=None
# Copyright (c) OpenMMLab. All rights reserved.
 
# 导入必要的库
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (
    CheckpointHook,
    DistSamplerSeedHook,
    IterTimerHook,
    LoggerHook,
    ParamSchedulerHook,
)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (
    DatasetInfoHook,
    EvaluateChatHook,
    VarlenAttnArgsToMessageHubHook,
)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
 
#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# 本部分定义了训练的核心超参数和路径设置
 
# Model
pretrained_model_name_or_path = "/hy-tmp/llm/Qwen/Qwen1.5-0.5B-Chat"  # 预训练模型本地路径,使用0.5B版本的Qwen1.5-Chat
use_varlen_attn = False  # 是否使用可变长度注意力机制
 
# Data
alpaca_en_path = "/hy-tmp/day15/data/output.json"  # 训练数据文件路径(JSON格式),包含指令-输出对
prompt_template = PROMPT_TEMPLATE.qwen_chat  # 使用Qwen聊天模型的对话模板
max_length = 250 # 单个样本的最大token长度,较长的长度允许处理更长的上下文
pack_to_max_length = True  # 是否将多个短样本打包到max_length以提高训练效率
 
# parallel
sequence_parallel_size = 1  # 序列并行大小,1表示不使用序列并行
 
# Scheduler & Optimizer
batch_size = 35  # 每个GPU设备的批处理大小,设为1可能是因为序列较长(2048)导致显存需求大
accumulative_counts = 16  # 梯度累积步数,实际等效批大小 = batch_size * accumulative_counts = 16
accumulative_counts *= sequence_parallel_size  # 序列并行时调整累积计数
dataloader_num_workers = 0  # 多线程
max_epochs = 500  # 最大训练轮数
optim_type = AdamW  # 优化器类型
lr = 2e-4  # 学习率
betas = (0.9, 0.999)  # AdamW优化器的动量参数
weight_decay = 0  # 权重衰减系数
max_norm = 1  # 梯度裁剪的最大范数
warmup_ratio = 0.03  # 学习率warmup阶段占总训练步数的比例
 
# Save
save_steps = 500  # 每隔500个训练步保存一次检查点
save_total_limit = 2  # 最多保留2个最新检查点
 
# ========== 新增:指定工作目录 ==========
work_dir = '/hy-tmp/xtuner/work_dirs/my_config/20060326_qwen1.5_0.8_chat'
 
# Evaluate the generation performance during the training
evaluation_freq = 200  # 每隔500个训练步进行一次生成效果评估
SYSTEM = SYSTEM_TEMPLATE.alpaca  # 使用Alpaca格式的系统提示词
evaluation_inputs = ['老妈非让我嫁给她同事儿子,怎么逃啊!',
                     '同事抢功时故意提高音量,要当场揭穿吗?',
                     '音乐节踩到不明液体,鞋底黏糊一生阴影!',
                     '拍Vlog被猴抢包,素材变动物世界!']  # 评估时使用的固定输入问题
 
#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
# 本部分定义模型和分词器的配置
 
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,  # 自动加载对应的分词器
    pretrained_model_name_or_path=pretrained_model_name_or_path,  # 与模型相同的路径
    trust_remote_code=True,  # 信任来自HuggingFace的远程代码(Qwen模型需要)
    padding_side="right",  # 在右侧进行填充,适用于自回归模型
)
 
model = dict(
    type=SupervisedFinetune,  # 监督微调包装器
    use_varlen_attn=use_varlen_attn,  # 不使用可变长度注意力
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,  # 加载因果语言模型
        pretrained_model_name_or_path=pretrained_model_name_or_path,  # 0.5B参数的Qwen1.5-Chat模型
        trust_remote_code=True,  # 信任远程代码
        torch_dtype=torch.float16,  # 使用float16精度,减少显存占用
        quantization_config=dict(  # 4位量化配置(QLoRA)
            type=BitsAndBytesConfig,
            load_in_4bit=True,  # 使用4位量化加载模型
            load_in_8bit=False,  # 不使用8位量化
            llm_int8_threshold=6.0,  # 8位量化阈值
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,  # 4位量化的计算数据类型
            bnb_4bit_use_double_quant=True,  # 使用双重量化
            bnb_4bit_quant_type="nf4",  # 4位标准化浮点量化
        ),
        device_map="auto",  # 自动将模型层分配到可用设备上,支持多GPU
    ),
    lora=dict(  # LoRA配置,用于参数高效微调
        type=LoraConfig,
        r=16,  # LoRA秩,控制可训练参数量,16比32更小,适用于小模型
        lora_alpha=32,  # LoRA缩放因子
        lora_dropout=0.1,  # LoRA层的dropout率
        bias="none",  # 不训练偏置参数
        task_type="CAUSAL_LM",  # 任务类型为因果语言建模
    ),
)
 
#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
# 本部分定义数据集和数据加载器的配置
 
# 数据集处理配置
alpaca_en = dict(
    type=process_hf_dataset,  # 使用XTuner的数据处理流程
    dataset=dict(type=load_dataset, path="json", data_files=alpaca_en_path),  # 加载JSON格式数据集
    tokenizer=tokenizer,  # 分词器
    max_length=max_length,  # 最大序列长度2048
    dataset_map_fn=None,  # 使用默认的Alpaca格式映射函数
    template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),  # 应用对话模板
    remove_unused_columns=True,  # 移除原始数据中不需要的列
    shuffle_before_pack=True,  # 打包前打乱数据
    pack_to_max_length=pack_to_max_length,  # 将短样本打包到最大长度
    use_varlen_attn=use_varlen_attn,  # 不使用可变长度注意力
)
 
# 选择采样器:序列并行采样器或默认采样器
sampler = SequenceParallelSampler if sequence_parallel_size > 1 else DefaultSampler
 
# 训练数据加载器配置
train_dataloader = dict(
    batch_size=batch_size,  # 批大小1
    num_workers=dataloader_num_workers,  # 不使用多进程数据加载
    dataset=alpaca_en,  # 数据集配置
    sampler=dict(type=sampler, shuffle=True),  # 采样器,训练时打乱数据
    collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn),  # 批处理函数
)
 
#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# 本部分定义优化器和学习率调度器
 
# 优化器包装器配置
optim_wrapper = dict(
    type=AmpOptimWrapper,  # 自动混合精度优化器包装器
    optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),  # AdamW优化器
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),  # 梯度裁剪
    accumulative_counts=accumulative_counts,  # 梯度累积16步
    loss_scale="dynamic",  # 动态损失缩放
    dtype="float16",  # 使用float16进行混合精度训练
)
 
# 学习率调度策略
# 更多信息: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md
param_scheduler = [
    dict(  # 线性warmup阶段
        type=LinearLR,
        start_factor=1e-5,  # 起始学习率为lr的1e-5倍
        by_epoch=True,  # 按epoch调度
        begin=0,  # 从第0个epoch开始
        end=warmup_ratio * max_epochs,  # warmup结束于总epoch的3%处
        convert_to_iter_based=True,  # 转换为基于迭代的调度
    ),
    dict(  # 余弦退火阶段
        type=CosineAnnealingLR,
        eta_min=0.0,  # 最小学习率为0
        by_epoch=True,
        begin=warmup_ratio * max_epochs,  # warmup结束后开始
        end=max_epochs,  # 训练结束时结束
        convert_to_iter_based=True,
    ),
]
 
# 训练配置
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)  # 训练循环,最多3个epoch
 
#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# 本部分定义训练过程中的各种运行时设置和钩子
 
# 自定义钩子
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),  # 记录数据集信息
    dict(  # 评估对话钩子
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,  # 每500次迭代评估一次
        evaluation_inputs=evaluation_inputs,  # 评估问题列表
        system=SYSTEM,  # 系统提示
        prompt_template=prompt_template,  # 提示模板
    ),
]
 
# 如果使用可变长度注意力,添加额外的钩子
if use_varlen_attn:
    custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
 
# 配置默认钩子
default_hooks = dict(
    timer=dict(type=IterTimerHook),  # 记录每次迭代的时间
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),  # 每10次迭代打印日志
    param_scheduler=dict(type=ParamSchedulerHook),  # 参数调度器钩子
    checkpoint=dict(  # 检查点保存钩子
        type=CheckpointHook,
        by_epoch=False,  # 按迭代步数保存
        interval=save_steps,  # 每500步保存一次
        max_keep_ckpts=save_total_limit,  # 最多保留2个检查点
    ),
    sampler_seed=dict(type=DistSamplerSeedHook),  # 分布式环境下的采样器种子设置
)
 
# 环境配置
env_cfg = dict(
    cudnn_benchmark=False,  # 不启用cudnn benchmark
    mp_cfg=dict(mp_start_method="fork", opencv_num_threads=0),  # 多进程配置
    dist_cfg=dict(backend="nccl"),  # 分布式后端使用NCCL
)
 
# 可视化器(未使用)
visualizer = None
 
# 日志级别
log_level = "INFO"  # 信息级别日志
 
# 从哪个检查点加载(None表示从头开始训练)
load_from = None
 
# 是否从加载的检查点恢复训练(为True但load_from为None时,会寻找最新的检查点)
resume = True  # 设置为True,如果找到检查点会从中恢复训练状态
 
# 随机性设置
randomness = dict(seed=None, deterministic=False)  # 随机种子为None(随机),不启用确定性算法
 
# 日志处理器设置
log_processor = dict(by_epoch=False)  # 按迭代记录日志

2. 在当前目录下,输入以下命令启动微调脚本

#单机单卡
python -m xtuner.tools.train /hy-tmp/day15/my_config.py

#单机多卡
NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --
deepspeed deepspeed_zero2

根据显存的占用情况,调整训练的批次

开始训练后loss(损失)会慢慢下降

需要将loss训练到0.009,效果很好(loss到0.7左右就开始测了)

3.模型转换

模型训练后会自动保存成PTH 模型(例如iter_2000 .pth,如果使用了DeepSpeed,则将会是一个文件夹),我们需要利用xtuner convert pth_to_hf 将其转换为HuggingFace 模型,以便于后续使用。具体命令为:

python /hy-tmp/xtuner/xtuner/tools/model_converters/pth_to_hf.py \
    /hy-tmp/day15/my_config.py \
    /hy-tmp/xtuner/work_dirs/my_config/20060326_qwen1.5_0.8_chat/iter_1000.pth \
    /hy-tmp/llm/Qwen/zhuaun_Xtuner_Qwen1.5_0.5b_chat
# 例如:xtuner convert pth_to_hf internlm2_chat_7b_qlora_custom_sft_e1_copy.py 
#./iter_2000.pth ./iter_2000_

4.模型合并

如果使用了LoRA / QLoRA 微调,则模型转换后将得到adapter 参数,而并不包含原LLM 参数。如果您期望获得合并后的模型权重(例如用于后续评测),那么可以利用xtuner convert merge

python xtuner/tools/model_converters/merge.py /hy-tmp/llm/Qwen/Qwen1.5-0.5B-Chat /hy-tmp/llm/Qwen/zhuaun_Xtuner_Qwen1.5_0.5b_chat /hy-tmp/llm/Qwen/he_xtuner_qwen1.5_0.5b_chat

合并之后的模型

五、模型推理部署

LMDeploy 支持两种添加对话模板的形式:

一种是利用现有对话模板,直接配置一个如下的 json 文件使用。

{
    "model_name": "your awesome chat template name",
    "system": "<|im_start|>system\n",
    "meta_instruction": "You are a robot developed by LMDeploy.",
    "eosys": "<|im_end|>\n",
    "user": "<|im_start|>user\n",
    "eoh": "<|im_end|>\n",
    "assistant": "<|im_start|>assistant\n",
    "eoa": "<|im_end|>",
    "separator": "\n",
    "capability": "chat",
    "stop_words": ["<|im_end|>"]
}

model_name 为必填项,可以是LMDeploy 内置对话模板名(通过lmdeploy list 可查阅),也可以是新名字。其他字段可选填。当model_name 是内置对话模板名时,json文件中各非null字段会覆盖原有对话模板的对应属性。而当model_name 是新名字时,它会把将 BaseChatTemplate直接注册成新的对话模板。其具体定义可以参考BaseChatTemplate

这样一个模板将会以下面的形式进行拼接。

{system}{meta_instruction}{eosys}{user}{user_content}{eoh}{assistant}
{assistant_content}{eoa}{separator}{user}...

在使用 CLI 工具时,可以通过 --chat-template 传入自定义对话模板,比如:

lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template 
${JSON_FILE}

也可以在通过接口函数传入,比如:

from lmdeploy import ChatTemplateConfig, serve
serve('internlm/internlm2_5-7b-chat',
      chat_template_config=ChatTemplateConfig.from_json('${JSON_FILE}'))

我采用的是以 LMDeploy 现有对话模板,自定义一个python对话模板类,注册成功后直接用即可。优 点是自定义程度高,可控性强。 下面是一个注册 LMDeploy 对话模板的例子

复制圈住的部分

粘贴代码到这个地方

对话模板转换脚本:完整代码

import re
import json
from typing import Dict, Any
 
 
def universal_converter(original_template: Dict[str, Any]) -> Dict[str, Any]:
    """将多种风格的原始模板转换为lmdeploy官方格式"""
    # 字段映射关系(核心逻辑)
    field_mapping = {
        # 基础字段映射
        "SYSTEM": "system",
        "INSTRUCTION": ("user", "assistant"),  # 需要拆分处理
        "SUFFIX": "eoa",
        "SEP": "separator",
        "STOP_WORDS": "stop_words",
        # 特殊处理字段
        "SUFFIX_AS_EOS": None,  # 该字段在官方模板中不需要
    }
 
    # 初始化目标模板(包含必填字段默认值)
    converted = {
        "meta_instruction": "You are a helpful assistant.",  # 必填项
        "capability": "chat",  # 必填项
        "eosys": "<|im_end|>\n",  # 通常固定格式
        "eoh": "<|im_end|>\n",  # 通常固定格式
    }
 
    # 自动处理字段映射
    for src_key, dest_key in field_mapping.items():
        if src_key in original_template:
            value = original_template[src_key]
            # 处理需要拆分的字段(如INSTRUCTION)
            if isinstance(dest_key, tuple) and src_key == "INSTRUCTION":
                # 使用正则拆分user和assistant部分
                parts = re.split(r'(\<\|im_start\|>assistant\n?)', value)
                converted["user"] = parts[0].strip()
                if len(parts) > 1:
                    converted["assistant"] = parts[1] + parts[2] if len(parts) > 2 else parts[1]
            # 处理直接映射字段
            elif dest_key and not isinstance(dest_key, tuple):
                converted[dest_key] = value
 
    # 特殊处理system字段的占位符
    if "system" in converted:
        converted["system"] = converted["system"].replace("{system}", "{{ system }}")
 
    # 处理用户输入占位符
    if "user" in converted:
        converted["user"] = converted["user"].replace("{input}", "{{ input }}")
 
    # 自动处理停止词(兼容列表和字符串)
    if "stop_words" in converted and isinstance(converted["stop_words"], str):
        converted["stop_words"] = [converted["stop_words"]]
 
    # 保留原始模板中的额外字段(带警告)
    for key in original_template:
        if key not in field_mapping:
            print(f"Warning: 发现未映射字段 [{key}], 已保留原样")
            converted[key] = original_template[key]
 
    return converted
 
 
# 示例用法
original_qwen_chat = dict(
    SYSTEM="<|im_start|>system\n{system}<|im_end|>\n",
    INSTRUCTION="<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n",
    SUFFIX="<|im_end|>",
    SUFFIX_AS_EOS=True,
    SEP="\n",
    STOP_WORDS=["<|im_end|>", "<|endoftext|>"]
)
 
# 执行转换
converted_template = universal_converter(original_qwen_chat)
 
# 生成JSON文件
with open('chat_template.json', 'w') as f:
    json.dump(converted_template, f,
              indent=2,
              ensure_ascii=False,
              separators=(',', ':'))

LMDeploy部署

选择合适的大模型推理框架部署模型(这里选择LMDeploy,须注意对话模板对其操作!

切换环境

conda activate lmdeploy

操作

这里我们选用CLI 工具推理,可以通过 --chat-template 传入自定义对话模板:

合并后的模型

lmdeploy serve api_server /hy-tmp/llm/Qwen/he_xtuner_qwen1.5_0.5b_chat --chat-template /hy-tmp/day15/chat_template.json

这里可以看到暴露出了 23333 这个端口,更改web代码中对应的端口即可

Streamlit

Streamlit 简介
Streamlit 是一个基于 Python 的开源 Web 应用程序框架,
专为数据科学家和机器学习工程师设计,用于快速创建数据可视
化应用。它的主要优势在于简单性和快速迭代,用户可以使用几
行代码创建交互式应用,无需深入了解前端技术如 HTML、CSS 等。

Streamlit 特点

易于使用:Streamlit 的设计哲学是简洁性,它允许用户通过简单的 Python 脚本快速构建 Web 应用。

即时预览:Streamlit 应用在编写代码时可以实时预览,这使得开发过程更加高效。

强大的交互性:Streamlit 提供了多种交互组件,如按钮、滑块、文本输入等,使得创建交互式应用变得简单。

数据可视化:支持多种图表库,如 Matplotlib、Altair 等,方便进行数据可视化。

适用于快速原型设计:适合快速开发和展示数据分析结果,特别适用于个人或小团队。

环境
pip install streamlit

web界面代码

import streamlit as st
from openai import OpenAI
 
# ------------------------------
# 1. 页面配置与样式美化
# ------------------------------
st.set_page_config(
    page_title="智能助手演示",          # 浏览器标签页标题
    page_icon="🤖",                    # 标签页图标
    layout="wide",                    # 宽屏布局
    initial_sidebar_state="expanded"  # 侧边栏默认展开
)
 
# 注入自定义 CSS 样式(覆盖 Streamlit 默认风格)
st.markdown(
    """
    <style>
    /* 整体背景色 */
    .stApp {
        background-color: #f5f7fb;
    }
    
    /* 聊天消息容器 */
    .stChatMessage {
        background-color: transparent;
        padding: 10px;
    }
    
    /* 用户消息气泡 */
    .stChatMessage.user .stMarkdown {
        background-color: #dcf8c6;
        border-radius: 18px;
        padding: 8px 16px;
        display: inline-block;
        max-width: 80%;
        margin-left: auto;
        text-align: left;
        box-shadow: 0 1px 1px rgba(0,0,0,0.1);
    }
    
    /* 助手消息气泡 */
    .stChatMessage.assistant .stMarkdown {
        background-color: #ffffff;
        border-radius: 18px;
        padding: 8px 16px;
        display: inline-block;
        max-width: 80%;
        margin-right: auto;
        text-align: left;
        box-shadow: 0 1px 1px rgba(0,0,0,0.1);
    }
    
    /* 输入框样式 */
    .stChatInputContainer textarea {
        border-radius: 24px;
        border: 1px solid #e0e0e0;
        box-shadow: 0 2px 4px rgba(0,0,0,0.05);
    }
    
    /* 侧边栏样式 */
    .css-1d391kg {
        background-color: #ffffff;
        border-right: 1px solid #e9ecef;
    }
    
    /* 按钮样式 */
    .stButton button {
        border-radius: 20px;
        background-color: #4CAF50;
        color: white;
        font-weight: 500;
        transition: all 0.3s ease;
    }
    .stButton button:hover {
        background-color: #45a049;
        transform: translateY(-1px);
        box-shadow: 0 2px 6px rgba(0,0,0,0.1);
    }
    
    /* 标题样式 */
    .main-title {
        font-size: 2.2rem;
        font-weight: 600;
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        margin-bottom: 0.5rem;
    }
    .subtitle {
        color: #6c757d;
        margin-bottom: 1rem;
        font-size: 1rem;
    }
    
    /* 信息提示框 */
    .info-box {
        background-color: #e9ecef;
        border-radius: 12px;
        padding: 12px;
        font-size: 0.9rem;
        color: #495057;
        margin-bottom: 1rem;
    }
    </style>
    """,
    unsafe_allow_html=True
)
 
# ------------------------------
# 2. 初始化 API 客户端
# ------------------------------
client = OpenAI(
    base_url="http://localhost:23333/v1",  # LMDeploy 服务地址
    api_key="suibianxie"                    # 任意字符串,无实际校验
)
 
# ------------------------------
# 3. 侧边栏:模型信息与辅助功能
# ------------------------------
with st.sidebar:
    st.markdown("## 🧠 智能助手")
    st.markdown(
        """
        <div class="info-box">
        ✅ **模型**:Qwen1.5-0.5B-Chat (情感微调版)<br>
        🔗 **API**:LMDeploy 本地服务<br>
        🎯 **模式**:单轮对话(每次仅基于当前问题回答)<br>
        💬 **历史记录**:仅用于展示,不参与上下文
        </div>
        """,
        unsafe_allow_html=True
    )
    
    # 清除历史按钮
    if st.button("🗑️ 清空对话记录", use_container_width=True):
        st.session_state.messages = []
        st.rerun()          # 刷新页面,清空显示
    
    st.markdown("---")
    st.markdown("### 📌 使用说明")
    st.markdown(
        """
        - 输入问题后按回车发送
        - 输入 `exit` 可退出当前对话(仅退出,历史保留)
        - 模型每次仅针对当前输入作答,不记忆前文
        - 所有对话历史仅用于界面展示
        """
    )
    st.markdown("---")
    st.markdown("💡 *提示:如需多轮对话,请修改 API 请求中的 `messages` 参数包含历史*")
 
# ------------------------------
# 4. 主界面标题
# ------------------------------
st.markdown('<div class="main-title">✨ 项目一效果演示</div>', unsafe_allow_html=True)
st.markdown('<div class="subtitle">基于 Qwen1.5-0.5B-Chat  的智能问答助手</div>', unsafe_allow_html=True)
 
# ------------------------------
# 5. 初始化对话历史(仅用于显示)
# ------------------------------
if "messages" not in st.session_state:
    st.session_state.messages = []
 
# ------------------------------
# 6. 显示历史消息(自定义消息气泡)
# ------------------------------
for message in st.session_state.messages:
    # 使用 streamlit 原生聊天组件,但通过 CSS 已美化
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
 
# ------------------------------
# 7. 处理用户输入
# ------------------------------
# 获取用户输入(输入框位于底部)
prompt = st.chat_input("请输入您的问题,或输入 exit 退出...")
 
if prompt:
    # 处理退出命令
    if prompt.strip().lower() == "exit":
        st.info("👋 您已退出当前会话。如需继续,请刷新页面或输入新问题。")
        st.stop()   # 停止当前脚本执行,不再调用模型
 
    # 添加用户消息到显示历史
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)
 
    # 调用 API(注意:这里只发送当前问题,无历史上下文)
    try:
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],  # 仅当前问题
            model="/hy-tmp/llm/Qwen/he_xtuner_qwen1.5_0.5b_chat"
        )
        model_response = response.choices[0].message.content
 
        # 添加助手回复到显示历史
        st.session_state.messages.append({"role": "assistant", "content": model_response})
        with st.chat_message("assistant"):
            st.markdown(model_response)
 
    except Exception as e:
        st.error(f"❌ 发生错误:{e}")
        # 可选:记录错误日志或提示重试

启动

使用下面命令启动streamlit web前端,测试最终效果:

streamlit run /hy-tmp/day15/chat_app.py

测试

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐