目录

一、问题背景

二、遇到的问题

1. 数据传输效率低下

2. Python处理开销大

3. 模型加载重复

三、问题根因

四、解决方案

技术实现

原理

执行位置:计算向数据移动

四大高效引擎

五、适用场景

六、扩展优化


一、问题背景

在最近项目中,需要将深度学习算法部署到spark集群中运行,部署过程中遇到一个坑,在此记录分享一下。

我是用pandas_udf的方式进行ONNX推理时报错,报错原因:

  • ONNX 模型推理时报错,输入数据形状不匹配

  • 批量预测时数据维度混乱

二、遇到的问题

在分布式机器学习推理场景中,PySpark需要将预处理后的特征数据传输到Python UDF中进行模型推理。原始特征数据通常是多维时间序列(如形状为(batch_size, sequence_length, feature_size)),在Spark和Python之间传输时遇到性能瓶颈。

1. 数据传输效率低下

  • 嵌套数组序列化慢:复杂嵌套结构(如array<array<double>>)在JVM和Python间传输时需要递归序列化

  • 内存占用高:字符串序列化方案产生额外编码开销

  • 网络带宽浪费:非紧凑的数据格式增加传输量

2. Python处理开销大

  • 逐条类型检查:传统UDF对每行数据单独处理,Python调用开销显著

  • 循环效率低:无法利用NumPy向量化操作

  • 内存碎片化:非连续内存布局降低CPU缓存命中率

3. 模型加载重复

  • 每个任务或批次可能重复加载ONNX模型,造成资源浪费

三、问题根因

ONNX 模型期望输入格式:(batch_size, seq_length=6, input_size=4) 的 float32 数组

但问题出在 PySpark Arrow 传递复杂嵌套类型时的数据格式:

  1. PySpark 中 feature_seq 列类型是 array<array<double>>

  2. 通过 pandas_udf 传递时,数据可能被序列化为字符串形式或包含 Row 对象

  3. 直接用 np.array() 转换时失败,因为元素不是纯数值

性能瓶颈点

数据传输瓶颈 → 序列化/反序列化 → Python处理瓶颈 → 模型推理瓶颈

四、解决方案

方案1: UDF中处理复杂类型

方案2: 字符串序列化

方案3: 展平为一维数组

Arrow传输效率

❌ 低 (嵌套数组序列化慢)

✅ 高 (字符串原生支持)

✅ 最高 (一维数组最优)

内存占用

中等

较高 (字符串开销)

最低

Python处理开销

高 (逐条类型检查)

中 (字符串解析)

低 (直接reshape)

网络传输

中等

较高

最低

评估下来,方案3:展平为一维数组传输是最合适的。

下面具体分析下方案三:

技术实现

# 关键步骤1:在Spark端展平
df.withColumn(
    "feature_flat",
    expr("""
        flatten(
            transform(
                sequence(0, {seq_length-1}), 
                i -> array(Tdb_list[i], Te_list[i], C_list[i], F_list[i])
            )
        )
    """)
)

# 关键步骤2:在UDF中还原
batch_array = np.array(valid_data, dtype=np.float32).reshape(-1, SEQUENCE_LENGTH, INPUT_SIZE)

举例

