目录

一、模型部署的目的和用到的技术

二、系统架构设计

1.整体架构(c/s)

2.具体流程

三、服务端实现

1. 服务端功能概述

2.导入所需库

3.Flask应用初始化

4. 模型加载函数

5. 图像预处理函数

6. API路由定义

7. 服务启动

四、客户端

1. 客户端功能概述

2. 客户端代码

五、部署

1. 环境配置

2. 模型准备

3. 启动服务

4. 修改IP地址

5. 客户端调用


在深度学习领域,模型训练只是整个项目生命周期的一部分。如何将训练好的模型部署到生产环境中,让其他应用程序能够方便地调用,是许多开发者面临的挑战。本文将详细介绍如何使用Flask框架和PyTorch,构建一个完整的图像分类API服务,实现模型的在线部署与调用。

一、模型部署的目的和用到的技术

  • 提供实时预测服务:让模型能够处理实时请求
  • 实现模型复用:避免重复训练,提高开发效率
  • 支持多端调用:Web、移动端、桌面应用都可以通过API访问
  • 便于模型更新:只需更新服务端,客户端无需改动

需要用到的:

  • Flask:轻量级Web框架,适合快速构建API服务
  • PyTorch:深度学习框架,用于加载和运行模型
  • Requests:Python HTTP库,用于客户端请求
  • PIL:图像处理库
  • Torchvision:提供预训练模型和图像变换工具

二、系统架构设计

1.整体架构(c/s)

客户端 (Client) → HTTP请求 → Flask服务器(server) → PyTorch模型 → 预测结果
客户端 (Client) ← HTTP响应 ← Flask服务器 ← 返回结果

2.具体流程

  • 服务端启动:加载预训练模型,监听指定端口
  • 客户端请求:发送图像文件到服务器
  • 服务器处理:接收图像,预处理,模型推理
  • 结果返回:将预测结果封装成JSON格式返回
  • 客户端展示:解析结果并显示

三、服务端实现

1. 服务端功能概述

根据代码注释,服务端需要实现以下核心功能:

  • 接收来自客户端的信息,24小时运行
  • 将模型部署起来
  • 对图片进行识别
  • 将识别结果返回给客户端

2.导入所需库

import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models

io:处理字节流数据,用于图像读取
flask:Web框架,构建API服务
torch 和 torch.nn.functional:PyTorch核心库和函数式接口
PIL:Python图像处理库
torchvision:提供预训练模型和数据预处理工具

3.Flask应用初始化

app = flask.Flask(__name__)
model = None
use_gpu = False

app = flask.Flask(__name__):初始化Flask应用程序实例。__name__参数用于定位应用程序的根路径,这样Flask就可以知道在哪里找到模板、静态文件等,是Flask应用程序的起点,为后续添加路由、配置等奠定了基础。
model = None:全局变量,用于存储加载的模型
use_gpu = False:是否使用GPU进行推理

4. 模型加载函数

def load_model():
    """Load the pre-trained model, you can use your model just as easily."""
    global model
    # 加载resnet18网络
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 修改最后的类别数

    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    # 将模型指定为测试模式
    model.eval()

    # 是否使用gpu
    if use_gpu:
        model.cuda()

global model:声明使用全局变量model,在函数内部修改全局变量
models.resnet18():加载ResNet18预训练网络结构
model.fc.in_features:获取原全连接层的输入特征数
nn.Sequential(nn.Linear(num_ftrs, 102)):修改最后的全连接层,适配102类分类任务
torch.load('best.pth'):加载训练好的模型权重文件
model.load_state_dict(checkpoint['state_dict']):将权重加载到模型中
model.eval():将模型设置为评估模式,关闭Dropout和Batch Normalization的训练行为
model.cuda():如果use_gpu为True,将模型移动到GPU

5. 图像预处理函数

def prepare_image(image, target_size):
    if image.mode != 'RGB':
        image = image.convert('RGB')

    image = transforms.Resize(target_size)(image)
    image = transforms.ToTensor()(image)

    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    image = image[None]
    if use_gpu:
        image = image.cuda()

    return torch.tensor(image)

