目录

引言

第一章 用户模块设计与实现

1.1 用户系统架构设计

1.2 JWT认证与权限管理

1.3 用户注册与登录接口

第二章 新闻模块与业务逻辑

2.1 新闻数据模型设计

2.2 新闻CRUD接口实现

2.3 用户收藏与浏览历史

第三章 缓存策略与AI模型集成

3.1 Redis缓存架构设计

3.2 缓存中间件实现

3.3 AI模型调用架构

3.4 AI接口集成到新闻系统

总结


引言

随着人工智能技术的快速发展,AI驱动的新闻推荐系统成为内容平台的核心组件。本文将基于FastAPI框架,从零开始构建一个完整的AI新闻推荐系统后端服务。项目涵盖用户认证管理、新闻内容运营、用户收藏浏览历史追踪以及智能缓存策略等核心模块,帮助你掌握企业级Web应用的开发流程。

整个项目采用模块化设计,各功能模块职责清晰,便于维护和扩展。通过这个实战项目,你将学习到如何将FastAPI的各项特性有机结合,构建高性能、易扩展的后端服务。我们会重点关注API设计、数据模型构建、业务逻辑处理和性能优化等关键环节。


第一章 用户模块设计与实现

1.1 用户系统架构设计

用户模块是整个新闻系统的基础,它承担着用户身份认证、权限管理和个性化服务等核心职责。一个设计良好的用户系统需要考虑安全性、可扩展性和用户体验等多个维度。

在架构设计层面,我们将用户模块划分为认证层、业务层和数据层。认证层负责处理登录注册、令牌验证等安全相关功能;业务层实现用户信息管理、偏好设置等核心逻辑;数据层则专注于用户数据的持久化存储和查询优化。这种分层设计使得各层职责明确,便于独立演进和测试。

from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field, validator
from typing import Optional, List
from datetime import datetime, timedelta
from passlib.context import CryptContext
from jose import JWTError, jwt
import re
​
app = FastAPI(title="AI新闻系统", version="1.0.0")
​
# 安全配置
SECRET_KEY = "your-super-secret-key-change-in-production"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24  # 24小时
​
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
​
# Pydantic模型定义
class UserBase(BaseModel):
    username: str = Field(..., min_length=3, max_length=30)
    email: EmailStr
    phone: Optional[str] = None
    avatar_url: Optional[str] = None
​
class UserCreate(UserBase):
    password: str = Field(..., min_length=8, max_length=50)
​
    @validator('password')
    def validate_password_strength(cls, v):
        if len(v) < 8:
            raise ValueError('密码长度至少8位')
        if not re.search(r'[A-Za-z]', v):
            raise ValueError('密码必须包含字母')
        if not re.search(r'\d', v):
            raise ValueError('密码必须包含数字')
        if not re.search(r'[!@#$%^&*(),.?":{}|<>]', v):
            raise ValueError('密码必须包含特殊字符')
        return v
​
    @validator('username')
    def validate_username(cls, v):
        if not re.match(r'^[a-zA-Z0-9_]+$', v):
            raise ValueError('用户名只能包含字母、数字和下划线')
        return v.lower()
​
class UserUpdate(BaseModel):
    email: Optional[EmailStr] = None
    phone: Optional[str] = None
    avatar_url: Optional[str] = None
    bio: Optional[str] = Field(None, max_length=200)
​
class UserResponse(UserBase):
    user_id: int
    is_active: bool
    is_verified: bool
    created_at: datetime
    last_login: Optional[datetime] = None
​
    class Config:
        from_attributes = True
​
class UserDetailResponse(UserResponse):
    bio: Optional[str] = None
    followers_count: int = 0
    following_count: int = 0
    articles_count: int = 0

1.2 JWT认证与权限管理

JWT(JSON Web Token)是现代Web应用中进行用户认证的主流方案。相比传统的session认证,JWT具有无状态、可扩展、支持跨域等优势,特别适合分布式系统和移动端应用。

class Token(BaseModel):
    access_token: str
    refresh_token: str
    token_type: str = "bearer"
    expires_in: int
​
class TokenData(BaseModel):
    user_id: Optional[int] = None
    username: Optional[str] = None
​
class PasswordResetRequest(BaseModel):
    email: EmailStr
​
class PasswordResetConfirm(BaseModel):
    token: str
    new_password: str = Field(..., min_length=8)
​
class ChangePasswordRequest(BaseModel):
    old_password: str
    new_password: str = Field(..., min_length=8)