┌─────────────────────────────────────────────────────────────────────────────┐
│                           完整数据流转过程                                    │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  原始格式 (在PySpark中):                                                     │
│  feature_seq = [[32.0, 25.0, 68.0, 600.0],   ← 时间步 t-5                   │
│                 [31.5, 17.0, 75.0, 600.0],   ← 时间步 t-4                   │
│                 [30.5, 16.0, 75.0, 600.0],   ← 时间步 t-3                   │
│                 [29.5, 16.0, 75.0, 600.0],   ← 时间步 t-2                   │
│                 [29.0, 15.0, 75.0, 600.0],   ← 时间步 t-1                   │
│                 [28.5, 15.0, 75.0, 600.0]]   ← 时间步 t (当前)              │
│                                                                             │
│                         ↓  展平 (flatten) 只为传输                          │
│                                                                             │
│  展平格式 (Arrow传输):                                                       │
│  feature_flat = [32.0, 25.0, 68.0, 600.0, 31.5, 17.0, 75.0, 600.0,         │
│                  30.5, 16.0, 75.0, 600.0, 29.5, 16.0, 75.0, 600.0,         │
│                  29.0, 15.0, 75.0, 600.0, 28.5, 15.0, 75.0, 600.0]         │
│                                                                             │
│                         ↓  reshape 还原形状                                 │
│                                                                             │
│  模型输入格式 (ONNX推理时):                                                   │
│  batch_array.shape = (batch_size, 6, 4)                                     │
│  ┌─────────────────────────────────┐                                        │
│  │ [[32.0, 25.0, 68.0, 600.0],    │  ← 时间步 t-5                           │
│  │  [31.5, 17.0, 75.0, 600.0],    │  ← 时间步 t-4                           │
│  │  [30.5, 16.0, 75.0, 600.0],    │  ← 时间步 t-3                           │
│  │  [29.5, 16.0, 75.0, 600.0],    │  ← 时间步 t-2                           │
│  │  [29.0, 15.0, 75.0, 600.0],    │  ← 时间步 t-1                           │
│  │  [28.5, 15.0, 75.0, 600.0]]    │  ← 时间步 t (当前)                      │
│  └─────────────────────────────────┘                                        │
│                                                                             │
│  ✅ 与原始格式完全一致,模型结果不受任何影响                                   │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│                    Arrow 数据传输性能                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  array<array<double>>    ────►  慢 (需要递归序列化)              │
│                                                                 │
│  string                  ────►  中等 (字符串编码开销)             │
│                                                                 │
│  array<double>           ────►  快 (Arrow原生支持,零拷贝)        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

原理

在了解了技术实现之后,来看下技术原理,即“为什么”这么做有效

执行位置:计算向数据移动

  • Driver端:仅负责任务解析与调度,不执行UDF代码,避免了单点瓶颈。

  • Executor端:真正的算力所在。每个Executor节点持有独立的数据分区,predict_udf直接在节点内存中执行,遵循“数据本地性”原则,避免了数据在网络中来回传输。

┌─────────────────────────────────────────────────────────────────────────────────────┐
│                           Spark 分布式执行架构                                        │
├─────────────────────────────────────────────────────────────────────────────────────┤
│                                                                                     │
│  ┌─────────────────────────────────────────────────────────────────────────────┐   │
│  │                          Driver (主节点)                                     │   │
│  │                                                                             │   │
│  │   • 解析代码,生成执行计划                                                    │   │
│  │   • 调度任务到 Executor                                                      │   │
│  │   • 收集最终结果                                                             │   │
│  │   • ❌ 不执行 UDF 代码                                                       │   │
│  └─────────────────────────────────────────────────────────────────────────────┘   │
│                                       │                                             │
│                    ┌──────────────────┼──────────────────┐                         │
│                    │                  │                  │                         │
│                    ▼                  ▼                  ▼                         │
│  ┌─────────────────────┐  ┌─────────────────────┐  ┌─────────────────────┐        │
│  │   Executor 1        │  │   Executor 2        │  │   Executor N        │        │
│  │   (工作节点)         │  │   (工作节点)         │  │   (工作节点)         │        │
│  │                     │  │                     │  │                     │        │
│  │  ┌───────────────┐  │  │  ┌───────────────┐  │  │  ┌───────────────┐  │        │
│  │  │  Partition 1  │  │  │  │  Partition 2  │  │  │  │  Partition N  │  │        │
│  │  │  (数据分区)    │  │  │  │  (数据分区)    │  │  │  │  (数据分区)    │  │        │
│  │  │  10000 rows   │  │  │  │  10000 rows   │  │  │  │  10000 rows   │  │        │
│  │  └───────┬───────┘  │  │  └───────┬───────┘  │  │  └───────┬───────┘  │        │
│  │          │          │  │          │          │  │          │          │        │
│  │          ▼          │  │          ▼          │  │          ▼          │        │
│  │  ┌───────────────┐  │  │  ┌───────────────┐  │  │  ┌───────────────┐  │        │
│  │  │ predict_udf   │  │  │  │ predict_udf   │  │  │  │ predict_udf   │  │        │
│  │  │ ✅ 在此执行    │  │  │  │ ✅ 在此执行    │  │  │  │ ✅ 在此执行    │  │        │
│  │  │               │  │  │  │               │  │  │  │               │  │        │
│  │  │ ONNX Session  │  │  │  │ ONNX Session  │  │  │  │ ONNX Session  │  │        │
│  │  │ (单例模式)     │  │  │  │ (单例模式)     │  │  │  │ (单例模式)     │  │        │
│  │  └───────────────┘  │  │  └───────────────┘  │  │  └───────────────┘  │        │
│  │                     │  │                     │  │                     │        │
│  └─────────────────────┘  └─────────────────────┘  └─────────────────────┘        │
│                                                                                     │
│   并行执行:所有 Executor 同时处理各自的分区,互不干扰                                  │
│                                                                                     │
└─────────────────────────────────────────────────────────────────────────────────────┘

