【AI大模型--NumPy-03】-NumPy 广播机制与通用函数完全指南 (Broadcasting & UFuncs)
·
03_broadcastingUfunc.py - 广播机制与通用函数完全指南
学习路径第 3 步 (共 10 步) | 难度:中级
概述
深入理解 NumPy 性能优势的根源:广播 (Broadcasting) 机制和通用函数 (ufunc)。这是编写高效 NumPy 代码的关键。
学习目标
- 掌握广播的 4 条规则,能判断任意两个数组是否可广播
- 理解 ufunc 的向量化运算原理
- 学会使用
reduce/accumulate/outer/einsum等 ufunc 高级操作 - 了解自定义 ufunc 的创建方法
核心内容 (6 个模块)
| 模块 | 核心知识点 |
|---|---|
| 1. 广播规则 | 形状兼容性判断、逐步对齐规则图解 |
| 2. 广播实战应用 | 数据归一化、距离矩阵计算 |
| 3. 一元/二元 ufunc | 数学运算、比较运算、逻辑运算 |
| 4. ufunc 高级用法 | reduce(规约)、accumulate(累积)、outer(外积) |
| 5. Einstein 求和 | np.einsum() 复杂张量运算的简洁写法 |
| 6. 自定义 ufunc | np.frompyfunc() / np.vectorize() |
代码学习
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
=====================================
NumPy 广播机制与通用函数完全指南 (Broadcasting & UFuncs)
=====================================
本案例深入介绍 NumPy 的两大核心引擎:
第一部分: 广播机制 (Broadcasting)
- 广播规则详解 (形状兼容性)
- 常见广播模式与可视化
- 广播的实际应用 (归一化、距离计算等)
第二部分: 通用函数 (Universal Functions / ufunc)
- 一元 ufunc (数学运算)
- 二元 ufunc (比较与逻辑)
- ufunc 的高级特性 (reduce/accumulate/outer/at)
- 自定义 ufunc
【核心价值】
广播 + ufunc 是 NumPy 性能优势的根本来源。
理解它们等于理解了为什么 NumPy 能比纯Python快几十倍。
作者:bloxed
日期:2026-05-19
"""
import numpy as np
def separator(title):
print(f"\n{'='*60}")
print(f" {title}")
print('='*60)
# ============================================================
# 第一部分: 广播机制 (Broadcasting)
# ============================================================
separator("一、广播机制 (Broadcasting) 原理")
print("""
┌──────────────────────────────────────────────────────────────┐
│ 广播规则 (Broadcasting Rules) │
├──────────────────────────────────────────────────────────────┤
│ │
│ 当两个数组形状不同时,NumPy 如何进行逐元素运算? │
│ │
│ 【规则】从最后一个维度开始,逐个向前匹配: │
│ │
│ 1. 维度相同 → 直接运算 │
│ 2. 其中一个为 1 → 扩展复制以匹配另一个 │
│ 3. 其中一个缺失 → 在前面补 1 再扩展 │
│ 4. 都不满足 → 报错! Shape mismatch │
│ │
│ [记忆法] 两个形状从右往左对齐,每个维度要么相等,要么有一方为1│
│ │
│ 示例: │
│ (3, 4) + (4,) → OK: (3,4)+(1,4)→(3,4) │
│ (3, 1) + (1, 4) → OK: (3,4)+ (3,4)→(3,4) │
│ (3, 4) + (3, ) → ERROR! 4 != 3 │
└──────────────────────────────────────────────────────────────┘
""")
# --- 规则演示 ---
print("【规则验证】\n")
# Case 1: 标量与数组 (最常见)
print("Case 1: 标量广播")
arr = np.array([[1, 2, 3], [4, 5, 6]]) # (2, 3)
result = arr + 10 # 10 被广播为 [[10,10,10],[10,10,10]]
print(f" (2,3) + scalar → {(arr + 10).shape}")
print(f" result:\n{result}")
# Case 2: (m,n) + (n,)
print("\nCase 2: 行向量广播 (常用: 归一化)")
matrix = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=float) # (3, 3)
row_vec = np.array([10, 20, 30]) # (3,)
result = matrix + row_vec # row_vec → (1,3) → (3,3)
print(f" (3,3) + (3,) → {result.shape}")
print(f" result:\n{result}")
# Case 3: (m,1) + (1,n) → (m,n) [外积模式]
print("\nCase 3: 列向量 + 行向量 → 矩阵 (外积/笛卡尔积模式)")
col = np.array([[1], [2], [3]]) # (3, 1)
row = np.array([10, 20]) # (2,)
outer_result = col * row # (3,1)*(2,) → (3,2)
print(f" (3,1) * (2,) → {outer_result.shape}")
print(f" result (每行为row倍数):\n{outer_result}")
# Case 4: 广播失败
print("\nCase 4: 不兼容的形状 → ERROR")
try:
bad = np.array([[1, 2], [3, 4]]) # (2, 2)
vec = np.array([1, 2, 3]) # (3,)
_ = bad + vec
except ValueError as e:
print(f" Error: {e}")
print(" 原因: (2,2) 和 (3,) 在最后维度不匹配 (2 vs 3)")
# ============================================================
# 第二部分: 广播的实际应用
# ============================================================
separator("二、广播的实际应用场景")
# 应用1: 数据归一化 (Z-Score Standardization)
print("\n【应用1】Z-Score 标准化 (每列独立)")
# 模拟数据: 5个样本, 3个特征
data = np.array([
[170, 65, 28], # 身高cm, 体重kg, 年龄
[180, 80, 35],
[160, 50, 22],
[175, 70, 30],
[165, 55, 25],
], dtype=float)
print(f"原始数据 (5样本 x 3特征):")
print(data)
means = data.mean(axis=0) # (3,) 每特征均值
stds = data.std(axis=0) # (3,) 每特征标准差
normalized = (data - means) / stds # 广播! (5,3)-(3,) / (3,)
print(f"\n各特征均值: {means.round(2)}")
print(f"各特征标准差: {stds.round(2)}")
print(f"\n标准化后数据 (均值≈0, 标准差≠1):")
print(normalized.round(3))
print(f"验证标准化后的均值: {normalized.mean(axis=0).round(10)}")
print(f"验证标准化后的标准差: {normalized.std(axis=0)}")
# 应用2: 距离计算 (欧氏距离矩阵)
print("\n\n【应用2】欧氏距离矩阵 (点对点距离)")
points = np.array([
[0, 0], # A: 原点
[3, 4], # B: (3,4)
[1, 1], # C: (1,1)
[-2, 1], # D: (-2,1)
])
print(f"点位:\n{points}")
# 利用广播计算所有点对的距离
# ||a-b||^2 = (a1-b1)^2 + (a2-b2)^2
diff = points[:, np.newaxis, :] - points[np.newaxis, :, :] # (4,1,2) - (1,4,2) = (4,4,2)
dist_matrix = np.sqrt((diff ** 2).sum(axis=2))
print(f"\n距离矩阵 (对称阵):")
labels = ['A(0,0)', 'B(3,4)', 'C(1,1)', 'D(-2,1)']
print(f"{'':>12}", end='')
for l in labels:
print(f"{l:>10}", end='')
print()
for i, label_i in enumerate(labels):
print(f"{label_i:>12}", end='')
for j in range(len(labels)):
print(f"{dist_matrix[i, j]:>10.2f}", end='')
print()
# 应用3: 图像处理中的广播
print("\n\n【应用3】图像亮度调整 (广播到RGB三通道)")
img = np.random.randint(0, 256, (4, 4, 3), dtype=np.uint8) # 模拟小图像
brightness_factor = np.float32(1.2) # 提亮 20%
adjusted = np.clip(img.astype(np.float32) * brightness_factor, 0, 255).astype(np.uint8)
print(f"原图均值: {img.mean():.1f}, 调整后均值: {adjusted.mean():.1f}")
print(f"(标量广播到 H×W×C 全部像素)")
# ============================================================
# 第三部分: 通用函数 (UFuncs) —— 一元函数
# ============================================================
separator("三、通用函数 (ufunc): 一元运算")
print("""
【什么是 ufunc?】
Universal Function — 能够对数组的每个元素逐个操作的函数。
NumPy 内置了大量的 ufunc,底层用 C 语言实现,极快。
特点:
- 输入: 0~N 个数组 (标量视为0-d数组)
- 输出: 一个数组 (或多个数组)
- 支持 broadcast (自动形状适配)
""")
arr = np.array([-2, -1, 0, 1, 2, 3.14, np.inf, np.nan])
print(f"测试数组: {arr}\n")
# 数学类 ufunc
print("【数学类 ufunc】")
print(f" abs(x) = {np.abs(arr)}")
print(f" sqrt(|x|) = {np.sqrt(np.abs(arr))}")
print(f" square(x) = {np.square(arr)}")
print(f" exp(x) = {np.exp(np.clip(arr, -5, 5))}") # clip防止溢出
print(f" log(|x|+1) = {np.log(np.abs(arr) + 1)}")
# 三角函数
angles = np.linspace(0, 2*np.pi, 8)
print(f"\n角度序列: {angles.round(3)}")
print(f" sin = {np.sin(angles).round(4)}")
print(f" cos = {np.cos(angles).round(4)}")
# 取整类
decimals = np.array([1.2, 2.7, -1.5, 3.14, -2.8])
print(f"\n取整测试: {decimals}")
print(f" floor (向下取整) = {np.floor(decimals)}")
print(f" ceil (向上取整) = {np.ceil(decimals)}")
print(f" round (四舍五入) = {np.round(decimals)}")
print(f" trunc (截断小数) = {np.trunc(decimals)}")
# 比较与逻辑类
print("\n【比较 & 逻辑类 ufunc】")
a = np.array([1, 2, 3, 4, 5])
b = np.array([3, 2, 5, 4, 1])
print(f" a = {a}")
print(f" b = {b}")
print(f" a > b = {np.greater(a, b)}")
print(f" a == b = {np.equal(a, b)}")
print(f" np.maximum(a,b) = {np.maximum(a, b)} # 逐元素取最大")
print(f" np.minimum(a,b) = {np.minimum(a, b)} # 逐元素取最小")
# 逻辑运算
bool_a = np.array([True, True, False, False])
bool_b = np.array([True, False, True, False])
print(f"\n bool_a = {bool_a}")
print(f" bool_b = {bool_b}")
print(f" logical_and = {np.logical_and(bool_a, bool_b)}")
print(f" logical_or = {np.logical_or(bool_a, bool_b)}")
print(f" logical_not(bool_a) = {np.logical_not(bool_a)}")
# ============================================================
# 第四部分: UFuncs 高级特性
# ============================================================
separator("四、ufunc 高级方法: reduce / accumulate / outer / at")
arr = np.array([3, 1, 4, 1, 5, 9, 2, 6])
print(f"操作数组: {arr}\n")
# reduce: 沿轴递归应用二元操作
print("【reduce】沿轴降维折叠")
print(f" np.add.reduce(arr) = {np.add.reduce(arr)} # 等价 sum()")
print(f" np.multiply.reduce = {np.multiply.reduce(arr)} # 等价 prod()")
# accumulate: 显示每步累积结果
print(f"\n【accumulate】显示累积过程")
print(f" add.accumulate = {np.add.accumulate(arr)} # 累加和 [3,4,8,...]")
print(f" multiply.accumulate= {np.multiply.accumulate(arr)} # 累积乘 [3,3,12,...]")
# outer: 外积 (两两配对)
print(f"\n【outer】外积运算")
a = np.array([1, 2, 3])
b = np.array([10, 20, 30])
print(f" a = {a}, b = {b}")
print(f" np.add.outer(a, b):\n{np.add.outer(a, b)}")
print(f" np.multiply.outer(a, b):\n{np.multiply.outer(a, b)}")
print(" [解释] outer 将两个数组的每个元素两两配对运算")
# at: 无缓冲的原地操作
print(f"\n【at】原地无缓冲操作")
bins = np.zeros(10, dtype=int)
indices = np.array([0, 3, 1, 3, 2, 3, 0, 5])
np.add.at(bins, indices, 1) # 在指定位置累加
print(f" 索引计数: indices = {indices}")
print(f" 统计直方图: bins = {bins}")
print(f" [!] at 解决了重复索引被覆盖的问题 (普通赋值会丢失)")
# ============================================================
# 第五部分: 自定义 ufunc (可选了解)
# ============================================================
separator("五、自定义 ufunc (进阶)")
print("""
除了内置 ufunc,NumPy 还支持创建自定义 ufunc。
方式一: frompyfunc (简单包装 Python 函数,速度一般)
方式二: Numba/C 扩展 (高性能,需额外安装)
""")
# frompyfunc 示例
def custom_op(x):
"""自定义操作: 返回 x^2 如果 x>0, 否则返回 0"""
return x**2 if x > 0 else 0
# 创建 ufunc: nin=1个输入, nout=1个输出
custom_ufunc = np.frompyfunc(custom_op, 1, 1)
test_vals = np.array([-3, -1, 0, 2, 4])
result = custom_ufunc(test_vals)
print(f" 输入: {test_vals}")
print(f" 自定义ufunc输出: {result}")
print(f" 类型: {type(result[0])} [注意: frompyfunc 返回 object 类型]")
# ============================================================
# 第六部分: 综合实战 —— 向量化替代循环
# ============================================================
separator("六、实战对比: 循环 vs 广播+ufunc")
print("任务: 计算 10000 个二维点到原点的距离\n")
np.random.seed(123)
points = np.random.randn(10000, 2)
# 方式1: Python 循环 (慢)
import time
start = time.perf_counter()
distances_loop = []
for p in points:
distances_loop.append(np.sqrt(p[0]**2 + p[1]**2))
loop_time = time.perf_counter() - start
# 方式2: NumPy 广播+ufunc (快)
start = time.perf_counter()
distances_np = np.sqrt(points[:, 0]**2 + points[:, 1]**2)
numpy_time = time.perf_counter() - start
# 方式3: 更优写法 (利用 np.linalg.norm)
start = time.perf_counter()
distances_norm = np.linalg.norm(points, axis=1)
norm_time = time.perf_counter() - start
print(f"{'方法':<25} {'耗时(ms)':<12} {'加速比':<10}")
print("-" * 47)
print(f"{'Python 循环':<22} {loop_time*1000:<12.3f} {'1.0x':<10}")
print(f"{'NumPy 广播':<22} {numpy_time*1000:<12.3f} {loop_time/max(numpy_time, 1e-9):<10.1f}x")
print(f"{'np.linalg.norm':<19} {norm_time*1000:<12.3f} {loop_time/max(norm_time, 1e-9):<10.1f}x")
# 验证一致性
print(f"\n结果一致: {np.allclose(distances_loop, distances_np)}")
print(f"结果一致: {np.allclose(distances_loop, distances_norm)}")
# ============================================================
# 总结
# ============================================================
separator("总结: 广播 & ufunc 核心要点")
summary = """
+------------------------------------------------------------+
| 广播 (Broadcasting) |
| |
| - 标量自动扩展到任意形状 |
| - (m,n) + (n,) → (m,n) [行向量广播] |
| - (m,1) + (1,n) → (m,n) [外积模式] |
| - 规则: 从右往左对齐,维度相等 或 有一方为1 |
| |
| 常用 ufunc |
| |
| [数学] abs/sqrt/exp/log/sin/cos/floor/ceil/round |
| [算术] add/subtract/multiply/divide/power/modulo |
| [比较] greater/less/equal/not_equal |
| [逻辑] logical_and/or/not |
| [极值] maximum/minimum/fmax/fmin |
| |
| ufunc 高级方法 |
| |
| reduce() → 折叠为一维 (如求和/求积) |
| accumulate() → 显示每步累积 |
| outer() → 两两外积 |
| at() → 原地操作 (解决重复索引问题) |
| |
| [核心原则] 尽可能用 ufunc+广播 替代显式 for 循环! |
+------------------------------------------------------------+
"""
print(summary)
print("\n运行完毕!")
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)