​
def verify_password(plain_password: str, hashed_password: str) -> bool:
    return pwd_context.verify(plain_password, hashed_password)
​
def get_password_hash(password: str) -> str:
    return pwd_context.hash(password)
​
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    to_encode = data.copy()
    expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
    to_encode.update({
        "exp": expire,
        "type": "access"
    })
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
​
def create_refresh_token(data: dict) -> str:
    to_encode = data.copy()
    expire = datetime.utcnow() + timedelta(days=30)
    to_encode.update({
        "exp": expire,
        "type": "refresh"
    })
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
​
async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="认证失败,请重新登录",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        user_id: int = payload.get("user_id")
        token_type: str = payload.get("type")
        if user_id is None or token_type != "access":
            raise credentials_exception
        token_data = TokenData(user_id=user_id)
    except JWTError:
        raise credentials_exception
​
    # 实际应该从数据库查询用户
    user = UserResponse(
        user_id=token_data.user_id,
        username="example_user",
        email="user@example.com",
        is_active=True,
        is_verified=True,
        created_at=datetime.now()
    )
    return user
​
async def get_current_active_user(
    current_user: UserResponse = Depends(get_current_user)
) -> UserResponse:
    if not current_user.is_active:
        raise HTTPException(status_code=400, detail="用户账号已被禁用")
    return current_user

1.3 用户注册与登录接口

用户注册和登录是用户模块的核心功能。在实现这些接口时,需要考虑数据验证、安全存储、错误处理等多个方面。注册时应该对密码进行哈希存储,登录时验证凭证后生成JWT令牌。

from fastapi import APIRouter, Depends, HTTPException, status
​
router = APIRouter(prefix="/api/v1/auth", tags=["认证模块"])
​
# 模拟用户存储(实际项目应使用数据库)
MOCK_USERS_DB = {}
​
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate):
    # 检查用户名是否已存在
    if user_data.username in MOCK_USERS_DB:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="用户名已被注册"
        )
​
    # 检查邮箱是否已存在
    for user in MOCK_USERS_DB.values():
        if user.email == user_data.email:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="邮箱已被注册"
            )
​
    # 创建新用户
    user_id = len(MOCK_USERS_DB) + 1
    hashed_password = get_password_hash(user_data.password)
​
    user_response = UserResponse(
        user_id=user_id,
        username=user_data.username,
        email=user_data.email,
        phone=user_data.phone,
        avatar_url=user_data.avatar_url,
        is_active=True,
        is_verified=False,
        created_at=datetime.now()
    )
​
    # 存储用户数据(模拟)
    MOCK_USERS_DB[user_data.username] = {
        **user_data.dict(),
        "user_id": user_id,
        "hashed_password": hashed_password,
        "is_active": True,
        "is_verified": False,
        "created_at": datetime.now()
    }
​
    return user_response
​
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    user = MOCK_USERS_DB.get(form_data.username)
    if not user or not verify_password(form_data.password, user["hashed_password"]):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="用户名或密码错误",
            headers={"WWW-Authenticate": "Bearer"},
        )
​
    if not user["is_active"]:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="用户账号已被禁用"
        )
​
    access_token = create_access_token(
        data={"user_id": user["user_id"], "username": user["username"]}
    )
    refresh_token = create_refresh_token(
        data={"user_id": user["user_id"], "username": user["username"]}
    )
​
    return Token(
        access_token=access_token,
        refresh_token=refresh_token,
        expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60
    )
​
@router.post("/refresh", response_model=Token)
async def refresh_token(refresh_token: str):
    try:
        payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
        if payload.get("type") != "refresh":
            raise HTTPException(status_code=401, detail="无效的刷新令牌")
        user_id = payload.get("user_id")
        username = payload.get("username")
    except JWTError:
        raise HTTPException(status_code=401, detail="令牌已过期")
​
    new_access_token = create_access_token(
        data={"user_id": user_id, "username": username}
    )
    new_refresh_token = create_refresh_token(
        data={"user_id": user_id, "username": username}
    )
​
    return Token(
        access_token=new_access_token,
        refresh_token=new_refresh_token,
        expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60
    )
​
@router.get("/me", response_model=UserDetailResponse)
async def get_user_profile(current_user: UserResponse = Depends(get_current_active_user)):
    # 模拟获取完整用户信息
    return UserDetailResponse(
        **current_user.dict(),
        bio="这是一条个性签名",
        followers_count=100,
        following_count=50,
        articles_count=10
    )
