03_broadcastingUfunc.py - 广播机制与通用函数完全指南

学习路径第 3 步 (共 10 步) | 难度:中级

概述

深入理解 NumPy 性能优势的根源:广播 (Broadcasting) 机制和通用函数 (ufunc)。这是编写高效 NumPy 代码的关键。

学习目标

  • 掌握广播的 4 条规则,能判断任意两个数组是否可广播
  • 理解 ufunc 的向量化运算原理
  • 学会使用 reduce / accumulate / outer / einsum 等 ufunc 高级操作
  • 了解自定义 ufunc 的创建方法

核心内容 (6 个模块)

模块 核心知识点
1. 广播规则 形状兼容性判断、逐步对齐规则图解
2. 广播实战应用 数据归一化、距离矩阵计算
3. 一元/二元 ufunc 数学运算、比较运算、逻辑运算
4. ufunc 高级用法 reduce(规约)、accumulate(累积)、outer(外积)
5. Einstein 求和 np.einsum() 复杂张量运算的简洁写法
6. 自定义 ufunc np.frompyfunc() / np.vectorize()

代码学习

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
=====================================
NumPy 广播机制与通用函数完全指南 (Broadcasting & UFuncs)
=====================================

本案例深入介绍 NumPy 的两大核心引擎:

第一部分: 广播机制 (Broadcasting)
  - 广播规则详解 (形状兼容性)
  - 常见广播模式与可视化
  - 广播的实际应用 (归一化、距离计算等)

第二部分: 通用函数 (Universal Functions / ufunc)
  - 一元 ufunc (数学运算)
  - 二元 ufunc (比较与逻辑)
  - ufunc 的高级特性 (reduce/accumulate/outer/at)
  - 自定义 ufunc

【核心价值】
  广播 + ufunc 是 NumPy 性能优势的根本来源。
  理解它们等于理解了为什么 NumPy 能比纯Python快几十倍。

