Python协议与抽象基类

一、鸭子类型

Python遵循"鸭子类型":如果一个对象走起来像鸭子、叫起来像鸭子,那它就是鸭子。

class Duck:
def quack(self):
print("嘎嘎嘎")
def walk(self):
print("摇摇摆摆走")

class Person:
def quack(self):
print("我在模仿鸭子叫")
def walk(self):
print("我在模仿鸭子走")

def make_it_quack(thing):
# 不检查类型,只关心行为
thing.quack()
thing.walk()

make_it_quack(Duck()) # 正常工作
make_it_quack(Person()) # 也正常工作


二、Protocol(结构化子类型)

from typing import Protocol, runtime_checkable

@runtime_checkable
class Readable(Protocol):
def read(self, size: int = -1) -> str: ...

@runtime_checkable
class Writable(Protocol):
def write(self, data: str) -> int: ...

class Closeable(Protocol):
def close(self) -> None: ...

class ReadWriteCloseable(Readable, Writable, Closeable, Protocol):
pass

# 任何实现了这些方法的类都满足协议,无需显式继承
class StringBuffer:
def __init__(self):
self._buffer = ""

def read(self, size: int = -1) -> str:
if size == -1:
result = self._buffer
self._buffer = ""
else:
result = self._buffer[:size]
self._buffer = self._buffer[size:]
return result

def write(self, data: str) -> int:
self._buffer += data
return len(data)

def close(self) -> None:
self._buffer = ""

# 类型检查通过(结构化子类型)
def process_stream(stream: ReadWriteCloseable) -> None:
stream.write("hello")
data = stream.read()
stream.close()

buf = StringBuffer()
process_stream(buf) # OK,StringBuffer满足协议

# runtime_checkable允许isinstance检查
print(isinstance(buf, Readable)) # True
print(isinstance(buf, Writable)) # True


三、抽象基类(ABC)

from abc import ABC, abstractmethod, abstractproperty
from typing import Iterator

class Collection(ABC):
"""自定义集合抽象基类"""

@abstractmethod
def __len__(self) -> int:
"""返回集合大小"""
pass

@abstractmethod
def __contains__(self, item) -> bool:
"""检查元素是否存在"""
pass

@abstractmethod
def __iter__(self) -> Iterator:
"""返回迭代器"""
pass

# 具体方法(有默认实现)
def is_empty(self) -> bool:
return len(self) == 0

def count(self, item) -> int:
return sum(1 for x in self if x == item)

class SortedSet(Collection):
"""有序集合实现"""
def __init__(self, items=None):
self._data = sorted(set(items)) if items else []

def __len__(self):
return len(self._data)

def __contains__(self, item):
# 二分查找
import bisect
index = bisect.bisect_left(self._data, item)
return index < len(self._data) and self._data[index] == item

def __iter__(self):
return iter(self._data)

def add(self, item):
if item not in self:
import bisect
bisect.insort(self._data, item)

# 不能实例化抽象类
# Collection() # TypeError: Can't instantiate abstract class

ss = SortedSet([3, 1, 4, 1, 5, 9])
print(len(ss)) # 5
print(4 in ss) # True
print(ss.is_empty()) # False


四、注册虚拟子类

from abc import ABC, abstractmethod

class Serializer(ABC):
@abstractmethod
def serialize(self, data) -> bytes:
pass

@abstractmethod
def deserialize(self, raw: bytes):
pass

# 第三方类无法修改,但可以注册为虚拟子类
class ThirdPartyJsonCodec:
def serialize(self, data) -> bytes:
import json
return json.dumps(data).encode()

def deserialize(self, raw: bytes):
import json
return json.loads(raw.decode())

# 注册为Serializer的虚拟子类
Serializer.register(ThirdPartyJsonCodec)

print(isinstance(ThirdPartyJsonCodec(), Serializer)) # True
print(issubclass(ThirdPartyJsonCodec, Serializer)) # True

# 注意:虚拟子类不会被强制实现抽象方法
# 这是一种"信任声明"


五、collections.abc中的协议

from collections.abc import (
Iterable, Iterator, Sequence, MutableSequence,
Mapping, MutableMapping, Set, Callable, Hashable
)

# 检查对象是否满足特定协议
print(isinstance([], Sequence)) # True
print(isinstance({}, Mapping)) # True
print(isinstance(lambda: None, Callable)) # True

# 实现自定义序列
class InfiniteSequence(Sequence):
"""无限序列(只需实现__getitem__和__len__)"""
def __init__(self, func):
self.func = func

def __getitem__(self, index):
if isinstance(index, slice):
start = index.start or 0
stop = index.stop
step = index.step or 1
return [self.func(i) for i in range(start, stop, step)]
return self.func(index)

def __len__(self):
return float('inf') # 概念上无限

squares = InfiniteSequence(lambda n: n ** 2)
print(squares[5]) # 25
print(squares[2:7]) # [4, 9, 16, 25, 36]