​
@router.put("/me", response_model=UserResponse)
async def update_user_profile(
    user_update: UserUpdate,
    current_user: UserResponse = Depends(get_current_active_user)
):
    # 更新用户信息逻辑
    return current_user
​
@router.post("/change-password", status_code=status.HTTP_200_OK)
async def change_password(
    password_data: ChangePasswordRequest,
    current_user: UserResponse = Depends(get_current_active_user)
):
    user = MOCK_USERS_DB.get(current_user.username)
    if not verify_password(password_data.old_password, user["hashed_password"]):
        raise HTTPException(status_code=400, detail="原密码错误")
​
    user["hashed_password"] = get_password_hash(password_data.new_password)
    return {"message": "密码修改成功"}

第二章 新闻模块与业务逻辑

2.1 新闻数据模型设计

新闻模块是整个平台的核心功能,需要支持新闻的创建、编辑、审核、推荐等完整生命周期。合理的数据模型设计是保证系统性能和可扩展性的基础。

from typing import Optional, List
from enum import Enum
from pydantic import BaseModel, Field
​
class NewsStatus(str, Enum):
    DRAFT = "draft"           # 草稿
    PENDING = "pending"       # 待审核
    APPROVED = "approved"     # 审核通过
    REJECTED = "rejected"     # 审核拒绝
    PUBLISHED = "published"   # 已发布
    ARCHIVED = "archived"     # 已归档
​
class NewsCategory(str, Enum):
    TECHNOLOGY = "technology"     # 科技
    BUSINESS = "business"         # 商业
    ENTERTAINMENT = "entertainment"  # 娱乐
    SPORTS = "sports"            # 体育
    EDUCATION = "education"       # 教育
    HEALTH = "health"            # 健康
    SCIENCE = "science"          # 科学
    POLITICS = "politics"         # 政治
​
class NewsBase(BaseModel):
    title: str = Field(..., min_length=5, max_length=200)
    content: str = Field(..., min_length=50)
    summary: Optional[str] = Field(None, max_length=500)
    cover_image_url: Optional[str] = None
    category: NewsCategory
    tags: List[str] = Field(default_factory=list, max_items=10)
    source: Optional[str] = None
    author_name: Optional[str] = None
​
class NewsCreate(NewsBase):
    pass
​
class NewsUpdate(BaseModel):
    title: Optional[str] = Field(None, min_length=5, max_length=200)
    content: Optional[str] = Field(None, min_length=50)
    summary: Optional[str] = Field(None, max_length=500)
    cover_image_url: Optional[str] = None
    category: Optional[NewsCategory] = None
    tags: Optional[List[str]] = None
    source: Optional[str] = None
​
class NewsResponse(NewsBase):
    news_id: int
    status: NewsStatus
    view_count: int = 0
    like_count: int = 0
    comment_count: int = 0
    share_count: int = 0
    created_at: datetime
    updated_at: datetime
    published_at: Optional[datetime] = None
    is_featured: bool = False
    is_top: bool = False
​
    class Config:
        from_attributes = True
​
class NewsListResponse(BaseModel):
    items: List[NewsResponse]
    total: int
    page: int
    page_size: int
    total_pages: int
​
class NewsSearchParams(BaseModel):
    keyword: Optional[str] = None
    category: Optional[NewsCategory] = None
    tags: Optional[List[str]] = None
    status: Optional[NewsStatus] = None
    start_date: Optional[datetime] = None
    end_date: Optional[datetime] = None
    sort_by: str = "created_at"
    sort_order: str = "desc"
    page: int = Field(1, ge=1)
    page_size: int = Field(20, ge=1, le=100)

2.2 新闻CRUD接口实现

新闻的增删改查是基础功能,通过FastAPI的路由和依赖注入系统可以优雅地实现这些接口。

from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
​
news_router = APIRouter(prefix="/api/v1/news", tags=["新闻模块"])
​
# 模拟新闻存储
MOCK_NEWS_DB = {}
​
@news_router.post("/", response_model=NewsResponse, status_code=status.HTTP_201_CREATED)
async def create_news(
    news_data: NewsCreate,
    current_user: UserResponse = Depends(get_current_active_user)
):
    news_id = len(MOCK_NEWS_DB) + 1
    now = datetime.now()
​
    news = NewsResponse(
        news_id=news_id,
        **news_data.dict(),
        status=NewsStatus.DRAFT,
        view_count=0,
        like_count=0,
        comment_count=0,
        share_count=0,
        created_at=now,
        updated_at=now
    )