作者:bloxed
日期:2026-05-19
"""

import numpy as np


def separator(title):
    print(f"\n{'='*60}")
    print(f"  {title}")
    print('='*60)


# ============================================================
# 第一部分: 广播机制 (Broadcasting)
# ============================================================
separator("一、广播机制 (Broadcasting) 原理")

print("""
┌──────────────────────────────────────────────────────────────┐
│                    广播规则 (Broadcasting Rules)              │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  当两个数组形状不同时,NumPy 如何进行逐元素运算?              │
│                                                              │
│  【规则】从最后一个维度开始,逐个向前匹配:                   │
│                                                              │
│    1. 维度相同 → 直接运算                                   │
│    2. 其中一个为 1 → 扩展复制以匹配另一个                    │
│    3. 其中一个缺失 → 在前面补 1 再扩展                       │
│    4. 都不满足 → 报错! Shape mismatch                       │
│                                                              │
│  [记忆法] 两个形状从右往左对齐,每个维度要么相等,要么有一方为1│
│                                                              │
│  示例:                                                       │
│    (3, 4) + (4,)     → OK: (3,4)+(1,4)→(3,4)               │
│    (3, 1) + (1, 4)   → OK: (3,4)+    (3,4)→(3,4)           │
│    (3, 4) + (3, )     → ERROR! 4 != 3                       │
└──────────────────────────────────────────────────────────────┘
""")

# --- 规则演示 ---
print("【规则验证】\n")

# Case 1: 标量与数组 (最常见)
print("Case 1: 标量广播")
arr = np.array([[1, 2, 3], [4, 5, 6]])  # (2, 3)
result = arr + 10  # 10 被广播为 [[10,10,10],[10,10,10]]
print(f"  (2,3) + scalar → {(arr + 10).shape}")
print(f"  result:\n{result}")

# Case 2: (m,n) + (n,)
print("\nCase 2: 行向量广播 (常用: 归一化)")
matrix = np.array([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]], dtype=float)  # (3, 3)
row_vec = np.array([10, 20, 30])  # (3,)
result = matrix + row_vec  # row_vec → (1,3) → (3,3)
print(f"  (3,3) + (3,) → {result.shape}")
print(f"  result:\n{result}")

# Case 3: (m,1) + (1,n) → (m,n) [外积模式]
print("\nCase 3: 列向量 + 行向量 → 矩阵 (外积/笛卡尔积模式)")
col = np.array([[1], [2], [3]])  # (3, 1)
row = np.array([10, 20])          # (2,)
outer_result = col * row  # (3,1)*(2,) → (3,2)
print(f"  (3,1) * (2,) → {outer_result.shape}")
print(f"  result (每行为row倍数):\n{outer_result}")

# Case 4: 广播失败
print("\nCase 4: 不兼容的形状 → ERROR")
try:
    bad = np.array([[1, 2], [3, 4]])  # (2, 2)
    vec = np.array([1, 2, 3])          # (3,)
    _ = bad + vec
except ValueError as e:
    print(f"  Error: {e}")
    print("  原因: (2,2) 和 (3,) 在最后维度不匹配 (2 vs 3)")


# ============================================================
# 第二部分: 广播的实际应用
# ============================================================
separator("二、广播的实际应用场景")

# 应用1: 数据归一化 (Z-Score Standardization)
print("\n【应用1】Z-Score 标准化 (每列独立)")

# 模拟数据: 5个样本, 3个特征
data = np.array([
    [170, 65, 28],  # 身高cm, 体重kg, 年龄
    [180, 80, 35],
    [160, 50, 22],
    [175, 70, 30],
    [165, 55, 25],
], dtype=float)

print(f"原始数据 (5样本 x 3特征):")
print(data)

means = data.mean(axis=0)   # (3,) 每特征均值
stds = data.std(axis=0)     # (3,) 每特征标准差
normalized = (data - means) / stds  # 广播! (5,3)-(3,) / (3,)

print(f"\n各特征均值: {means.round(2)}")
print(f"各特征标准差: {stds.round(2)}")
print(f"\n标准化后数据 (均值≈0, 标准差≠1):")
print(normalized.round(3))
print(f"验证标准化后的均值: {normalized.mean(axis=0).round(10)}")
print(f"验证标准化后的标准差: {normalized.std(axis=0)}")

# 应用2: 距离计算 (欧氏距离矩阵)
print("\n\n【应用2】欧氏距离矩阵 (点对点距离)")

points = np.array([
    [0, 0],    # A: 原点
    [3, 4],    # B: (3,4)
    [1, 1],    # C: (1,1)
    [-2, 1],   # D: (-2,1)
])

print(f"点位:\n{points}")

# 利用广播计算所有点对的距离
# ||a-b||^2 = (a1-b1)^2 + (a2-b2)^2
diff = points[:, np.newaxis, :] - points[np.newaxis, :, :]  # (4,1,2) - (1,4,2) = (4,4,2)
dist_matrix = np.sqrt((diff ** 2).sum(axis=2))

print(f"\n距离矩阵 (对称阵):")
labels = ['A(0,0)', 'B(3,4)', 'C(1,1)', 'D(-2,1)']
print(f"{'':>12}", end='')
for l in labels:
    print(f"{l:>10}", end='')
print()
for i, label_i in enumerate(labels):
    print(f"{label_i:>12}", end='')
    for j in range(len(labels)):
        print(f"{dist_matrix[i, j]:>10.2f}", end='')
    print()

# 应用3: 图像处理中的广播
print("\n\n【应用3】图像亮度调整 (广播到RGB三通道)")
img = np.random.randint(0, 256, (4, 4, 3), dtype=np.uint8)  # 模拟小图像
brightness_factor = np.float32(1.2)  # 提亮 20%
adjusted = np.clip(img.astype(np.float32) * brightness_factor, 0, 255).astype(np.uint8)
print(f"原图均值: {img.mean():.1f}, 调整后均值: {adjusted.mean():.1f}")
print(f"(标量广播到 H×W×C 全部像素)")


# ============================================================
# 第三部分: 通用函数 (UFuncs) —— 一元函数
# ============================================================
separator("三、通用函数 (ufunc): 一元运算")

print("""
【什么是 ufunc?】
  Universal Function — 能够对数组的每个元素逐个操作的函数。
  NumPy 内置了大量的 ufunc,底层用 C 语言实现,极快。

  特点:
  - 输入: 0~N 个数组 (标量视为0-d数组)
  - 输出: 一个数组 (或多个数组)
  - 支持 broadcast (自动形状适配)