# 实现自定义映射
class CaseInsensitiveMapping(MutableMapping):
def __init__(self, data=None):
self._store = {}
if data:
self.update(data)

def __getitem__(self, key):
return self._store[key.lower()]

def __setitem__(self, key, value):
self._store[key.lower()] = value

def __delitem__(self, key):
del self._store[key.lower()]

def __iter__(self):
return iter(self._store)

def __len__(self):
return len(self._store)


六、Protocol vs ABC的选择

# Protocol:结构化子类型(鸭子类型的形式化)
# - 不需要显式继承
# - 适合定义接口契约
# - 主要用于类型检查
# - 更灵活,与第三方代码兼容

class Sortable(Protocol):
def __lt__(self, other) -> bool: ...

def sort_items(items: list[Sortable]) -> list[Sortable]:
return sorted(items)

# ABC:名义子类型(传统OOP)
# - 需要显式继承
# - 可以提供默认实现
# - 运行时强制实现抽象方法
# - 适合框架内部的类层次

class Plugin(ABC):
@abstractmethod
def execute(self, context: dict) -> dict:
pass

def validate(self, context: dict) -> bool:
"""默认验证(子类可覆盖)"""
return bool(context)

@abstractmethod
def name(self) -> str:
pass


七、__subclasshook__

class MyIterable(ABC):
@classmethod
def __subclasshook__(cls, C):
"""自定义isinstance/issubclass的行为"""
if cls is MyIterable:
# 检查是否有__iter__方法
if any("__iter__" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented

class SimpleRange:
def __iter__(self):
return iter(range(10))

# 无需注册或继承
print(isinstance(SimpleRange(), MyIterable)) # True


八、混合使用Protocol和ABC

from typing import Protocol, runtime_checkable
from abc import ABC, abstractmethod

# 对外接口用Protocol(不强制继承)
@runtime_checkable
class StorageBackend(Protocol):
def get(self, key: str) -> bytes | None: ...
def set(self, key: str, value: bytes, ttl: int = 0) -> None: ...
def delete(self, key: str) -> None: ...

# 内部基类用ABC(提供共享实现)
class BaseStorage(ABC):
def __init__(self, prefix=""):
self.prefix = prefix

def _make_key(self, key):
return f"{self.prefix}:{key}" if self.prefix else key

@abstractmethod
def get(self, key: str) -> bytes | None:
pass

@abstractmethod
def set(self, key: str, value: bytes, ttl: int = 0) -> None:
pass

@abstractmethod
def delete(self, key: str) -> None:
pass

def exists(self, key: str) -> bool:
return self.get(key) is not None

class MemoryStorage(BaseStorage):
def __init__(self, prefix=""):
super().__init__(prefix)
self._data = {}

def get(self, key):
return self._data.get(self._make_key(key))

def set(self, key, value, ttl=0):
self._data[self._make_key(key)] = value

def delete(self, key):
self._data.pop(self._make_key(key), None)

# MemoryStorage同时满足Protocol和ABC
storage: StorageBackend = MemoryStorage(prefix="app")
print(isinstance(storage, StorageBackend)) # True


九、实际应用:插件系统

class PluginInterface(Protocol):
"""插件必须实现的接口"""
name: str
version: str

def initialize(self, config: dict) -> None: ...
def execute(self, data: dict) -> dict: ...
def cleanup(self) -> None: ...

class PluginManager:
def __init__(self):
self._plugins: dict[str, PluginInterface] = {}

def register(self, plugin: PluginInterface):
if not isinstance(plugin, PluginInterface):
raise TypeError(f"{type(plugin).__name__} 不满足PluginInterface协议")
self._plugins[plugin.name] = plugin

def initialize_all(self, config: dict):
for plugin in self._plugins.values():
plugin.initialize(config)

def execute_pipeline(self, data: dict) -> dict:
result = data
for plugin in self._plugins.values():
result = plugin.execute(result)
return result

# 实现插件(无需继承任何基类)
class LoggingPlugin:
name = "logging"
version = "1.0"

def initialize(self, config):
self.log_level = config.get('log_level', 'INFO')

def execute(self, data):
print(f"[{self.log_level}] 处理数据: {list(data.keys())}")
return data

def cleanup(self):
pass


十、总结

选择指南:
- 定义外部接口/契约 -> Protocol
- 需要默认实现 -> ABC
- 与第三方代码集成 -> Protocol + register
- 框架内部类层次 -> ABC
- 运行时类型检查 -> @runtime_checkable Protocol
- 纯静态类型检查 -> Protocol(不需要runtime_checkable)

设计原则:
1. 优先使用Protocol(更Pythonic,更灵活)
2. ABC适合需要共享实现的场景
3. 不要过度使用抽象,简单场景直接用鸭子类型
4. Protocol是对鸭子类型的形式化,不是对它的替代

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