​
    MOCK_NEWS_DB[news_id] = news
    return news
​
@news_router.get("/{news_id}", response_model=NewsResponse)
async def get_news(news_id: int):
    if news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    news = MOCK_NEWS_DB[news_id]
    # 增加浏览量
    news.view_count += 1
    return news
​
@news_router.get("/", response_model=NewsListResponse)
async def list_news(
    page: int = Query(1, ge=1),
    page_size: int = Query(20, ge=1, le=100),
    category: Optional[NewsCategory] = None,
    keyword: Optional[str] = None,
    is_featured: Optional[bool] = None
):
    # 模拟查询逻辑
    news_list = list(MOCK_NEWS_DB.values())
​
    # 过滤
    if category:
        news_list = [n for n in news_list if n.category == category]
    if keyword:
        news_list = [n for n in news_list
                     if keyword.lower() in n.title.lower()
                     or keyword.lower() in n.content.lower()]
    if is_featured is not None:
        news_list = [n for n in news_list if n.is_featured == is_featured]
​
    # 只返回已发布的
    news_list = [n for n in news_list if n.status == NewsStatus.PUBLISHED]
​
    # 排序
    news_list.sort(key=lambda x: x.created_at, reverse=True)
​
    # 分页
    total = len(news_list)
    start = (page - 1) * page_size
    end = start + page_size
    items = news_list[start:end]
​
    return NewsListResponse(
        items=items,
        total=total,
        page=page,
        page_size=page_size,
        total_pages=(total + page_size - 1) // page_size
    )
​
@news_router.put("/{news_id}", response_model=NewsResponse)
async def update_news(
    news_id: int,
    news_update: NewsUpdate,
    current_user: UserResponse = Depends(get_current_active_user)
):
    if news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    news = MOCK_NEWS_DB[news_id]
    update_data = news_update.dict(exclude_unset=True)
​
    for field, value in update_data.items():
        setattr(news, field, value)
​
    news.updated_at = datetime.now()
    return news
​
@news_router.delete("/{news_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_news(
    news_id: int,
    current_user: UserResponse = Depends(get_current_active_user)
):
    if news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    del MOCK_NEWS_DB[news_id]
    return None

2.3 用户收藏与浏览历史

收藏功能和浏览历史是提升用户体验的重要功能。收藏让用户能够保存感兴趣的内容以便后续查看,而浏览历史则记录用户的阅读习惯,为推荐系统提供数据支持。

from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional, List
​
class FavoriteBase(BaseModel):
    news_id: int
​
class FavoriteCreate(FavoriteBase):
    pass
​
class FavoriteResponse(BaseModel):
    favorite_id: int
    user_id: int
    news_id: int
    created_at: datetime
    news: Optional[NewsResponse] = None
​
    class Config:
        from_attributes = True
​
class FavoriteListResponse(BaseModel):
    items: List[FavoriteResponse]
    total: int
    page: int
    page_size: int
​
class BrowseHistoryResponse(BaseModel):
    history_id: int
    user_id: int
    news_id: int
    read_progress: int = Field(0, ge=0, le=100)  # 阅读进度百分比
    read_duration: int = 0  # 阅读时长(秒)
    created_at: datetime
    last_read_at: datetime
    news: Optional[NewsResponse] = None
​
    class Config:
        from_attributes = True
​
class UpdateReadProgressRequest(BaseModel):
    news_id: int
    read_progress: int = Field(0, ge=0, le=100)
    read_duration: int = Field(0, ge=0)
​
favorite_router = APIRouter(prefix="/api/v1/favorites", tags=["收藏模块"])
history_router = APIRouter(prefix="/api/v1/history", tags=["浏览历史模块"])
​
# 模拟存储
MOCK_FAVORITES_DB = {}
MOCK_HISTORY_DB = {}
​
@favorite_router.post("/", response_model=FavoriteResponse, status_code=status.HTTP_201_CREATED)
async def create_favorite(
    favorite_data: FavoriteCreate,
    current_user: UserResponse = Depends(get_current_active_user)
):
    # 检查是否已经收藏
    for fav in MOCK_FAVORITES_DB.values():
        if fav.user_id == current_user.user_id and fav.news_id == favorite_data.news_id:
            raise HTTPException(status_code=400, detail="已经收藏过该新闻")