""")

arr = np.array([-2, -1, 0, 1, 2, 3.14, np.inf, np.nan])
print(f"测试数组: {arr}\n")

# 数学类 ufunc
print("【数学类 ufunc】")
print(f"  abs(x)     = {np.abs(arr)}")
print(f"  sqrt(|x|)  = {np.sqrt(np.abs(arr))}")
print(f"  square(x)  = {np.square(arr)}")
print(f"  exp(x)     = {np.exp(np.clip(arr, -5, 5))}")  # clip防止溢出
print(f"  log(|x|+1) = {np.log(np.abs(arr) + 1)}")

# 三角函数
angles = np.linspace(0, 2*np.pi, 8)
print(f"\n角度序列: {angles.round(3)}")
print(f"  sin = {np.sin(angles).round(4)}")
print(f"  cos = {np.cos(angles).round(4)}")

# 取整类
decimals = np.array([1.2, 2.7, -1.5, 3.14, -2.8])
print(f"\n取整测试: {decimals}")
print(f"  floor (向下取整) = {np.floor(decimals)}")
print(f"  ceil  (向上取整) = {np.ceil(decimals)}")
print(f"  round (四舍五入) = {np.round(decimals)}")
print(f"  trunc (截断小数) = {np.trunc(decimals)}")

# 比较与逻辑类
print("\n【比较 & 逻辑类 ufunc】")
a = np.array([1, 2, 3, 4, 5])
b = np.array([3, 2, 5, 4, 1])
print(f"  a = {a}")
print(f"  b = {b}")
print(f"  a > b  = {np.greater(a, b)}")
print(f"  a == b = {np.equal(a, b)}")
print(f"  np.maximum(a,b) = {np.maximum(a, b)}  # 逐元素取最大")
print(f"  np.minimum(a,b) = {np.minimum(a, b)}  # 逐元素取最小")

# 逻辑运算
bool_a = np.array([True, True, False, False])
bool_b = np.array([True, False, True, False])
print(f"\n  bool_a = {bool_a}")
print(f"  bool_b = {bool_b}")
print(f"  logical_and = {np.logical_and(bool_a, bool_b)}")
print(f"  logical_or  = {np.logical_or(bool_a, bool_b)}")
print(f"  logical_not(bool_a) = {np.logical_not(bool_a)}")


# ============================================================
# 第四部分: UFuncs 高级特性
# ============================================================
separator("四、ufunc 高级方法: reduce / accumulate / outer / at")

arr = np.array([3, 1, 4, 1, 5, 9, 2, 6])
print(f"操作数组: {arr}\n")

# reduce: 沿轴递归应用二元操作
print("【reduce】沿轴降维折叠")
print(f"  np.add.reduce(arr)  = {np.add.reduce(arr)}      # 等价 sum()")
print(f"  np.multiply.reduce = {np.multiply.reduce(arr)}   # 等价 prod()")

# accumulate: 显示每步累积结果
print(f"\n【accumulate】显示累积过程")
print(f"  add.accumulate      = {np.add.accumulate(arr)}    # 累加和 [3,4,8,...]")
print(f"  multiply.accumulate= {np.multiply.accumulate(arr)} # 累积乘 [3,3,12,...]")

# outer: 外积 (两两配对)
print(f"\n【outer】外积运算")
a = np.array([1, 2, 3])
b = np.array([10, 20, 30])
print(f"  a = {a}, b = {b}")
print(f"  np.add.outer(a, b):\n{np.add.outer(a, b)}")
print(f"  np.multiply.outer(a, b):\n{np.multiply.outer(a, b)}")
print("  [解释] outer 将两个数组的每个元素两两配对运算")

# at: 无缓冲的原地操作
print(f"\n【at】原地无缓冲操作")
bins = np.zeros(10, dtype=int)
indices = np.array([0, 3, 1, 3, 2, 3, 0, 5])
np.add.at(bins, indices, 1)  # 在指定位置累加
print(f"  索引计数: indices = {indices}")
print(f"  统计直方图: bins = {bins}")
print(f"  [!] at 解决了重复索引被覆盖的问题 (普通赋值会丢失)")


# ============================================================
# 第五部分: 自定义 ufunc (可选了解)
# ============================================================
separator("五、自定义 ufunc (进阶)")

print("""
除了内置 ufunc,NumPy 还支持创建自定义 ufunc。

