08-Flask RESTful API设计与实现深度解析
·
Flask RESTful API设计与实现深度解析
一、概述
RESTful API是现代Web应用的核心接口,Flask通过Flask-RESTful扩展提供了强大的API开发能力。本文将深入解析RESTful设计原则、资源路由、请求解析、响应格式、错误处理、API文档、认证授权以及性能优化,帮助开发者构建规范、高效、易维护的API系统。
二、RESTful设计原则
2.1 REST架构约束
"""
REST (Representational State Transfer) 架构风格核心约束:
1. 客户端-服务器分离 (Client-Server)
- 关注点分离,提高可移植性
2. 无状态 (Stateless)
- 每个请求包含所有必要信息
- 服务器不存储客户端上下文
3. 可缓存 (Cacheable)
- 响应必须标识是否可缓存
- 提高网络效率
4. 统一接口 (Uniform Interface)
- 资源标识 (URI)
- 通过表述操作资源
- 自描述消息
- HATEOAS (超媒体作为应用状态引擎)
5. 分层系统 (Layered System)
- 客户端无法判断连接的是终端服务器还是中间层
6. 按需代码 (Code on Demand) - 可选
- 服务器可以扩展客户端功能
"""
# RESTful API设计规范
class RESTfulGuidelines:
"""RESTful API设计指南"""
# URL设计规范
URL_RULES = {
# 使用名词而非动词
'correct': [
'GET /users', # 获取用户列表
'GET /users/1', # 获取单个用户
'POST /users', # 创建用户
'PUT /users/1', # 更新用户(完整)
'PATCH /users/1', # 更新用户(部分)
'DELETE /users/1', # 删除用户
],
'incorrect': [
'GET /getUsers',
'POST /createUser',
'POST /deleteUser/1',
],
# 使用复数形式
'plural': [
'/users', '/posts', '/comments', '/products'
],
'singular': [
'/user', '/post', '/comment', '/product'
],
# 表达资源关系
'relationships': [
'GET /users/1/posts', # 用户的文章列表
'GET /users/1/posts/2', # 用户的特定文章
'POST /users/1/posts', # 为用户创建文章
'GET /posts/2/comments', # 文章的评论列表
],
# 过滤、排序、分页
'query_params': [
'GET /users?status=active&role=admin', # 过滤
'GET /users?sort=created_at:desc', # 排序
'GET /users?page=1&per_page=20', # 分页
'GET /users?fields=id,name,email', # 字段选择
'GET /users?embed=posts,comments', # 嵌入关联资源
],
}
# HTTP方法语义
HTTP_METHODS = {
'GET': {
'description': '获取资源',
'safe': True, # 是否安全(不修改资源)
'idempotent': True, # 是否幂等
'cacheable': True, # 是否可缓存
},
'POST': {
'description': '创建资源',
'safe': False,
'idempotent': False,
'cacheable': False,
},
'PUT': {
'description': '完整更新资源',
'safe': False,
'idempotent': True,
'cacheable': False,
},
'PATCH': {
'description': '部分更新资源',
'safe': False,
'idempotent': True, # 实现正确时幂等
'cacheable': False,
},
'DELETE': {
'description': '删除资源',
'safe': False,
'idempotent': True,
'cacheable': False,
},
'HEAD': {
'description': '获取资源元数据',
'safe': True,
'idempotent': True,
'cacheable': True,
},
'OPTIONS': {
'description': '获取资源支持的HTTP方法',
'safe': True,
'idempotent': True,
'cacheable': False,
},
}
# HTTP状态码
STATUS_CODES = {
# 2xx 成功
200: 'OK - 请求成功',
201: 'Created - 资源创建成功',
202: 'Accepted - 请求已接受,处理中',
204: 'No Content - 成功但无返回内容',
# 3xx 重定向
301: 'Moved Permanently - 资源已永久移动',
304: 'Not Modified - 资源未修改(缓存有效)',
# 4xx 客户端错误
400: 'Bad Request - 请求格式错误',
401: 'Unauthorized - 未认证',
403: 'Forbidden - 无权限',
404: 'Not Found - 资源不存在',
405: 'Method Not Allowed - 方法不允许',
406: 'Not Acceptable - 无法生成请求的媒体类型',
409: 'Conflict - 资源冲突',
410: 'Gone - 资源已删除',
415: 'Unsupported Media Type - 不支持的媒体类型',
422: 'Unprocessable Entity - 语义错误',
429: 'Too Many Requests - 请求过多',
# 5xx 服务器错误
500: 'Internal Server Error - 服务器内部错误',
501: 'Not Implemented - 功能未实现',
502: 'Bad Gateway - 网关错误',
503: 'Service Unavailable - 服务不可用',
504: 'Gateway Timeout - 网关超时',
}
2.2 RESTful API架构图
三、Flask-RESTful基础
3.1 扩展配置
from flask import Flask
from flask_restful import Api, Resource, fields, marshal_with, reqparse
app = Flask(__name__)
# Flask-RESTful配置
app.config['RESTFUL_JSON'] = {
'ensure_ascii': False, # 支持中文
'indent': 2, # 格式化输出
'sort_keys': True, # 排序键
}
# 初始化API
api = Api(app, prefix='/api/v1', catch_all_404s=True)
# 或使用工厂模式
def create_api(app):
api = Api(app)
api.prefix = '/api/v1'
return api
3.2 资源类定义
from flask_restful import Resource, fields, marshal_with, reqparse, abort
from models import User, db
from flask_jwt_extended import jwt_required, get_jwt_identity
# 请求解析器
user_parser = reqparse.RequestParser()
user_parser.add_argument('username', type=str, required=True,
help='用户名不能为空', location='json')
user_parser.add_argument('email', type=str, required=True,
help='邮箱不能为空', location='json')
user_parser.add_argument('password', type=str, required=True,
help='密码不能为空', location='json')
user_parser.add_argument('name', type=str, location='json')
user_parser.add_argument('avatar', type=str, location='json')
# 更新解析器(部分字段可选)
user_update_parser = reqparse.RequestParser()
user_update_parser.add_argument('name', type=str, location='json')
user_update_parser.add_argument('avatar', type=str, location='json')
user_update_parser.add_argument('bio', type=str, location='json')
# 查询解析器
user_query_parser = reqparse.RequestParser()
user_query_parser.add_argument('page', type=int, default=1, location='args')
user_query_parser.add_argument('per_page', type=int, default=20, location='args')
user_query_parser.add_argument('status', type=str, choices=['active', 'inactive'],
location='args')
user_query_parser.add_argument('sort', type=str, default='created_at:desc',
location='args')
user_query_parser.add_argument('search', type=str, location='args')
# 序列化字段
user_fields = {
'id': fields.Integer,
'username': fields.String,
'email': fields.String,
'name': fields.String,
'avatar': fields.String(attribute='avatar_url'),
'bio': fields.String,
'active': fields.Boolean,
'created_at': fields.DateTime(dt_format='iso8601'),
'updated_at': fields.DateTime(dt_format='iso8601'),
'_links': fields.Nested({
'self': fields.Url('api.user', absolute=True),
'posts': fields.Url('api.user_posts', absolute=True),
})
}
user_list_fields = {
'items': fields.List(fields.Nested(user_fields)),
'pagination': fields.Nested({
'page': fields.Integer,
'per_page': fields.Integer,
'total': fields.Integer,
'pages': fields.Integer,
'has_next': fields.Boolean,
'has_prev': fields.Boolean,
}),
'_links': fields.Nested({
'self': fields.Url('api.users', absolute=True),
'next': fields.String,
'prev': fields.String,
})
}
class UserListResource(Resource):
"""用户列表资源"""
method_decorators = [jwt_required()] # 类级别装饰器
@marshal_with(user_list_fields)
def get(self):
"""获取用户列表"""
args = user_query_parser.parse_args()
# 构建查询
query = User.query
# 过滤
if args['status']:
query = query.filter(User.active == (args['status'] == 'active'))
# 搜索
if args['search']:
search_term = f"%{args['search']}%"
query = query.filter(
(User.username.ilike(search_term)) |
(User.email.ilike(search_term)) |
(User.name.ilike(search_term))
)
# 排序
sort_field, sort_order = args['sort'].split(':')
sort_column = getattr(User, sort_field, User.created_at)
if sort_order == 'desc':
sort_column = sort_column.desc()
query = query.order_by(sort_column)
# 分页
pagination = query.paginate(
page=args['page'],
per_page=min(args['per_page'], 100), # 限制最大每页数量
error_out=False
)
# 构建响应
result = {
'items': pagination.items,
'pagination': {
'page': pagination.page,
'per_page': pagination.per_page,
'total': pagination.total,
'pages': pagination.pages,
'has_next': pagination.has_next,
'has_prev': pagination.has_prev,
},
'_links': {
'self': None, # 由marshal_with填充
'next': url_for('api.users', page=pagination.next_num,
_external=True) if pagination.has_next else None,
'prev': url_for('api.users', page=pagination.prev_num,
_external=True) if pagination.has_prev else None,
}
}
return result
@marshal_with(user_fields)
def post(self):
"""创建用户"""
args = user_parser.parse_args()
# 检查用户名是否已存在
if User.query.filter_by(username=args['username']).first():
abort(409, message='用户名已存在')
# 检查邮箱是否已存在
if User.query.filter_by(email=args['email']).first():
abort(409, message='邮箱已存在')
# 创建用户
user = User(
username=args['username'],
email=args['email'],
name=args.get('name'),
avatar=args.get('avatar'),
)
user.set_password(args['password'])
db.session.add(user)
db.session.commit()
return user, 201
class UserResource(Resource):
"""单个用户资源"""
method_decorators = [jwt_required()]
@marshal_with(user_fields)
def get(self, user_id):
"""获取单个用户"""
user = User.query.get_or_404(user_id)
return user
@marshal_with(user_fields)
def put(self, user_id):
"""完整更新用户"""
current_user_id = get_jwt_identity()
# 权限检查
if current_user_id != user_id:
abort(403, message='无权修改其他用户信息')
user = User.query.get_or_404(user_id)
args = user_parser.parse_args()
# 更新所有字段
user.username = args['username']
user.email = args['email']
user.name = args.get('name')
user.avatar = args.get('avatar')
if args['password']:
user.set_password(args['password'])
db.session.commit()
return user
@marshal_with(user_fields)
def patch(self, user_id):
"""部分更新用户"""
current_user_id = get_jwt_identity()
if current_user_id != user_id:
abort(403, message='无权修改其他用户信息')
user = User.query.get_or_404(user_id)
args = user_update_parser.parse_args()
# 只更新提供的字段
if args['name'] is not None:
user.name = args['name']
if args['avatar'] is not None:
user.avatar = args['avatar']
if args['bio'] is not None:
user.bio = args['bio']
db.session.commit()
return user
def delete(self, user_id):
"""删除用户"""
current_user_id = get_jwt_identity()
if current_user_id != user_id:
abort(403, message='无权删除其他用户')
user = User.query.get_or_404(user_id)
db.session.delete(user)
db.session.commit()
return '', 204
# 注册资源路由
api.add_resource(UserListResource, '/users', endpoint='users')
api.add_resource(UserResource, '/users/<int:user_id>', endpoint='user')
3.3 嵌套资源
# 文章序列化字段
post_fields = {
'id': fields.Integer,
'title': fields.String,
'content': fields.String,
'summary': fields.String,
'status': fields.String,
'view_count': fields.Integer,
'created_at': fields.DateTime(dt_format='iso8601'),
'updated_at': fields.DateTime(dt_format='iso8601'),
'author': fields.Nested(user_fields, attribute='user'),
'_links': fields.Nested({
'self': fields.Url('api.post', absolute=True),
'author': fields.Url('api.user', absolute=True),
'comments': fields.Url('api.post_comments', absolute=True),
})
}
# 文章解析器
post_parser = reqparse.RequestParser()
post_parser.add_argument('title', type=str, required=True, location='json')
post_parser.add_argument('content', type=str, required=True, location='json')
post_parser.add_argument('summary', type=str, location='json')
post_parser.add_argument('status', type=str, choices=['draft', 'published', 'archived'],
default='draft', location='json')
class UserPostsResource(Resource):
"""用户的文章列表"""
method_decorators = [jwt_required()]
@marshal_with(fields.List(fields.Nested(post_fields)))
def get(self, user_id):
"""获取用户的文章列表"""
user = User.query.get_or_404(user_id)
return user.posts.filter(Post.status == 'published').all()
@marshal_with(post_fields)
def post(self, user_id):
"""为用户创建文章"""
current_user_id = get_jwt_identity()
if current_user_id != user_id:
abort(403, message='无权为其他用户创建文章')
user = User.query.get_or_404(user_id)
args = post_parser.parse_args()
post = Post(
title=args['title'],
content=args['content'],
summary=args.get('summary'),
status=args['status'],
user_id=user_id
)
db.session.add(post)
db.session.commit()
return post, 201
# 注册嵌套资源路由
api.add_resource(UserPostsResource, '/users/<int:user_id>/posts', endpoint='user_posts')
四、请求解析与验证
4.1 RequestParser详解
from flask_restful import reqparse
# 基本用法
parser = reqparse.RequestParser()
# 参数类型
parser.add_argument('name', type=str) # 字符串
parser.add_argument('age', type=int) # 整数
parser.add_argument('price', type=float) # 浮点数
parser.add_argument('active', type=bool) # 布尔值
parser.add_argument('tags', type=list, location='json') # 列表
parser.add_argument('settings', type=dict, location='json') # 字典
# 参数位置
parser.add_argument('query_param', location='args') # URL查询参数
parser.add_argument('form_param', location='form') # 表单数据
parser.add_argument('json_param', location='json') # JSON请求体
parser.add_argument('header_param', location='headers') # 请求头
parser.add_argument('cookie_param', location='cookies') # Cookie
parser.add_argument('file_param', location='files') # 上传文件
# 必填验证
parser.add_argument('required_field', required=True,
help='此字段为必填项')
# 默认值
parser.add_argument('page', type=int, default=1)
parser.add_argument('per_page', type=int, default=20)
# 选择验证
parser.add_argument('status', type=str,
choices=['active', 'inactive', 'pending'],
help='状态必须是 active, inactive 或 pending')
# 正则验证
import re
parser.add_argument('phone', type=str,
help='请输入有效的手机号码')
# 自定义类型验证
def email_type(value):
"""邮箱类型验证"""
if '@' not in value:
raise ValueError('请输入有效的邮箱地址')
return value
parser.add_argument('email', type=email_type)
# 自定义验证函数
def validate_password(value):
"""密码验证"""
if len(value) < 8:
raise ValueError('密码长度至少8位')
if not re.search(r'[A-Z]', value):
raise ValueError('密码必须包含大写字母')
if not re.search(r'[a-z]', value):
raise ValueError('密码必须包含小写字母')
if not re.search(r'\d', value):
raise ValueError('密码必须包含数字')
return value
parser.add_argument('password', type=validate_password)
# 忽略未知参数
parser.add_argument('known_param', type=str)
args = parser.parse_args(strict=True) # 严格模式,拒绝未知参数
# 继承解析器
base_parser = reqparse.RequestParser()
base_parser.add_argument('page', type=int, default=1)
base_parser.add_argument('per_page', type=int, default=20)
search_parser = base_parser.copy()
search_parser.add_argument('keyword', type=str)
search_parser.add_argument('sort', type=str)
4.2 高级验证
from werkzeug.datastructures import FileStorage
from flask_restful import inputs
# 使用内置验证器
parser = reqparse.RequestParser()
# 日期验证
parser.add_argument('date', type=inputs.date,
help='日期格式: YYYY-MM-DD')
# 日期时间验证
parser.add_argument('datetime', type=inputs.datetime_from_iso8601,
help='日期时间格式: ISO 8601')
# URL验证
parser.add_argument('website', type=inputs.url,
help='请输入有效的URL')
# 正则验证
parser.add_argument('username', type=inputs.regex('^[a-zA-Z][a-zA-Z0-9_]{2,19}$'),
help='用户名必须以字母开头,3-20个字符,只能包含字母、数字和下划线')
# 布尔值验证(支持多种格式)
parser.add_argument('active', type=inputs.boolean,
help='请输入 true 或 false')
# 整数范围验证
parser.add_argument('rating', type=inputs.int_range(1, 5),
help='评分必须在1-5之间')
# 文件上传验证
def file_type(allowed_extensions):
"""文件类型验证器"""
def validate(file):
if not isinstance(file, FileStorage):
raise ValueError('请上传文件')
filename = file.filename
if '.' not in filename:
raise ValueError('文件必须有扩展名')
ext = filename.rsplit('.', 1)[1].lower()
if ext not in allowed_extensions:
raise ValueError(f'只允许以下文件类型: {", ".join(allowed_extensions)}')
return file
return validate
parser.add_argument('avatar', type=file_type(['jpg', 'jpeg', 'png', 'gif']),
location='files')
# 列表验证
def list_of_ints(value):
"""整数列表验证"""
if isinstance(value, str):
value = value.split(',')
return [int(x) for x in value]
parser.add_argument('ids', type=list_of_ints, location='args')
# JSON Schema验证
import jsonschema
def validate_json_schema(schema):
"""JSON Schema验证器"""
def validate(value):
try:
jsonschema.validate(value, schema)
except jsonschema.ValidationError as e:
raise ValueError(f'JSON验证失败: {e.message}')
return value
return validate
address_schema = {
'type': 'object',
'properties': {
'province': {'type': 'string'},
'city': {'type': 'string'},
'address': {'type': 'string'},
'postal_code': {'type': 'string', 'pattern': r'^\d{6}$'}
},
'required': ['province', 'city', 'address']
}
parser.add_argument('address', type=validate_json_schema(address_schema),
location='json')
五、响应格式与序列化
5.1 序列化字段详解
from flask_restful import fields, marshal, marshal_with
# 基本字段类型
user_fields = {
'id': fields.Integer, # 整数
'username': fields.String, # 字符串
'email': fields.String, # 字符串
'active': fields.Boolean, # 布尔值
'balance': fields.Float, # 浮点数
'price': fields.Fixed(decimals=2), # 定点数(保留小数位)
'created_at': fields.DateTime(dt_format='iso8601'), # 日期时间
'birthday': fields.Date, # 日期
'login_time': fields.Time, # 时间
'settings': fields.Raw, # 原始数据(不转换)
'count': fields.Integer(default=0), # 默认值
}
# 格式化字段
formatted_fields = {
'price': fields.Price, # 价格格式化
'amount': fields.FormattedString('¥{0:.2f}'), # 自定义格式化
'percentage': fields.FormattedString('{0:.1%}'), # 百分比格式
}
# 属性映射
user_fields = {
'id': fields.Integer,
'name': fields.String(attribute='username'), # 映射到username属性
'avatar_url': fields.String(attribute='avatar'), # 重命名属性
'full_name': fields.String(attribute='get_full_name'), # 映射到方法
}
# 嵌套字段
post_fields = {
'id': fields.Integer,
'title': fields.String,
'content': fields.String,
'author': fields.Nested(user_fields, attribute='user'), # 嵌套对象
'tags': fields.List(fields.String), # 字符串列表
'comments': fields.List(fields.Nested(comment_fields)), # 对象列表
}
# URL字段
post_fields = {
'id': fields.Integer,
'title': fields.String,
'url': fields.Url('api.post', absolute=True), # 绝对URL
'relative_url': fields.Url('api.post', absolute=False), # 相对URL
}
# 自定义字段
class EnumField(fields.Raw):
"""枚举字段"""
def __init__(self, enum_class, **kwargs):
super().__init__(**kwargs)
self.enum_class = enum_class
def format(self, value):
if isinstance(value, self.enum_class):
return value.value
return value
class FileSizeField(fields.Raw):
"""文件大小格式化字段"""
def format(self, value):
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if value < 1024:
return f'{value:.2f} {unit}'
value /= 1024
return f'{value:.2f} PB'
class MaskedEmailField(fields.Raw):
"""脱敏邮箱字段"""
def format(self, value):
if not value or '@' not in value:
return value
local, domain = value.split('@', 1)
masked_local = local[0] + '***' + local[-1] if len(local) > 2 else local[0] + '***'
return f'{masked_local}@{domain}'
# 使用自定义字段
user_fields = {
'id': fields.Integer,
'email': MaskedEmailField(),
'file_size': FileSizeField(attribute='file_size_bytes'),
'status': EnumField(UserStatus),
}
5.2 响应构建
from flask import jsonify, make_response
from flask_restful import marshal
class ResponseBuilder:
"""响应构建器"""
@staticmethod
def success(data=None, message='操作成功', status_code=200):
"""成功响应"""
response = {
'success': True,
'message': message,
}
if data is not None:
response['data'] = data
return response, status_code
@staticmethod
def error(message='操作失败', status_code=400, errors=None):
"""错误响应"""
response = {
'success': False,
'message': message,
}
if errors:
response['errors'] = errors
return response, status_code
@staticmethod
def paginated(items, pagination, fields):
"""分页响应"""
return {
'items': marshal(items, fields),
'pagination': {
'page': pagination.page,
'per_page': pagination.per_page,
'total': pagination.total,
'pages': pagination.pages,
'has_next': pagination.has_next,
'has_prev': pagination.has_prev,
},
'_links': {
'self': url_for(request.endpoint, page=pagination.page, **request.args),
'next': url_for(request.endpoint, page=pagination.next_num, **request.args) if pagination.has_next else None,
'prev': url_for(request.endpoint, page=pagination.prev_num, **request.args) if pagination.has_prev else None,
}
}
@staticmethod
def created(data, location=None):
"""创建成功响应"""
response = make_response(jsonify(data), 201)
if location:
response.headers['Location'] = location
return response
@staticmethod
def no_content():
"""无内容响应"""
return '', 204
# 使用示例
class PostResource(Resource):
@marshal_with(post_fields)
def get(self, post_id):
post = Post.query.get_or_404(post_id)
return ResponseBuilder.success(post)
def delete(self, post_id):
post = Post.query.get_or_404(post_id)
db.session.delete(post)
db.session.commit()
return ResponseBuilder.no_content()
六、错误处理
6.1 统一错误处理
from flask_restful import abort
from werkzeug.exceptions import HTTPException
import traceback
# 自定义异常
class APIError(Exception):
"""API基础异常"""
def __init__(self, message, status_code=400, errors=None):
super().__init__(message)
self.message = message
self.status_code = status_code
self.errors = errors
class ValidationError(APIError):
"""验证错误"""
def __init__(self, message='数据验证失败', errors=None):
super().__init__(message, 422, errors)
class NotFoundError(APIError):
"""资源不存在"""
def __init__(self, message='资源不存在'):
super().__init__(message, 404)
class ForbiddenError(APIError):
"""禁止访问"""
def __init__(self, message='无权访问'):
super().__init__(message, 403)
class ConflictError(APIError):
"""资源冲突"""
def __init__(self, message='资源冲突'):
super().__init__(message, 409)
# 错误处理器
@app.errorhandler(APIError)
def handle_api_error(error):
"""处理API异常"""
response = {
'success': False,
'message': error.message,
}
if error.errors:
response['errors'] = error.errors
return jsonify(response), error.status_code
@app.errorhandler(HTTPException)
def handle_http_error(error):
"""处理HTTP异常"""
response = {
'success': False,
'message': error.description,
'status_code': error.code,
}
return jsonify(response), error.code
@app.errorhandler(Exception)
def handle_generic_error(error):
"""处理通用异常"""
# 记录错误日志
app.logger.error(f'Unhandled exception: {str(error)}\n{traceback.format_exc()}')
response = {
'success': False,
'message': '服务器内部错误',
'status_code': 500,
}
# 开发环境返回详细错误信息
if app.debug:
response['debug'] = {
'type': type(error).__name__,
'message': str(error),
'traceback': traceback.format_exc().split('\n')
}
return jsonify(response), 500
# 数据库错误处理
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
@app.errorhandler(IntegrityError)
def handle_integrity_error(error):
"""处理数据库完整性错误"""
response = {
'success': False,
'message': '数据完整性错误',
}
# 解析错误信息
error_msg = str(error.orig)
if 'unique constraint' in error_msg.lower():
response['message'] = '数据已存在,请检查唯一字段'
elif 'foreign key constraint' in error_msg.lower():
response['message'] = '关联数据不存在'
return jsonify(response), 409
@app.errorhandler(SQLAlchemyError)
def handle_sqlalchemy_error(error):
"""处理数据库错误"""
app.logger.error(f'Database error: {str(error)}')
return jsonify({
'success': False,
'message': '数据库操作失败'
}), 500
6.2 验证错误处理
from marshmallow import ValidationError as MarshmallowValidationError
class ValidationErrorHandler:
"""验证错误处理器"""
@staticmethod
def format_errors(errors):
"""格式化错误信息"""
formatted = {}
for field, messages in errors.items():
if isinstance(messages, list):
formatted[field] = messages
elif isinstance(messages, dict):
formatted[field] = ValidationErrorHandler.format_errors(messages)
else:
formatted[field] = [str(messages)]
return formatted
@staticmethod
def handle_wtforms_errors(form):
"""处理WTForms验证错误"""
errors = {}
for field_name, field in form._fields.items():
if field.errors:
errors[field_name] = field.errors
return errors
@staticmethod
def handle_marshmallow_errors(error):
"""处理Marshmallow验证错误"""
return ValidationErrorHandler.format_errors(error.messages)
# 使用示例
@app.errorhandler(MarshmallowValidationError)
def handle_marshmallow_validation_error(error):
"""处理Marshmallow验证错误"""
return jsonify({
'success': False,
'message': '数据验证失败',
'errors': ValidationErrorHandler.format_errors(error.messages)
}), 422
七、API文档
7.1 Flask-RESTX Swagger文档
from flask_restx import Api, Resource, fields, Namespace
# 初始化API文档
api = Api(
app,
version='1.0',
title='Flask API',
description='Flask RESTful API文档',
doc='/api/docs', # 文档URL
prefix='/api/v1',
authorizations={
'Bearer': {
'type': 'apiKey',
'in': 'header',
'name': 'Authorization',
'description': 'JWT Token格式: Bearer {token}'
}
},
security='Bearer'
)
# 创建命名空间
ns_users = Namespace('users', description='用户相关操作')
ns_posts = Namespace('posts', description='文章相关操作')
api.add_namespace(ns_users, path='/users')
api.add_namespace(ns_posts, path='/posts')
# 定义模型
user_model = api.model('User', {
'id': fields.Integer(description='用户ID'),
'username': fields.String(description='用户名', required=True),
'email': fields.String(description='邮箱', required=True),
'name': fields.String(description='姓名'),
'avatar': fields.String(description='头像URL'),
'active': fields.Boolean(description='是否激活'),
'created_at': fields.DateTime(description='创建时间'),
})
user_create_model = api.model('UserCreate', {
'username': fields.String(description='用户名', required=True, min_length=3, max_length=20),
'email': fields.String(description='邮箱', required=True),
'password': fields.String(description='密码', required=True, min_length=8),
'name': fields.String(description='姓名'),
})
user_update_model = api.model('UserUpdate', {
'name': fields.String(description='姓名'),
'avatar': fields.String(description='头像URL'),
'bio': fields.String(description='个人简介'),
})
pagination_model = api.model('Pagination', {
'page': fields.Integer(description='当前页'),
'per_page': fields.Integer(description='每页数量'),
'total': fields.Integer(description='总数'),
'pages': fields.Integer(description='总页数'),
'has_next': fields.Boolean(description='是否有下一页'),
'has_prev': fields.Boolean(description='是否有上一页'),
})
user_list_model = api.model('UserList', {
'items': fields.List(fields.Nested(user_model)),
'pagination': fields.Nested(pagination_model),
})
error_model = api.model('Error', {
'success': fields.Boolean(description='是否成功'),
'message': fields.String(description='错误信息'),
'errors': fields.Raw(description='详细错误'),
})
@ns_users.route('')
class UserList(Resource):
@ns_users.doc('list_users')
@ns_users.expect(ns_users.parser().add_argument('page', type=int, default=1, help='页码')
.add_argument('per_page', type=int, default=20, help='每页数量')
.add_argument('search', type=str, help='搜索关键词'))
@ns_users.marshal_with(user_list_model)
@ns_users.response(200, '成功')
def get(self):
"""获取用户列表"""
pass
@ns_users.doc('create_user')
@ns_users.expect(user_create_model, validate=True)
@ns_users.marshal_with(user_model, code=201)
@ns_users.response(400, '参数错误', error_model)
@ns_users.response(409, '用户已存在', error_model)
def post(self):
"""创建用户"""
pass
@ns_users.route('/<int:user_id>')
@ns_users.response(404, '用户不存在', error_model)
class UserResource(Resource):
@ns_users.doc('get_user')
@ns_users.marshal_with(user_model)
def get(self, user_id):
"""获取单个用户"""
pass
@ns_users.doc('update_user')
@ns_users.expect(user_update_model, validate=True)
@ns_users.marshal_with(user_model)
def put(self, user_id):
"""更新用户"""
pass
@ns_users.doc('delete_user')
@ns_users.response(204, '删除成功')
def delete(self, user_id):
"""删除用户"""
pass
7.2 API文档架构图
八、认证与授权
8.1 JWT认证
from flask_jwt_extended import (
JWTManager, create_access_token, create_refresh_token,
jwt_required, get_jwt_identity, get_jwt,
set_access_cookies, set_refresh_cookies,
unset_jwt_cookies, get_csrf_token
)
# JWT配置
app.config['JWT_SECRET_KEY'] = 'your-jwt-secret-key'
app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(hours=1)
app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(days=30)
app.config['JWT_TOKEN_LOCATION'] = ['headers', 'cookies']
app.config['JWT_COOKIE_SECURE'] = True
app.config['JWT_COOKIE_CSRF_PROTECT'] = True
app.config['JWT_CSRF_IN_COOKIES'] = True
jwt = JWTManager(app)
# JWT回调
@jwt.user_identity_loader
def user_identity_lookup(user):
"""用户身份标识"""
return user.id
@jwt.user_lookup_loader
def user_lookup_callback(_jwt_header, jwt_data):
"""用户查找回调"""
identity = jwt_data['sub']
return User.query.get(identity)
@jwt.expired_token_loader
def expired_token_callback(jwt_header, jwt_payload):
"""Token过期回调"""
return jsonify({
'success': False,
'message': 'Token已过期',
'error': 'token_expired'
}), 401
@jwt.invalid_token_loader
def invalid_token_callback(error_string):
"""无效Token回调"""
return jsonify({
'success': False,
'message': '无效的Token',
'error': 'invalid_token'
}), 401
@jwt.unauthorized_loader
def missing_token_callback(error_string):
"""缺少Token回调"""
return jsonify({
'success': False,
'message': '缺少认证Token',
'error': 'authorization_required'
}), 401
@jwt.revoked_token_loader
def revoked_token_callback(jwt_header, jwt_payload):
"""Token被撤销回调"""
return jsonify({
'success': False,
'message': 'Token已被撤销',
'error': 'token_revoked'
}), 401
# Token黑名单
app.config['JWT_BLACKLIST_ENABLED'] = True
app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'refresh']
class TokenBlacklist(db.Model):
"""Token黑名单"""
__tablename__ = 'token_blacklist'
id = Column(Integer, primary_key=True)
jti = Column(String(120), unique=True, nullable=False)
token_type = Column(String(10), nullable=False)
user_id = Column(Integer, nullable=False)
revoked = Column(Boolean, default=False)
expires_at = Column(DateTime)
@jwt.token_in_blocklist_loader
def check_if_token_revoked(jwt_header, jwt_payload):
"""检查Token是否被撤销"""
jti = jwt_payload['jti']
token = TokenBlacklist.query.filter_by(jti=jti).first()
return token is not None and token.revoked
# 登录端点
@ns_auth.route('/login')
class Login(Resource):
@ns_auth.expect(login_model)
def post(self):
"""用户登录"""
data = request.get_json()
user = User.query.filter(
(User.username == data.get('username')) |
(User.email == data.get('username'))
).first()
if not user or not user.check_password(data.get('password')):
abort(401, message='用户名或密码错误')
# 创建Token
access_token = create_access_token(identity=user)
refresh_token = create_refresh_token(identity=user)
response = jsonify({
'success': True,
'access_token': access_token,
'refresh_token': refresh_token,
'user': marshal(user, user_model)
})
# 设置Cookie(可选)
if app.config.get('JWT_TOKEN_LOCATION') and 'cookies' in app.config['JWT_TOKEN_LOCATION']:
set_access_cookies(response, access_token)
set_refresh_cookies(response, refresh_token)
return response
@ns_auth.route('/refresh')
class Refresh(Resource):
@jwt_required(refresh=True)
def post(self):
"""刷新Token"""
identity = get_jwt_identity()
access_token = create_access_token(identity=identity)
return jsonify({
'success': True,
'access_token': access_token
})
@ns_auth.route('/logout')
class Logout(Resource):
@jwt_required()
def post(self):
"""登出(撤销Token)"""
jti = get_jwt()['jti']
user_id = get_jwt_identity()
# 将Token加入黑名单
token = TokenBlacklist(
jti=jti,
token_type='access',
user_id=user_id,
revoked=True
)
db.session.add(token)
db.session.commit()
response = jsonify({'success': True, 'message': '登出成功'})
unset_jwt_cookies(response)
return response
8.2 权限控制
from functools import wraps
def role_required(*roles):
"""角色检查装饰器"""
def decorator(fn):
@wraps(fn)
@jwt_required()
def wrapper(*args, **kwargs):
current_user = get_jwt_identity()
user = User.query.get(current_user)
if not user or not any(user.has_role(role) for role in roles):
abort(403, message='权限不足')
return fn(*args, **kwargs)
return wrapper
return decorator
def permission_required(*permissions):
"""权限检查装饰器"""
def decorator(fn):
@wraps(fn)
@jwt_required()
def wrapper(*args, **kwargs):
current_user = get_jwt_identity()
user = User.query.get(current_user)
if not user or not any(user.has_permission(p) for p in permissions):
abort(403, message='权限不足')
return fn(*args, **kwargs)
return wrapper
return decorator
def owner_required(get_owner_id):
"""资源所有者检查装饰器"""
def decorator(fn):
@wraps(fn)
@jwt_required()
def wrapper(*args, **kwargs):
current_user_id = get_jwt_identity()
resource_id = kwargs.get('id') or kwargs.get('user_id')
owner_id = get_owner_id(resource_id)
if current_user_id != owner_id:
user = User.query.get(current_user_id)
if not user.has_role('admin'):
abort(403, message='无权访问此资源')
return fn(*args, **kwargs)
return wrapper
return decorator
# 使用示例
@ns_posts.route('/<int:post_id>')
class PostResource(Resource):
@jwt_required()
def get(self, post_id):
"""获取文章(需要登录)"""
pass
@owner_required(lambda post_id: Post.query.get(post_id).user_id)
def put(self, post_id):
"""更新文章(需要所有者)"""
pass
@role_required('admin', 'moderator')
def delete(self, post_id):
"""删除文章(需要管理员或版主)"""
pass
九、性能优化
9.1 缓存策略
from flask_caching import Cache
cache = Cache(app, config={'CACHE_TYPE': 'RedisCache'})
class CachedResource:
"""缓存资源"""
@staticmethod
def cache_key_prefix():
"""缓存键前缀"""
return f"api:{request.endpoint}"
@staticmethod
def cache_key_with_params():
"""带参数的缓存键"""
return f"{CachedResource.cache_key_prefix()}:{request.query_string.decode()}"
@staticmethod
def cached_response(timeout=300):
"""缓存响应装饰器"""
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
cache_key = CachedResource.cache_key_with_params()
response = cache.get(cache_key)
if response is None:
response = fn(*args, **kwargs)
cache.set(cache_key, response, timeout=timeout)
return response
return wrapper
return decorator
# 使用示例
class UserListResource(Resource):
@CachedResource.cached_response(timeout=60)
@marshal_with(user_list_fields)
def get(self):
"""获取用户列表(缓存60秒)"""
pass
# 缓存失效
def invalidate_cache(pattern):
"""使缓存失效"""
if app.config['CACHE_TYPE'] == 'RedisCache':
cache.cache._client.delete_pattern(f"*{pattern}*")
else:
cache.clear()
class UserResource(Resource):
def put(self, user_id):
"""更新用户"""
# ... 更新逻辑
# 使相关缓存失效
invalidate_cache(f"api:user:{user_id}")
invalidate_cache("api:users")
return user
9.2 查询优化
from sqlalchemy.orm import lazyload, joinedload, subqueryload
class QueryOptimization:
"""查询优化"""
@staticmethod
def optimize_user_query():
"""优化用户查询"""
# 预加载关联数据
return User.query.options(
joinedload(User.profile),
subqueryload(User.posts).joinedload(Post.comments)
)
@staticmethod
def paginate_optimized(query, page, per_page):
"""优化分页查询"""
# 使用延迟加载
return query.options(lazyload('*')).paginate(
page=page,
per_page=per_page,
error_out=False
)
@staticmethod
def count_optimized(model, filters=None):
"""优化计数查询"""
query = db.session.query(func.count(model.id))
if filters:
query = query.filter(*filters)
return query.scalar()
# 批量操作
class BulkOperations:
"""批量操作"""
@staticmethod
def bulk_create(model_class, data_list):
"""批量创建"""
db.session.bulk_insert_mappings(model_class, data_list)
db.session.commit()
@staticmethod
def bulk_update(model_class, data_list):
"""批量更新"""
db.session.bulk_update_mappings(model_class, data_list)
db.session.commit()
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)