​
    if favorite_data.news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    favorite_id = len(MOCK_FAVORITES_DB) + 1
    favorite = FavoriteResponse(
        favorite_id=favorite_id,
        user_id=current_user.user_id,
        news_id=favorite_data.news_id,
        created_at=datetime.now(),
        news=MOCK_NEWS_DB[favorite_data.news_id]
    )
​
    MOCK_FAVORITES_DB[favorite_id] = favorite
    return favorite
​
@favorite_router.get("/", response_model=FavoriteListResponse)
async def list_favorites(
    page: int = Query(1, ge=1),
    page_size: int = Query(20, ge=1, le=100),
    current_user: UserResponse = Depends(get_current_active_user)
):
    user_favorites = [
        fav for fav in MOCK_FAVORITES_DB.values()
        if fav.user_id == current_user.user_id
    ]
​
    user_favorites.sort(key=lambda x: x.created_at, reverse=True)
    total = len(user_favorites)
​
    start = (page - 1) * page_size
    end = start + page_size
    items = user_favorites[start:end]
​
    return FavoriteListResponse(
        items=items,
        total=total,
        page=page,
        page_size=page_size
    )
​
@favorite_router.delete("/{news_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_favorite(
    news_id: int,
    current_user: UserResponse = Depends(get_current_active_user)
):
    favorite_to_delete = None
    for fav_id, fav in MOCK_FAVORITES_DB.items():
        if fav.user_id == current_user.user_id and fav.news_id == news_id:
            favorite_to_delete = fav_id
            break
​
    if favorite_to_delete is None:
        raise HTTPException(status_code=404, detail="收藏不存在")
​
    del MOCK_FAVORITES_DB[favorite_to_delete]
    return None
​
@history_router.post("/read", response_model=BrowseHistoryResponse)
async def record_browse_history(
    history_data: UpdateReadProgressRequest,
    current_user: UserResponse = Depends(get_current_active_user)
):
    if history_data.news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    # 查找是否已有浏览记录
    existing_history = None
    for hist in MOCK_HISTORY_DB.values():
        if hist.user_id == current_user.user_id and hist.news_id == history_data.news_id:
            existing_history = hist
            break
​
    now = datetime.now()
​
    if existing_history:
        # 更新已有记录
        existing_history.read_progress = history_data.read_progress
        existing_history.read_duration = history_data.read_duration
        existing_history.last_read_at = now
        return existing_history
    else:
        # 创建新记录
        history_id = len(MOCK_HISTORY_DB) + 1
        history = BrowseHistoryResponse(
            history_id=history_id,
            user_id=current_user.user_id,
            news_id=history_data.news_id,
            read_progress=history_data.read_progress,
            read_duration=history_data.read_duration,
            created_at=now,
            last_read_at=now,
            news=MOCK_NEWS_DB[history_data.news_id]
        )
​
        MOCK_HISTORY_DB[history_id] = history
        return history
​
@history_router.get("/", response_model=List[BrowseHistoryResponse])
async def list_browse_history(
    limit: int = Query(50, ge=1, le=100),
    current_user: UserResponse = Depends(get_current_active_user)
):
    user_history = [
        hist for hist in MOCK_HISTORY_DB.values()
        if hist.user_id == current_user.user_id
    ]
​
    user_history.sort(key=lambda x: x.last_read_at, reverse=True)
    return user_history[:limit]

第三章 缓存策略与AI模型集成

3.1 Redis缓存架构设计

在高并发场景下,缓存是提升系统性能的关键组件。合理的缓存策略可以显著降低数据库压力,加快接口响应速度。Redis是最流行的内存数据库,非常适合作为FastAPI应用的缓存层。

from typing import Optional, List, Any
import json
import hashlib
from functools import wraps
​
class CacheStrategy:
    CACHE_PREFIX = "news_app"
    DEFAULT_TTL = 300  # 5分钟
​
    @staticmethod
    def generate_key(*args, **kwargs) -> str:
        key_data = f"{args}:{sorted(kwargs.items())}"
        key_hash = hashlib.md5(key_data.encode()).hexdigest()
        return f"{CacheStrategy.CACHE_PREFIX}:{key_hash}"
​
    @staticmethod
    def get_news_key(news_id: int) -> str:
        return f"{CacheStrategy.CACHE_PREFIX}:news:{news_id}"
​
    @staticmethod
    def get_user_favorites_key(user_id: int) -> str:
        return f"{CacheStrategy.CACHE_PREFIX}:favorites:{user_id}"