方式一: frompyfunc (简单包装 Python 函数,速度一般)
方式二: Numba/C 扩展 (高性能,需额外安装)
""")

# frompyfunc 示例
def custom_op(x):
    """自定义操作: 返回 x^2 如果 x>0, 否则返回 0"""
    return x**2 if x > 0 else 0

# 创建 ufunc: nin=1个输入, nout=1个输出
custom_ufunc = np.frompyfunc(custom_op, 1, 1)

test_vals = np.array([-3, -1, 0, 2, 4])
result = custom_ufunc(test_vals)
print(f"  输入: {test_vals}")
print(f"  自定义ufunc输出: {result}")
print(f"  类型: {type(result[0])}  [注意: frompyfunc 返回 object 类型]")


# ============================================================
# 第六部分: 综合实战 —— 向量化替代循环
# ============================================================
separator("六、实战对比: 循环 vs 广播+ufunc")

print("任务: 计算 10000 个二维点到原点的距离\n")

np.random.seed(123)
points = np.random.randn(10000, 2)

# 方式1: Python 循环 (慢)
import time
start = time.perf_counter()
distances_loop = []
for p in points:
    distances_loop.append(np.sqrt(p[0]**2 + p[1]**2))
loop_time = time.perf_counter() - start

# 方式2: NumPy 广播+ufunc (快)
start = time.perf_counter()
distances_np = np.sqrt(points[:, 0]**2 + points[:, 1]**2)
numpy_time = time.perf_counter() - start

# 方式3: 更优写法 (利用 np.linalg.norm)
start = time.perf_counter()
distances_norm = np.linalg.norm(points, axis=1)
norm_time = time.perf_counter() - start

print(f"{'方法':<25} {'耗时(ms)':<12} {'加速比':<10}")
print("-" * 47)
print(f"{'Python 循环':<22} {loop_time*1000:<12.3f} {'1.0x':<10}")
print(f"{'NumPy 广播':<22} {numpy_time*1000:<12.3f} {loop_time/max(numpy_time, 1e-9):<10.1f}x")
print(f"{'np.linalg.norm':<19} {norm_time*1000:<12.3f} {loop_time/max(norm_time, 1e-9):<10.1f}x")

# 验证一致性
print(f"\n结果一致: {np.allclose(distances_loop, distances_np)}")
print(f"结果一致: {np.allclose(distances_loop, distances_norm)}")


# ============================================================
# 总结
# ============================================================
separator("总结: 广播 & ufunc 核心要点")

summary = """
+------------------------------------------------------------+
|  广播 (Broadcasting)                                        |
|                                                            |
|  - 标量自动扩展到任意形状                                   |
|  - (m,n) + (n,)   → (m,n)  [行向量广播]                   |
|  - (m,1) + (1,n)  → (m,n)  [外积模式]                    |
|  - 规则: 从右往左对齐,维度相等 或 有一方为1               |
|                                                            |
|  常用 ufunc                                                |
|                                                            |
|  [数学] abs/sqrt/exp/log/sin/cos/floor/ceil/round         |
|  [算术] add/subtract/multiply/divide/power/modulo         |
|  [比较] greater/less/equal/not_equal                       |
|  [逻辑] logical_and/or/not                                 |
|  [极值] maximum/minimum/fmax/fmin                          |
|                                                            |
|  ufunc 高级方法                                            |
|                                                            |
|  reduce()     → 折叠为一维 (如求和/求积)                   |
|  accumulate() → 显示每步累积                               |
|  outer()      → 两两外积                                   |
|  at()         → 原地操作 (解决重复索引问题)                 |
|                                                            |
|  [核心原则] 尽可能用 ufunc+广播 替代显式 for 循环!          |
+------------------------------------------------------------+
"""
print(summary)

print("\n运行完毕!")

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