四大高效引擎

Arrow零拷贝:JVM数据直接映射为Python内存对象,省去序列化步骤。

┌─────────────────────────────────────────────────────────────────┐
│                 传统序列化 vs Arrow                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  传统方式 (pickle):                                             │
│  ┌─────────┐    序列化    ┌─────────┐    反序列化    ┌───────┐ │
│  │ JVM数据 │ ──────────► │ 网络传输 │ ──────────► │Python │ │
│  │ (Row)   │   慢!      │  字节流  │   慢!      │ 对象   │ │
│  └─────────┘             └─────────┘             └───────┘ │
│                                                                 │
│  Arrow 方式:                                                    │
│  ┌─────────┐              ┌─────────┐              ┌───────┐ │
│  │ JVM数据 │   直接共享   │   内存   │   零拷贝    │ Pandas│ │
│  │ (列式)  │ ──────────► │  映射    │ ──────────► │Series │ │
│  └─────────┘    快!     └─────────┘    快!      └───────┘ │
│                                                                 │
│  性能提升: 10-50倍                                              │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

  向量化批量处理:pandas_udf一次处理成千上万行数据

# ❌ 传统 UDF - 逐行处理 (慢)
@udf(returnType=DoubleType())
def slow_udf(feature):
    # 每行都要:Python调用开销 + 模型推理
    return model.predict(feature)  # 10000次调用 = 10000次推理

# ✅ pandas_udf - 批量处理 (快)
@pandas_udf(returnType=DoubleType())
def fast_udf(features: pd.Series):
    # 一次处理10000行
    batch = np.array(features.tolist()).reshape(-1, 6, 4)
    return model.predict(batch)  # 1次调用 = 1次批量推理

单例模式:ONNX模型在每个Executor进程启动时仅加载一次,后续推理复用该实例,消除IO开销。

分布式并行:所有Executor同时处理各自分区,线性扩展计算能力。

关键优势:

  1. Arrow 零拷贝传输:一维 array<double> 是 Arrow 原生类型,可直接映射到 pandas/numpy

  2. 向量化处理:避免了 Python 循环,充分利用 NumPy 向量化操作

  3. 内存连续:连续内存布局,CPU 缓存命中率高

  4. 网络带宽低:数据紧凑,传输量最小

性能提升预估:相比方案1可提升30-50%的整体吞吐量。

五、适用场景

  • ✅ 分布式机器学习推理

  • ✅ 大规模特征数据传输

  • ✅ 时间序列或图像等多维数据

  • ✅ Spark + Python混合架构

六、扩展优化

# 进一步优化:使用PySpark的向量化UDF
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, DoubleType

# 使用更高效的Arrow格式
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