​
    @staticmethod
    def get_news_list_key(page: int, page_size: int, **filters) -> str:
        filter_str = ":".join(f"{k}={v}" for k, v in sorted(filters.items()) if v)
        return f"{CacheStrategy.CACHE_PREFIX}:list:{page}:{page_size}:{filter_str}"
​
# 模拟Redis客户端(实际项目使用redis-py)
class MockRedisClient:
    def __init__(self):
        self._cache = {}
​
    async def get(self, key: str) -> Optional[str]:
        return self._cache.get(key)
​
    async def set(
        self,
        key: str,
        value: str,
        ex: Optional[int] = None
    ) -> bool:
        self._cache[key] = value
        return True
​
    async def delete(self, key: str) -> int:
        if key in self._cache:
            del self._cache[key]
            return 1
        return 0
​
    async def exists(self, key: str) -> bool:
        return key in self._cache
​
    async def expire(self, key: str, seconds: int) -> bool:
        return key in self._cache
​
    async def incr(self, key: str) -> int:
        if key in self._cache:
            self._cache[key] = str(int(self._cache[key]) + 1)
        else:
            self._cache[key] = "1"
        return int(self._cache[key])
​
redis_client = MockRedisClient()
​
def cache_key(*args, **kwargs):
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            key = CacheStrategy.generate_key(*args, **kwargs)
            return await func(key, *args, **kwargs)
        return wrapper
    return decorator

3.2 缓存中间件实现

通过中间件和依赖注入,我们可以将缓存逻辑优雅地集成到应用中。以下是一个完整的缓存实现方案:

from fastapi import Depends, HTTPException
from typing import Optional, Callable
import json
from datetime import datetime
​
class CacheService:
    def __init__(self, redis=redis_client):
        self.redis = redis
        self.default_ttl = CacheStrategy.DEFAULT_TTL
​
    async def get_cached(self, key: str) -> Optional[Any]:
        cached_data = await self.redis.get(key)
        if cached_data:
            try:
                return json.loads(cached_data)
            except json.JSONDecodeError:
                return cached_data
        return None
​
    async def set_cached(
        self,
        key: str,
        value: Any,
        ttl: Optional[int] = None
    ) -> bool:
        if isinstance(value, (dict, list)):
            value = json.dumps(value, default=str)
        return await self.redis.set(key, value, ex=ttl or self.default_ttl)
​
    async def invalidate_cache(self, key: str) -> int:
        return await self.redis.delete(key)
​
    async def invalidate_pattern(self, pattern: str) -> int:
        # 模拟模式删除
        keys_to_delete = [
            k for k in self.redis._cache.keys()
            if pattern.replace("*", "") in k
        ]
        for key in keys_to_delete:
            await self.redis.delete(key)
        return len(keys_to_delete)
​
    async def get_news(self, news_id: int) -> Optional[NewsResponse]:
        cache_key = CacheStrategy.get_news_key(news_id)
        cached = await self.get_cached(cache_key)
        if cached:
            return NewsResponse(**cached)
        return None
​
    async def set_news(self, news: NewsResponse, ttl: int = 600) -> bool:
        cache_key = CacheStrategy.get_news_key(news.news_id)
        return await self.set_cached(cache_key, news.dict(), ttl)
​
    async def invalidate_news(self, news_id: int) -> int:
        cache_key = CacheStrategy.get_news_key(news_id)
        return await self.invalidate_cache(cache_key)
​
cache_service = CacheService()
​
async def get_cache_service() -> CacheService:
    return cache_service

3.3 AI模型调用架构

在AI新闻系统中,模型调用主要用于新闻摘要生成、内容推荐和智能标签等场景。设计一个优雅的模型调用层需要考虑异步处理、错误重试、熔断降级等因素。

from typing import Optional, List, Dict, Any
from enum import Enum
import asyncio
import random
​
class ModelProvider(str, Enum):
    OPENAI = "openai"
    ANTHROPIC = "anthropic"
    LOCAL = "local"
    MOCK = "mock"
​
class AIModelType(str, Enum):
    TEXT_SUMMARIZATION = "text_summarization"
    NEWS_CLASSIFICATION = "news_classification"
    SIMILARITY_MATCHING = "similarity_matching"
    SENTIMENT_ANALYSIS = "sentiment_analysis"
​
class ModelConfig:
    def __init__(
        self,
        provider: ModelProvider,
        model_name: str,
        api_key: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
        timeout: int = 30
    ):
        self.provider = provider
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url
        self.max_retries = max_retries
        self.timeout = timeout
