|
import os
import logging
import sys
import threading
import time
from contextlib import asynccontextmanager
from io import BytesIO
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
# ===================== 离线部署(必须开启) =====================
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HOME"] = "/home/models"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ===================== 日志配置 =====================
logger = logging.getLogger("qwen_vl_api")
logger.setLevel(logging.INFO)
# ===================== 模型路径 =====================
MODEL_DIR = "/home/models/Qwen2.5-VL-32B-Instruct"
processor = None
model = None
# ===================== 心跳日志(防止加载时误以为卡死) =====================
def heartbeat(phase):
stop = threading.Event()
def run():
n = 0
while not stop.wait(15):
n += 1
logger.info(f"【心跳】{phase} 运行中 {n}次")
t = threading.Thread(target=run, daemon=True)
t.start()
return stop
# ===================== 加载模型 =====================
def load_model():
global processor, model
logger.info("开始加载 Processor")
stop = heartbeat("加载Processor")
processor = AutoProcessor.from_pretrained(MODEL_DIR, trust_remote_code=True)
stop.set()
logger.info("开始加载模型")
stop = heartbeat("加载模型")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_DIR,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
stop.set()
logger.info("模型加载完成")
# ===================== 服务生命周期 =====================
@asynccontextmanager
async def lifespan(app: FastAPI):
load_model()
logger.info("服务启动完成")
yield
app = FastAPI(title="Qwen2.5-VL 多模态API", lifespan=lifespan)
# ===================== 工具函数 =====================
def load_image(raw: bytes):
try:
return Image.open(BytesIO(raw)).convert("RGB")
except:
raise HTTPException(status_code=400, detail="图片格式错误")
# ===================== 健康检查 =====================
@app.get("/health")
def health():
return {"status": "running", "model_ready": model is not None}
# ===================== 推理接口 =====================
@app.post("/predict")
async def predict(
file: UploadFile = File(...),
prompt: str = Form(default="请描述图片内容")
):
if not model or not processor:
raise HTTPException(status_code=500, detail="模型未加载")
raw = await file.read()
image = load_image(raw)
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": prompt},
]}
]
prompt_text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[prompt_text],
images=[image],
return_tensors="pt"
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.7,
use_cache=True
)
answer = processor.decode(
outputs[0][len(inputs.input_ids[0]):],
skip_special_tokens=True
)
return {"code": 0, "message": "success", "answer": answer} |
所有评论(0)