if image.mode != 'RGB':检查图像颜色模式,如果不是RGB则转换
transforms.Resize(target_size)(image):调整图像尺寸到目标大小(224×224)
transforms.ToTensor()(image):将PIL图像转换为PyTorch张量,并将像素值从[0,255]缩放到[0.0,1.0]
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(image):使用ImageNet数据集的均值和标准差进行标准化
image = image[None]:在第0维添加一个维度,将形状从(C,H,W)变为(1,C,H,W),满足模型输入的batch要求
return torch.tensor(image):返回处理后的张量

6. API路由定义

@app.route("/predict", methods=["POST"])
def predict():
    data = {"success": False}

    if flask.request.method == "POST":
        if flask.request.files.get("image"):
            # 读取图像文件
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image))

            # 预处理图像
            image = prepare_image(image, target_size=(224, 224))
            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())

            data['predictions'] = list()

            for prob, label in zip(results[0][0], results[1][0]):
                r = {'label': str(label), 'probability': float(prob)}
                data['predictions'].append(r)
            data['success'] = True

    return flask.jsonify(data)

@app.route("/predict", methods=["POST"]):装饰器,将URL路径"/predict"与predict函数关联,并指定只处理POST请求
data = {"success": False}:初始化返回数据,success默认为False
flask.request.files.get("image"):获取上传的文件,key为"image"
image = flask.request.files["image"].read():读取文件的二进制数据
Image.open(io.BytesIO(image)):将二进制数据转换为PIL图像
F.softmax(model(image), dim=1):模型推理后应用softmax函数,将输出转换为概率分布
torch.topk(preds.cpu().data, k=3, dim=1):获取概率最高的前3个类别及其概率
results[0].cpu().numpy():将概率值转换为numpy数组
results[1].cpu().numpy():将类别索引转换为numpy数组
循环遍历,将结果封装成字典列表
flask.jsonify(data):将字典转换为JSON格式的HTTP响应

7. 服务启动

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    load_model()  # 加载模型
    # 开启服务
    app.run(host='0.0.0.0', port=5010)

if __name__ == '__main__':判断是否直接运行此脚本

load_model():启动时加载模型

app.run(host='0.0.0.0', port=5010):启动Flask开发服务器,host='0.0.0.0'表示监听所有网络接口,允许局域网内的其他设备访问,port=5010指定端口号

四、客户端

1. 客户端功能概述

根据代码注释,客户端需要实现以下核心功能:

  1. 负责发送图片到指定的服务器

  2. 接收服务器端返回的结果信息

2. 客户端代码

import requests

flask_url = 'http://192.168.31.83:5010/predict'  # url和端口携程自己的本地ip

def predict_result(image_path):
    image = open(image_path, 'rb').read()
    payload = {'image': image}
    r = requests.post(flask_url, files=payload).json()
    # 向fLask_url服务发送一个P0ST请求,并尝试将返回的JSON响应解析为一个Python字典

    if r['success']:
        for (i, result) in enumerate(r['predictions']):
            print('{}.预测类别为{}:的概率{}'.format(i + 1, result['label'], result['probability']))
    else:
        print('Request failed')

if __name__ == '__main__':
    predict_result('flower.jpg')

import requests:导入requests库,用于发送HTTP请求

flask_url = 'http://192.168.31.83:5010/predict':服务器端API地址,需要根据实际情况修改IP和端口

open(image_path, 'rb').read():以二进制模式打开并读取图像文件

payload = {'image': image}:构建请求数据,键名"image"必须与服务器端flask.request.files.get("image")一致

requests.post(flask_url, files=payload).json():发送POST请求,files参数用于上传文件,.json()将响应解析为Python字典

if r['success']:检查服务器返回的success字段

enumerate(r['predictions']):遍历预测结果列表,同时获取索引和值

打印每个预测结果的类别和概率

五、部署

1. 环境配置

# 创建虚拟环境(可选)
python -m venv venv

# 激活虚拟环境(Windows)
venv\Scripts\activate

# 安装依赖
pip install flask torch torchvision pillow requests

2. 模型准备

确保在服务端目录下有训练好的模型文件 best.pth。根据代码,这个文件应该包含:

模型权重(通过checkpoint['state_dict']访问)

3. 启动服务

先执行server代码

4. 修改IP地址

客户端代码中的IP地址需要根据实际情况修改:

flask_url = 'http://你的实际IP:5010/predict'

5. 客户端调用

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