​
class ModelRequest:
    def __init__(
        self,
        model_type: AIModelType,
        input_text: str,
        parameters: Optional[Dict[str, Any]] = None
    ):
        self.model_type = model_type
        self.input_text = input_text
        self.parameters = parameters or {}
​
class ModelResponse:
    def __init__(
        self,
        success: bool,
        result: Any = None,
        error: Optional[str] = None,
        model_used: Optional[str] = None,
        tokens_used: Optional[int] = None,
        latency_ms: Optional[float] = None
    ):
        self.success = success
        self.result = result
        self.error = error
        self.model_used = model_used
        self.tokens_used = tokens_used
        self.latency_ms = latency_ms
​
class ModelService:
    def __init__(self):
        self.configs = {
            AIModelType.TEXT_SUMMARIZATION: ModelConfig(
                provider=ModelProvider.MOCK,
                model_name="mock-summarizer"
            ),
            AIModelType.NEWS_CLASSIFICATION: ModelConfig(
                provider=ModelProvider.MOCK,
                model_name="mock-classifier"
            ),
            AIModelType.SIMILARITY_MATCHING: ModelConfig(
                provider=ModelProvider.MOCK,
                model_name="mock-similarity"
            )
        }
​
    async def call_model(
        self,
        request: ModelRequest,
        retry_count: int = 0
    ) -> ModelResponse:
        config = self.configs.get(request.model_type)
        if not config:
            return ModelResponse(
                success=False,
                error=f"未配置模型类型: {request.model_type}"
            )
​
        try:
            # 模拟模型调用
            if request.model_type == AIModelType.TEXT_SUMMARIZATION:
                result = await self._mock_summarize(request.input_text)
            elif request.model_type == AIModelType.NEWS_CLASSIFICATION:
                result = await self._mock_classify(request.input_text)
            elif request.model_type == AIModelType.SIMILARITY_MATCHING:
                result = await self._mock_similarity(request.input_text)
            else:
                result = {"message": "Unknown model type"}
​
            return ModelResponse(
                success=True,
                result=result,
                model_used=config.model_name,
                tokens_used=len(request.input_text) // 4,
                latency_ms=random.uniform(100, 500)
            )
​
        except Exception as e:
            if retry_count < config.max_retries:
                await asyncio.sleep(2 ** retry_count)
                return await self.call_model(request, retry_count + 1)
​
            return ModelResponse(
                success=False,
                error=str(e),
                model_used=config.model_name
            )
​
    async def _mock_summarize(self, text: str) -> Dict[str, Any]:
        await asyncio.sleep(0.2)  # 模拟延迟
        words = text.split()[:50]
        return {
            "summary": " ".join(words) + "...",
            "key_points": ["要点1", "要点2", "要点3"],
            "word_count_original": len(text.split()),
            "word_count_summary": len(words)
        }
​
    async def _mock_classify(self, text: str) -> Dict[str, Any]:
        await asyncio.sleep(0.15)
        categories = ["technology", "business", "entertainment", "sports"]
        return {
            "primary_category": random.choice(categories),
            "confidence": random.uniform(0.7, 0.99),
            "all_scores": {cat: random.uniform(0, 1) for cat in categories}
        }
​
    async def _mock_similarity(self, text: str) -> Dict[str, Any]:
        await asyncio.sleep(0.1)
        return {
            "similar_news": [
                {"news_id": 1, "score": 0.95},
                {"news_id": 2, "score": 0.87},
                {"news_id": 3, "score": 0.82}
            ]
        }
​
model_service = ModelService()
​
async def get_model_service() -> ModelService:
    return model_service

3.4 AI接口集成到新闻系统

将AI能力集成到新闻系统后,我们可以实现智能摘要、自动分类、相似新闻推荐等功能:

from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
​
class NewsSummaryRequest(BaseModel):
    news_id: int
    max_length: int = Field(200, ge=50, le=500)
​
class NewsSummaryResponse(BaseModel):
    news_id: int
    summary: str
    key_points: List[str]
    original_length: int
    summary_length: int
    model_used: str
    generated_at: datetime
​
class NewsClassificationRequest(BaseModel):
    news_id: int
​
class NewsClassificationResponse(BaseModel):
    news_id: int
    category: str
    confidence: float
    all_categories: Dict[str, float]
    model_used: str
    classified_at: datetime
​
class SimilarNewsRequest(BaseModel):
    news_id: int
    limit: int = Field(5, ge=1, le=20)
​
class SimilarNewsResponse(BaseModel):
    source_news_id: int
    similar_news: List[Dict[str, Any]]
    model_used: str
    generated_at: datetime
