Python依赖注入与控制反转
Python依赖注入与控制反转
一、什么是依赖注入
依赖注入(DI)是一种设计模式,将对象的依赖从内部创建改为外部传入,实现控制反转(IoC)。
# 没有依赖注入(紧耦合)
class OrderService:
def __init__(self):
self.db = MySQLDatabase() # 硬编码依赖
self.mailer = SMTPMailer() # 硬编码依赖
def create_order(self, order_data):
order = self.db.insert('orders', order_data)
self.mailer.send(order.user_email, "订单已创建")
return order
# 使用依赖注入(松耦合)
class OrderService:
def __init__(self, db, mailer):
self.db = db # 外部注入
self.mailer = mailer # 外部注入
def create_order(self, order_data):
order = self.db.insert('orders', order_data)
self.mailer.send(order.user_email, "订单已创建")
return order
# 可以轻松替换依赖(测试时用Mock)
service = OrderService(db=FakeDatabase(), mailer=FakeMailer())
二、构造函数注入
from typing import Protocol
class Database(Protocol):
def query(self, sql: str) -> list: ...
def insert(self, table: str, data: dict) -> dict: ...
class EmailSender(Protocol):
def send(self, to: str, subject: str, body: str) -> None: ...
class Logger(Protocol):
def info(self, message: str) -> None: ...
def error(self, message: str) -> None: ...
class UserService:
"""通过构造函数注入所有依赖"""
def __init__(self, db: Database, email: EmailSender, logger: Logger):
self.db = db
self.email = email
self.logger = logger
def register(self, name: str, email_addr: str) -> dict:
self.logger.info(f"注册用户: {name}")
user = self.db.insert('users', {'name': name, 'email': email_addr})
self.email.send(email_addr, "欢迎", f"欢迎 {name}!")
return user
def get_user(self, user_id: int) -> dict:
users = self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
if not users:
raise ValueError(f"用户 {user_id} 不存在")
return users[0]
三、简单的DI容器
class Container:
"""简单的依赖注入容器"""
def __init__(self):
self._factories = {}
self._singletons = {}
self._instances = {}
def register(self, interface, factory, singleton=False):
"""注册依赖"""
self._factories[interface] = factory
if singleton:
self._singletons[interface] = True
def register_instance(self, interface, instance):
"""注册已有实例"""
self._instances[interface] = instance
def resolve(self, interface):
"""解析依赖"""
# 先检查已注册的实例
if interface in self._instances:
return self._instances[interface]
# 检查工厂
if interface not in self._factories:
raise KeyError(f"未注册的依赖: {interface}")
factory = self._factories[interface]
# 单例模式
if interface in self._singletons:
if interface not in self._instances:
self._instances[interface] = factory(self)
return self._instances[interface]
return factory(self)
def __getitem__(self, interface):
return self.resolve(interface)
# 使用
container = Container()
# 注册依赖
container.register(Database, lambda c: PostgresDatabase(host='localhost'), singleton=True)
container.register(EmailSender, lambda c: SMTPMailer(host='smtp.example.com'))
container.register(Logger, lambda c: FileLogger('app.log'), singleton=True)
container.register(UserService, lambda c: UserService(
db=c[Database],
email=c[EmailSender],
logger=c[Logger]
))
# 解析
user_service = container[UserService]
user_service.register("Alice", "alice@example.com")
四、基于装饰器的DI
import inspect
from typing import get_type_hints
class AutoWireContainer:
"""自动装配的DI容器"""
def __init__(self):
self._registry = {}
def register(self, cls, singleton=False):
"""装饰器:注册类"""
def decorator(impl):
self._registry[cls] = {
'impl': impl,
'singleton': singleton,
'instance': None
}
return impl
return decorator
def resolve(self, cls):
if cls not in self._registry:
raise KeyError(f"未注册: {cls}")
entry = self._registry[cls]
if entry['singleton'] and entry['instance']:
return entry['instance']
# 自动解析构造函数参数
impl = entry['impl']
hints = get_type_hints(impl.__init__)
kwargs = {}
for param_name, param_type in hints.items():
if param_name == 'return':
continue
if param_type in self._registry:
kwargs[param_name] = self.resolve(param_type)
instance = impl(**kwargs)
if entry['singleton']:
entry['instance'] = instance
return instance
# 使用
container = AutoWireContainer()
@container.register(Database, singleton=True)
class PostgresDB:
def query(self, sql):
return []
def insert(self, table, data):
return data
@container.register(EmailSender)
class SMTPEmail:
def send(self, to, subject, body):
print(f"发送邮件到 {to}")
@container.register(Logger, singleton=True)
class AppLogger:
def info(self, msg):
print(f"[INFO] {msg}")
def error(self, msg):
print(f"[ERROR] {msg}")
@container.register(UserService)
class UserServiceImpl:
def __init__(self, db: Database, email: EmailSender, logger: Logger):
self.db = db
self.email = email
self.logger = logger
# 自动装配
service = container.resolve(UserService)
五、作用域管理
from contextlib import contextmanager
from enum import Enum
class Scope(Enum):
TRANSIENT = 'transient' # 每次创建新实例
SINGLETON = 'singleton' # 全局单例
SCOPED = 'scoped' # 作用域内单例
class ScopedContainer:
def __init__(self, parent=None):
self._registry = parent._registry if parent else {}
self._scoped_instances = {}
self._singleton_instances = parent._singleton_instances if parent else {}
def register(self, interface, factory, scope=Scope.TRANSIENT):
self._registry[interface] = {'factory': factory, 'scope': scope}
def resolve(self, interface):
if interface not in self._registry:
raise KeyError(f"未注册: {interface}")
entry = self._registry[interface]
scope = entry['scope']
factory = entry['factory']
if scope == Scope.SINGLETON:
if interface not in self._singleton_instances:
self._singleton_instances[interface] = factory(self)
return self._singleton_instances[interface]
elif scope == Scope.SCOPED:
if interface not in self._scoped_instances:
self._scoped_instances[interface] = factory(self)
return self._scoped_instances[interface]
else: # TRANSIENT
return factory(self)
@contextmanager
def create_scope(self):
"""创建子作用域"""
child = ScopedContainer(parent=self)
try:
yield child
finally:
# 清理作用域内的资源
for instance in child._scoped_instances.values():
if hasattr(instance, 'close'):
instance.close()
# 使用(Web请求场景)
container = ScopedContainer()
container.register(Database, lambda c: DatabaseConnection(), scope=Scope.SCOPED)
container.register(UserService, lambda c: UserService(c.resolve(Database)), scope=Scope.SCOPED)
# 每个请求一个作用域
def handle_request(request):
with container.create_scope() as scope:
service = scope.resolve(UserService)
# 同一作用域内共享同一个数据库连接
return service.process(request)
六、依赖注入与测试
class TestUserService:
def setup_method(self):
"""每个测试方法前创建Mock依赖"""
self.mock_db = MockDatabase()
self.mock_email = MockEmailSender()
self.mock_logger = MockLogger()
self.service = UserService(
db=self.mock_db,
email=self.mock_email,
logger=self.mock_logger
)
def test_register_success(self):
self.mock_db.insert_returns = {'id': 1, 'name': 'Alice'}
result = self.service.register("Alice", "alice@example.com")
assert result['name'] == 'Alice'
assert self.mock_email.sent_count == 1
assert self.mock_logger.messages[-1] == "注册用户: Alice"
def test_register_db_failure(self):
self.mock_db.should_fail = True
with pytest.raises(DatabaseError):
self.service.register("Alice", "alice@example.com")
assert self.mock_email.sent_count == 0 # 邮件未发送
class MockDatabase:
def __init__(self):
self.insert_returns = {}
self.should_fail = False
self.queries = []
def insert(self, table, data):
if self.should_fail:
raise DatabaseError("模拟数据库错误")
self.queries.append(('insert', table, data))
return self.insert_returns
def query(self, sql):
self.queries.append(('query', sql))
return []
七、FastAPI的依赖注入
from fastapi import FastAPI, Depends
app = FastAPI()
# 依赖函数
def get_db():
db = DatabaseSession()
try:
yield db
finally:
db.close()
def get_current_user(token: str = Depends(get_token)):
user = verify_token(token)
if not user:
raise HTTPException(status_code=401)
return user
def get_user_service(
db: Database = Depends(get_db),
current_user: User = Depends(get_current_user)
):
return UserService(db=db, user=current_user)
# 路由中使用
@app.post("/orders")
async def create_order(
order_data: OrderCreate,
service: UserService = Depends(get_user_service)
):
return service.create_order(order_data)
# 依赖覆盖(测试时)
def override_get_db():
return FakeDatabase()
app.dependency_overrides[get_db] = override_get_db
八、实际项目中的DI架构
# 项目结构
# src/
# interfaces/ # 抽象接口定义
# services/ # 业务逻辑
# repositories/ # 数据访问
# infrastructure/ # 具体实现
# container.py # DI配置
# container.py
def create_container(config):
container = Container()
# 基础设施
container.register(Database,
lambda c: PostgresDatabase(config.db_url), singleton=True)
container.register(Cache,
lambda c: RedisCache(config.redis_url), singleton=True)
container.register(EmailSender,
lambda c: SMTPMailer(config.smtp_host))
# 仓储层
container.register(UserRepository,
lambda c: SQLUserRepository(c[Database]))
container.register(OrderRepository,
lambda c: SQLOrderRepository(c[Database]))
# 服务层
container.register(UserService,
lambda c: UserService(c[UserRepository], c[EmailSender]))
container.register(OrderService,
lambda c: OrderService(c[OrderRepository], c[UserService], c[Cache]))
return container
九、总结
依赖注入的好处:
1. 松耦合:组件不依赖具体实现
2. 可测试:轻松替换为Mock对象
3. 可配置:运行时决定使用哪个实现
4. 单一职责:每个类只关注自己的逻辑
使用建议:
- 小项目:手动构造函数注入即可
- 中型项目:简单的DI容器
- 大型项目:使用成熟的DI框架(dependency-injector)
- Web框架:利用框架内置的DI(FastAPI Depends)
- 不要过度设计:只在确实需要灵活性时使用DI
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)