​
ai_router = APIRouter(prefix="/api/v1/ai", tags=["AI功能"])
​
@ai_router.post("/summarize", response_model=NewsSummaryResponse)
async def summarize_news(
    request: NewsSummaryRequest,
    cache: CacheService = Depends(get_cache_service),
    model: ModelService = Depends(get_model_service)
):
    # 检查缓存
    cache_key = f"summarize:{request.news_id}"
    cached = await cache.get_cached(cache_key)
    if cached:
        return NewsSummaryResponse(**cached)
​
    # 获取新闻内容
    if request.news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    news = MOCK_NEWS_DB[request.news_id]
​
    # 调用AI模型生成摘要
    model_request = ModelRequest(
        model_type=AIModelType.TEXT_SUMMARIZATION,
        input_text=news.content,
        parameters={"max_length": request.max_length}
    )
​
    model_response = await model.call_model(model_request)
    if not model_response.success:
        raise HTTPException(status_code=500, detail=f"AI服务错误: {model_response.error}")
​
    result = model_response.result
    response = NewsSummaryResponse(
        news_id=request.news_id,
        summary=result["summary"],
        key_points=result["key_points"],
        original_length=result["word_count_original"],
        summary_length=result["word_count_summary"],
        model_used=model_response.model_used,
        generated_at=datetime.now()
    )
​
    # 存入缓存
    await cache.set_cached(cache_key, response.dict(), ttl=3600)
​
    return response
​
@ai_router.post("/classify", response_model=NewsClassificationResponse)
async def classify_news(
    request: NewsClassificationRequest,
    cache: CacheService = Depends(get_cache_service),
    model: ModelService = Depends(get_model_service)
):
    cache_key = f"classify:{request.news_id}"
    cached = await cache.get_cached(cache_key)
    if cached:
        return NewsClassificationResponse(**cached)
​
    if request.news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    news = MOCK_NEWS_DB[request.news_id]
​
    model_request = ModelRequest(
        model_type=AIModelType.NEWS_CLASSIFICATION,
        input_text=f"{news.title} {news.content}"
    )
​
    model_response = await model.call_model(model_request)
    if not model_response.success:
        raise HTTPException(status_code=500, detail=f"AI服务错误: {model_response.error}")
​
    result = model_response.result
    response = NewsClassificationResponse(
        news_id=request.news_id,
        category=result["primary_category"],
        confidence=result["confidence"],
        all_categories=result["all_scores"],
        model_used=model_response.model_used,
        classified_at=datetime.now()
    )
​
    await cache.set_cached(cache_key, response.dict(), ttl=7200)
    return response
​
@ai_router.post("/similar", response_model=SimilarNewsResponse)
async def find_similar_news(
    request: SimilarNewsRequest,
    model: ModelService = Depends(get_model_service)
):
    if request.news_id not in MOCK_NEWS_DB:
        raise HTTPException(status_code=404, detail="新闻不存在")
​
    news = MOCK_NEWS_DB[request.news_id]
​
    model_request = ModelRequest(
        model_type=AIModelType.SIMILARITY_MATCHING,
        input_text=f"{news.title} {news.content}",
        parameters={"limit": request.limit}
    )
​
    model_response = await model.call_model(model_request)
    if not model_response.success:
        raise HTTPException(status_code=500, detail=f"AI服务错误: {model_response.error}")
​
    return SimilarNewsResponse(
        source_news_id=request.news_id,
        similar_news=model_response.result["similar_news"][:request.limit],
        model_used=model_response.model_used,
        generated_at=datetime.now()
    )

总结

本文详细介绍了基于FastAPI构建AI新闻推荐系统后端的完整实践。用户模块实现了完整的认证授权体系,包括JWT令牌、密码安全和用户信息管理。新闻模块覆盖了新闻的CRUD操作、分类管理和搜索功能。收藏和浏览历史功能为用户提供了个性化服务支持。

缓存策略章节展示了如何设计可扩展的缓存架构,提升系统响应速度。AI模型集成部分则演示了如何优雅地集成各种AI服务,包括摘要生成、智能分类和相似度匹配等功能。这些模块的组合构成了一个功能完善的AI新闻平台后端服务。

在实际项目中,还需要考虑数据库持久化、Docker容器化部署、CI/CD自动化、监控告警等工程化要素。建议在掌握本文内容的基础上,进一步学习微服务架构设计、云原生部署等进阶主题,以构建更加完善的生产级应用。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