数据加载与预处理

在深度学习模型的训练过程中,数据加载和预处理是至关重要的步骤,直接影响模型的性能和训练效率。

DJL 数据处理核心组件

DJL 组件 功能描述 对应 DL4J/DataVec 组件
Dataset接口 定义数据集的核心规范,所有自定义数据集需实现该接口 通用数据源定义
RandomAccessDataset 支持随机访问的数据集(如 CSV、图像数据集) RecordReader + FileSplit
Transform接口 定义数据预处理的单步操作(如归一化、编码、增强) TransformProcess
Pipeline 组合多个 Transform,实现预处理流程的链式执行 TransformProcess
NDArray 核心数据结构,承载张量数据并提供标准化、归一化、One-Hot 编码等操作 DataNormalization + 各类 Transform
ImageTransform 图像专用预处理工具(旋转、缩放、裁剪、数据增强) DataAugmentation

DJL 数据处理工作流程

DJL 的数处理流程遵循 “数据源定义 → 数据集封装 → 预处理流水线 → 迭代器加载” 的核心逻辑,具体步骤如下:

  1. 定义数据源:指定数据文件路径 / 存储位置、数据格式;
  2. 构建数据集:基于RandomAccessDataset实现自定义数据集,读取原始数据;
  3. 定义预处理流水线(Pipeline):组合多个Transform,实现数据清洗、标准化、编码等操作;
  4. 执行预处理:通过Pipeline自动对数据进行批量预处理;
  5. 加载数据到模型:通过DatasetIterator将预处理后的数据批量输入模型训练 / 评估。

DJL 数据处理核心概念和架构

核心组件详解
Dataset 接口

DJL 的Dataset是所有数据集的基类,定义了数据集的基本行为(如获取样本数、获取单个样本)。常用实现类:

  • RandomAccessDataset:支持随机访问的数据集(推荐用于 CSV、图像等本地文件数据);
  • ArrayDataset:基于内存数组的轻量级数据集(适合小批量数据);
  • ImageFolder:图像分类专用数据集(自动按文件夹分类加载图像)。
Transform 与 Pipeline
  • Transform:单步预处理操作的接口,自定义预处理需实现transform(NDList)方法;
  • Pipeline:将多个Transform按顺序组合,形成完整的预处理流水线,支持批量数据的自动处理。
NDArray

DJL 的核心数据结构,替代 DL4J 的DataSet,提供以下关键预处理能力:

  • 数值标准化 / 归一化(sub()/div()/mean()/std());
  • 类别特征编码(oneHot());
  • 数据类型转换(toType());
  • 缺失值处理(where()/fill())。
数据迭代器
  • DatasetIterator:将数据集转换为批量迭代器,支持按批次加载数据到模型;
  • Batch:封装单批次的特征和标签数据,直接输入模型训练。

加载和预处理 CSV 数据

准备 CSV 样例数据

resources目录下创建data.csv

id,name,age,gender,income
1,Alice,34,F,50000
2,Bob,45,M,60000
3,Charlie,23,M,45000
4,Diana,56,F,70000
5,Eva,29,F,55000
6,Frank,38,M,65000
7,Grace,42,F,58000
8,Henry,51,M,72000
9,Ivy,26,F,48000
10,Jack,33,M,62000

数据字段说明:

  • id: 唯一标识符(整型);
  • name: 姓名(字符串);
  • age: 年龄(整型);
  • gender: 性别(F/M,类别型);
  • income: 收入(浮点型)。
加载 CSV 数据

pom.xml中添加 DJL 核心依赖和 CSV 解析依赖:

<!-- CSV解析依赖 -->
<dependency>
    <groupId>com.opencsv</groupId>
    <artifactId>opencsv</artifactId>
    <version>5.6</version>
</dependency>

自定义 CSV 数据集加载数据

package com.woniuxy.base.load;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.util.Progress;
import com.opencsv.CSVReader;
import com.opencsv.exceptions.CsvException;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
import java.util.List;

/**
 * 自定义CSV数据集
 */
public class CSVDataset extends RandomAccessDataset {
    // 存储CSV解析后的数据
    private List<String[]> csvData;
    // 数据集实际大小(对应availableSize())
    private long dataSize;

    /**
     * 构建器
     */
    public static class Builder extends BaseBuilder<Builder> {
        private NDManager manager;

        public Builder setManager(NDManager manager) {
            this.manager = manager;
            return self();
        }

        @Override
        protected Builder self() {
            return this;
        }

        public CSVDataset build() throws IOException, URISyntaxException, CsvException {
            return new CSVDataset(this);
        }
    }

    /**
     * 私有化构造器,通过Builder初始化(
     */
    private CSVDataset(Builder builder) throws IOException, URISyntaxException, CsvException {
        super(builder); // 必须调用父类构造器初始化采样器、Pipeline等参数
        // 1. 加载CSV文件
        URL resourceUrl = getClass().getClassLoader().getResource("data.csv");
        if (resourceUrl == null) {
            throw new IOException("CSV文件 data.csv 未找到");
        }
        File csvFile = Paths.get(resourceUrl.toURI()).toFile();

        // 2. 解析CSV
        try (CSVReader reader = new CSVReader(new FileReader(csvFile))) {
            List<String[]> allLines = reader.readAll();
            // 跳过表头(第一行)
            this.csvData = allLines.subList(1, allLines.size());
            this.dataSize = csvData.size();
        }
    }

    /**
     * 核心抽象方法:获取数据集实际大小
     */
    @Override
    protected long availableSize() {
        return dataSize;
    }

    /**
     * 核心抽象方法:读取单个样本(必须声明throws IOException + 适配Record构造器)
     */
    @Override
    public Record get(NDManager manager, long index) throws IOException {
        // 校验索引合法性
        if (index < 0 || index >= dataSize) {
            throw new IndexOutOfBoundsException("CSV索引超出范围: " + index + ", 数据集大小: " + dataSize);
        }

        String[] line = csvData.get((int) index);
        // 解析字段:id(0), name(1), age(2), gender(3), income(4)
        int id = Integer.parseInt(line[0]);
        int age = Integer.parseInt(line[2]);
        String gender = line[3];
        double income = Double.parseDouble(line[4]);

        // 将数据转换为NDArray(特征:id, age, genderCode, income)
        // 手动编码性别:F→0,M→1
        int genderCode = "F".equals(gender) ? 0 : 1;
        NDArray featuresArray = manager.create(new float[]{id, age, genderCode, (float) income});
        
        // 将单个NDArray封装为NDList
        NDList features = new NDList(featuresArray);
        
        // 标签:暂设为0(封装为NDList),可根据业务需求修改(如预测收入则设为income)
        NDArray labelArray = manager.create(0);
        NDList labels = new NDList(labelArray);

        // 返回Record
        return new Record(features, labels);
    }

    /**
     * 可选:数据集预处理准备(如加载数据、初始化资源)
     */
    @Override
    public void prepare(Progress progress) {
        // 此处可添加数据加载进度展示等逻辑
        System.out.println("CSV数据集加载完成,共" + dataSize + "条样本");
    }

    /**
     * 测试CSV加载
     */
    public static void main(String[] args) throws Exception {
        try (NDManager manager = NDManager.newBaseManager()) {
            // 构建数据集(设置采样器:批量大小2,顺序读取)
            CSVDataset dataset = new CSVDataset.Builder()
                    .setManager(manager)
                    .setSampling(2, false) // 批量大小2,非随机采样
                    .optDataBatchifier(Batchifier.STACK) // 数据批处理方式
                    .optLabelBatchifier(Batchifier.STACK) // 标签批处理方式
                    .optDevice(Device.cpu()) // 指定CPU设备
                    .build();

            // 准备数据集
            dataset.prepare(null);

            // 遍历所有样本
            for (long i = 0; i < dataset.size(); i++) {
                Record record = dataset.get(manager, i);
                // 读取特征(NDList的第一个元素)
                NDArray feature = record.getData().get(0);
                // 读取标签(NDList的第一个元素)
                NDArray label = record.getLabels().get(0);
                System.out.println("样本" + (i + 1) + " - 特征:" + feature + " | 标签:" + label);
            }
        }
    }
}

运行结果

CSV数据集加载完成,共4条样本
样本1 - 特征:ND: (4) cpu() float32
[ 1.00000000e+00,  3.40000000e+01,  0.00000000e+00,  5.00000000e+04]
 | 标签:ND: () cpu() int32
0

样本2 - 特征:ND: (4) cpu() float32
[ 2.00000000e+00,  4.50000000e+01,  1.00000000e+00,  6.00000000e+04]
 | 标签:ND: () cpu() int32
0

样本3 - 特征:ND: (4) cpu() float32
[ 3.00000000e+00,  2.30000000e+01,  1.00000000e+00,  4.50000000e+04]
 | 标签:ND: () cpu() int32
0

样本4 - 特征:ND: (4) cpu() float32
[ 4.00000000e+00,  5.60000000e+01,  0.00000000e+00,  7.00000000e+04]
 | 标签:ND: () cpu() int32
0
数据预处理
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

import java.io.IOException;
import java.io.Reader;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

public class SimpleCSVDemo {
    
    public static void main(String[] args) throws URISyntaxException, IOException {

        // 1. 模拟CSV数据(不用真的读文件)
        List<Person> people = createSampleData();
        for (Person p : people) {
            System.out.println(p);
        }
        
        // 2. 创建NDManager(理解成"数据管理器")
        try (NDManager manager = NDManager.newBaseManager()) {
            
            // 3. 预处理:把文字变成数字
            System.out.println("\n2. 预处理:文字变数字");
            SimpleData data = simplePreprocess(manager, people);
            
            // 4. 查看处理后的张量
            System.out.println("\n3. 处理后的张量:");
            System.out.println("特征(年龄,性别编码):\n" + data.features);
            System.out.println("\n标签(收入):\n" + data.labels);
            
            // 5. 简单计算演示
            System.out.println("\n4. 简单统计:");
            showSimpleStats(data);
            
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    
    // 1. 数据结构(简单的Java类)
    static class Person {
        int id;
        String name;
        int age;
        String gender;  // "男" 或 "女"
        double income;
        
        Person(int id, String name, int age, String gender, double income) {
            this.id = id;
            this.name = name;
            this.age = age;
            this.gender = gender;
            this.income = income;
        }
        
        @Override
        public String toString() {
                return id + "," + name + ",年龄:" + age + ",性别:" + gender + ",收入:" + income;
        }
    }
    
    // 2. 简单的数据容器
    static class SimpleData {
        NDArray features;  // 特征矩阵
        NDArray labels;    // 标签
        
        SimpleData(NDArray features, NDArray labels) {
            this.features = features;
            this.labels = labels;
        }
    }
    
    // 3. 创建模拟数据(代替读CSV文件)
    private static List<Person> createSampleData() throws URISyntaxException, IOException {
        List<Person> people = new ArrayList<>();

        URL resourceUrl = CSVDataProcessingExample.class.getClassLoader().getResource("data.csv");
        Path path = Paths.get(resourceUrl.toURI());
        try (Reader reader = Files.newBufferedReader(path);
             CSVParser csvParser = new CSVParser(reader,
                     CSVFormat.DEFAULT.withFirstRecordAsHeader())) {

            for (CSVRecord record : csvParser) {
                int id = Integer.parseInt(record.get("id"));
                String name = record.get("name");
                int age = Integer.parseInt(record.get("age"));
                String gender = record.get("gender");
                double income = Double.parseDouble(record.get("income"));

                people.add(new Person(id, name, age, gender, income));
            }
        }
        return people;
    }
    
    // 4. 简化版预处理
    private static SimpleData simplePreprocess(NDManager manager, List<Person> people) {
        int count = people.size();  // 4个人
        
        // 4.1 准备数组(更容易理解)
        double[][] features = new double[count][2];  // 4行×2列
        double[] labels = new double[count];        // 4个收入
        
        // 4.2 手动填充数据
        for (int i = 0; i < count; i++) {
            Person p = people.get(i);
            
            // 特征1:年龄(原样保留)
            features[i][0] = p.age;  // 第一列:年龄
            
            // 特征2:性别编码(男→0,女→1)
            int genderCode = p.gender.equals("F") ? 0 : 1;
            features[i][1] = genderCode;  // 第二列:性别编码
            
            // 标签:收入
            labels[i] = p.income;
        }
        
        // 4.3 创建张量
        NDArray featuresTensor = manager.create(features);
        NDArray labelsTensor = manager.create(labels);
        
        return new SimpleData(featuresTensor, labelsTensor);
    }
    
    // 5. 显示简单统计
    private static void showSimpleStats(SimpleData data) {
        NDArray features = data.features;
        NDArray labels = data.labels;
        
        // 5.1 计算总和
        double totalIncome = labels.sum().getDouble();
        System.out.println("总收入:" + totalIncome);
        
        // 5.2 计算平均
        double avgAge = features.get(":, 0").mean().getDouble();  // 第0列是年龄
        double avgIncome = labels.mean().getDouble();
        System.out.println("平均年龄:" + avgAge);
        System.out.println("平均收入:" + avgIncome);
        
        // 5.3 找出最大值
        double maxIncome = labels.max().getDouble();
        System.out.println("最高收入:" + maxIncome);
        
        // 5.4 按性别统计
        System.out.println("\n按性别统计:");
        
        // 男性数据(性别编码=0)
        NDArray maleMask = features.get(":, 1").eq(0);  // 第1列是性别,等于0的是男性
        NDArray maleIncomes = labels.get(maleMask);
        
        // 女性数据(性别编码=1)
        NDArray femaleMask = features.get(":, 1").eq(1);
        NDArray femaleIncomes = labels.get(femaleMask);
        
        System.out.println("男性平均收入:" + maleIncomes.mean().getDouble());
        System.out.println("女性平均收入:" + femaleIncomes.mean().getDouble());
    }
}

运行结果

1,Alice,年龄:34,性别:F,收入:50000.0
2,Bob,年龄:45,性别:M,收入:60000.0
3,Charlie,年龄:23,性别:M,收入:45000.0
4,Diana,年龄:56,性别:F,收入:70000.0
5,Eva,年龄:29,性别:F,收入:55000.0
6,Frank,年龄:38,性别:M,收入:65000.0
7,Grace,年龄:42,性别:F,收入:58000.0
8,Henry,年龄:51,性别:M,收入:72000.0
9,Ivy,年龄:26,性别:F,收入:48000.0
10,Jack,年龄:33,性别:M,收入:62000.0

2. 预处理:文字变数字

3. 处理后的张量:
特征(年龄,性别编码):
ND: (10, 2) cpu() float64
[[34.,  0.],
 [45.,  1.],
 [23.,  1.],
 [56.,  0.],
 [29.,  0.],
 [38.,  1.],
 [42.,  0.],
 [51.,  1.],
 [26.,  0.],
 [33.,  1.],
]


标签(收入):
ND: (10) cpu() float64
[50000., 60000., 45000., 70000., 55000., 65000., 58000., 72000., 48000., 62000.]


4. 简单统计:
总收入:585000.0
平均年龄:37.7
平均收入:58500.0
最高收入:72000.0

按性别统计:
男性平均收入:56200.0
女性平均收入:60800.0
模块化数据预处理

上面我们写了一个完整的数据处理程序,把所有代码都放在一个main方法里。这样写容易理解,但不方便复用。

模块化的意思就是把代码按功能拆分成几个部分:

原来的:一个大文件做所有事
现在的:
├── DataLoader.java    只负责读CSV
├── DataPreprocessor.java 只负责处理数据  
├── DataAnalyzer.java  只负责统计分析
└── Main.java         把上面几个组合起来

好处:

  • 每部分代码更短,更容易看懂
  • 可以在其他项目里直接使用这些模块
  • 修改一个功能不影响其他部分

就像做菜:原来一个人负责所有步骤,现在分工明确:有人洗菜、有人切菜、有人炒菜。

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

import java.io.IOException;
import java.io.Reader;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

public class ModularCSVDemo {
    
    public static void main(String[] args) throws URISyntaxException, IOException {
        System.out.println("=== 模块化CSV数据处理 ===\n");
        
        try (NDManager manager = NDManager.newBaseManager()) {
            
            // 1. 分模块处理数据
            System.out.println("1. 加载数据...");
            List<Person> people = DataLoader.loadCSV("data.csv");
            
            System.out.println("\n2. 预处理数据...");
            ProcessedData data = DataPreprocessor.process(manager, people);
            
            System.out.println("\n3. 分析数据...");
            DataAnalyzer.analyze(data);
            
            System.out.println("\n4. 显示结果...");
            displayResults(data, people);
            
        }
    }
    
    // ==================== 1. 数据加载模块 ====================
    static class DataLoader {
        static List<Person> loadCSV(String fileName) throws URISyntaxException, IOException {
            List<Person> people = new ArrayList<>();
            
            URL resourceUrl = ModularCSVDemo.class.getClassLoader().getResource(fileName);
            Path path = Paths.get(resourceUrl.toURI());
            
            try (Reader reader = Files.newBufferedReader(path);
                 CSVParser csvParser = new CSVParser(reader,
                         CSVFormat.DEFAULT.withFirstRecordAsHeader())) {
                
                for (CSVRecord record : csvParser) {
                    int id = Integer.parseInt(record.get("id"));
                    String name = record.get("name");
                    int age = Integer.parseInt(record.get("age"));
                    String gender = record.get("gender");
                    double income = Double.parseDouble(record.get("income"));
                    
                    people.add(new Person(id, name, age, gender, income));
                }
            }
            
            System.out.println("  加载完成:" + people.size() + "条记录");
            return people;
        }
    }
    
    // ==================== 2. 数据预处理模块 ====================
    static class DataPreprocessor {
        static ProcessedData process(NDManager manager, List<Person> people) {
            int count = people.size();
            
            // 准备数组
            double[][] features = new double[count][2];
            double[] labels = new double[count];
            
            // 填充数据
            for (int i = 0; i < count; i++) {
                Person p = people.get(i);
                
                // 特征1:年龄
                features[i][0] = p.age;
                
                // 特征2:性别编码(M→0, F→1)
                features[i][1] = p.gender.equals("M") ? 0 : 1;
                
                // 标签:收入
                labels[i] = p.income;
            }
            
            // 创建张量
            NDArray featuresTensor = manager.create(features);
            NDArray labelsTensor = manager.create(labels);
            
            System.out.println("  预处理完成:特征" + featuresTensor.getShape() + 
                              ",标签" + labelsTensor.getShape());
            
            return new ProcessedData(featuresTensor, labelsTensor, people);
        }
    }
    
    // ==================== 3. 数据分析模块 ====================
    static class DataAnalyzer {
        static void analyze(ProcessedData data) {
            NDArray features = data.features;
            NDArray labels = data.labels;
            
            // 基本统计
            double totalIncome = labels.sum().getDouble();
            double avgAge = features.get(":, 0").mean().getDouble();
            double avgIncome = labels.mean().getDouble();
            double maxIncome = labels.max().getDouble();
            
            // 按性别统计
            double[] genders = features.get(":, 1").toDoubleArray();
            double[] incomes = labels.toDoubleArray();
            
            double maleSum = 0, femaleSum = 0;
            int maleCount = 0, femaleCount = 0;
            
            for (int i = 0; i < genders.length; i++) {
                if (genders[i] == 0) {
                    maleSum += incomes[i];
                    maleCount++;
                } else {
                    femaleSum += incomes[i];
                    femaleCount++;
                }
            }
            
            System.out.println("  分析完成:");
            System.out.println("  - 平均年龄: " + String.format("%.1f", avgAge));
            System.out.println("  - 平均收入: " + String.format("%,.0f", avgIncome));
            System.out.println("  - 最高收入: " + String.format("%,.0f", maxIncome));
            System.out.println("  - 男性平均收入: " + 
                String.format("%,.0f", maleCount > 0 ? maleSum/maleCount : 0));
            System.out.println("  - 女性平均收入: " + 
                String.format("%,.0f", femaleCount > 0 ? femaleSum/femaleCount : 0));
        }
    }
    
    // ==================== 4. 结果显示模块 ====================
    static void displayResults(ProcessedData data, List<Person> people) {
        System.out.println("\n=== 最终结果 ===");
        
        System.out.println("\n原始数据(前5条):");
        for (int i = 0; i < Math.min(5, people.size()); i++) {
            System.out.println("  " + people.get(i));
        }
        
        System.out.println("\n特征张量(前5行):");
        for (int i = 0; i < Math.min(5, data.features.getShape().get(0)); i++) {
            System.out.printf("  样本%d: %s%n", i+1, 
                    Arrays.toString(data.features.get(i).toDoubleArray()));
        }
        
        System.out.println("\n标签张量(前5行):");
        for (int i = 0; i < Math.min(5, data.labels.getShape().get(0)); i++) {
            System.out.printf("  样本%d: %.0f%n", i+1, 
                    data.labels.getDouble(i));
        }
    }
    
    // ==================== 数据结构 ====================
    static class Person {
        int id;
        String name;
        int age;
        String gender;
        double income;
        
        Person(int id, String name, int age, String gender, double income) {
            this.id = id;
            this.name = name;
            this.age = age;
            this.gender = gender;
            this.income = income;
        }
        
        @Override
        public String toString() {
            return String.format("ID:%d %-10s 年龄:%2d 性别:%s 收入:%,.0f", 
                    id, name, age, gender, income);
        }
    }
    
    static class ProcessedData {
        NDArray features;
        NDArray labels;
        List<Person> rawData;
        
        ProcessedData(NDArray features, NDArray labels, List<Person> rawData) {
            this.features = features;
            this.labels = labels;
            this.rawData = rawData;
        }
    }
}

运行结果

=== 模块化CSV数据处理 ===

1. 加载数据...
  加载完成:10条记录

2. 预处理数据...
  预处理完成:特征(10, 2),标签(10)

3. 分析数据...
  分析完成:
  - 平均年龄: 37.7
  - 平均收入: 58,500
  - 最高收入: 72,000
  - 男性平均收入: 60,800
  - 女性平均收入: 56,200

4. 显示结果...

=== 最终结果 ===

原始数据(前5条):
  ID:1 Alice      年龄:34 性别:F 收入:50,000
  ID:2 Bob        年龄:45 性别:M 收入:60,000
  ID:3 Charlie    年龄:23 性别:M 收入:45,000
  ID:4 Diana      年龄:56 性别:F 收入:70,000
  ID:5 Eva        年龄:29 性别:F 收入:55,000

特征张量(前5行):
  样本1: [34.0, 1.0]
  样本2: [45.0, 0.0]
  样本3: [23.0, 0.0]
  样本4: [56.0, 1.0]
  样本5: [29.0, 1.0]

标签张量(前5行):
  样本1: 50000
  样本2: 60000
  样本3: 45000
  样本4: 70000
  样本5: 55000
组合Pipeline预处理

上一个我们把代码分成了几个模块,结构清晰多了。但还有个问题:处理步骤是固定的。

现在我们要用Pipeline(流水线)模式,让数据处理流程可以灵活配置。

核心思想

  • 每个处理步骤都是一个"小插件"
  • 可以随意组合这些插件
  • 不需要改代码就能调整处理流程

就像乐高积木:每个步骤是一个积木块,你可以按需要拼出不同的形状。

Pipeline 设计的好处

  • 模块化设计:每个步骤独立,易于理解和维护

  • 灵活配置:可以轻松添加、删除或重新排序步骤

  • 可重用性:步骤可以在不同项目间重用

示例代码

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

import java.io.IOException;
import java.io.Reader;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.Normalizer;
import java.util.*;

public class SimpleCSVDemoWithPipeline {
    
    public static void main(String[] args) throws Exception {
        System.out.println("=== 带Pipeline的CSV数据处理 ===\n");
        
        // 1. 创建数据处理Pipeline
        DataPipeline pipeline = new DataPipeline();
        
        // 2. 添加处理步骤
        pipeline.addStep(new CSVLoader());        // 步骤1:加载CSV
        pipeline.addStep(new GenderEncoder());    // 步骤2:性别编码
        pipeline.addStep(new FeatureSelector());  // 步骤3:选择特征
        pipeline.addStep(new Statistics());       // 步骤4:统计分析
        
        // 3. 执行Pipeline
        try (NDManager manager = NDManager.newBaseManager()) {
            ProcessedResult result = pipeline.execute(manager, "data.csv");
            
            // 4. 显示结果
            result.showResults();
        }
    }
    
    // ==================== 数据结构 ====================
    static class Person {
        int id;
        String name;
        int age;
        String gender;
        double income;
        
        Person(int id, String name, int age, String gender, double income) {
            this.id = id;
            this.name = name;
            this.age = age;
            this.gender = gender;
            this.income = income;
        }
        
        @Override
        public String toString() {
            return String.format("ID:%d %-10s 年龄:%2d 性别:%s 收入:%,.0f", 
                    id, name, age, gender, income);
        }
    }
    
    // ==================== Pipeline处理结果 ====================
    static class ProcessedResult {
        List<Person> rawData;           // 原始数据
        List<Person> encodedData;       // 编码后的数据(内存中)
        NDArray features;               // 特征张量
        NDArray labels;                 // 标签张量
        Map<String, Object> stats;      // 统计信息
        
        void showResults() {
            System.out.println("\n=== 数据处理结果 ===");
            
            System.out.println("\n1. 原始数据(" + rawData.size() + "条):");
            rawData.forEach(p -> System.out.println("  " + p));
            
            System.out.println("\n2. 特征张量(" + features.getShape() + "):");
            printArray(features, 5);  // 只显示前5行
            
            System.out.println("\n3. 标签张量(" + labels.getShape() + "):");
            printArray(labels, 5);
            
            System.out.println("\n4. 统计信息:");
            stats.forEach((key, value) -> System.out.printf("  %-20s: %s%n", key, value));
        }
        
        private void printArray(NDArray arr, int maxRows) {
            long rows = Math.min(arr.getShape().get(0), maxRows);
            for (int i = 0; i < rows; i++) {
                System.out.printf("  样本%d: %s%n", i+1, 
                        Arrays.toString(arr.get(i).toDoubleArray()));
            }
            if (arr.getShape().get(0) > maxRows) {
                System.out.printf("  ... 还有%d行未显示%n", arr.getShape().get(0) - maxRows);
            }
        }
    }
    
    // ==================== Pipeline接口 ====================
    interface PipelineStep {
        void process(ProcessedResult result, NDManager manager) throws Exception;
        String getStepName();
    }
    
    // ==================== 数据处理Pipeline ====================
    static class DataPipeline {
        private List<PipelineStep> steps = new ArrayList<>();
        
        void addStep(PipelineStep step) {
            steps.add(step);
            System.out.println("添加步骤: " + step.getStepName());
        }
        
        ProcessedResult execute(NDManager manager, String csvPath) throws Exception {
            ProcessedResult result = new ProcessedResult();
            result.stats = new LinkedHashMap<>();  // 保持插入顺序
            
            System.out.println("\n开始执行Pipeline:");
            for (int i = 0; i < steps.size(); i++) {
                PipelineStep step = steps.get(i);
                System.out.printf("%d. %s...%n", i+1, step.getStepName());
                
                long startTime = System.currentTimeMillis();
                step.process(result, manager);
                long time = System.currentTimeMillis() - startTime;
                
                System.out.printf("   完成 (耗时: %d ms)%n", time);
            }
            
            return result;
        }
    }
    
    // ==================== 步骤1:CSV加载器 ====================
    static class CSVLoader implements PipelineStep {
        @Override
        public void process(ProcessedResult result, NDManager manager) throws Exception {
            URL resourceUrl = SimpleCSVDemoWithPipeline.class.getClassLoader()
                    .getResource("data.csv");
            Path path = Paths.get(resourceUrl.toURI());
            
            List<Person> people = new ArrayList<>();
            try (Reader reader = Files.newBufferedReader(path);
                 CSVParser csvParser = new CSVParser(reader,
                         CSVFormat.DEFAULT.withFirstRecordAsHeader())) {
                
                for (CSVRecord record : csvParser) {
                    int id = Integer.parseInt(record.get("id"));
                    String name = record.get("name");
                    int age = Integer.parseInt(record.get("age"));
                    String gender = record.get("gender");
                    double income = Double.parseDouble(record.get("income"));
                    
                    people.add(new Person(id, name, age, gender, income));
                }
            }
            
            result.rawData = people;
            result.stats.put("加载完成, 数据条数", people.size());
        }
        
        @Override
        public String getStepName() {
            return "加载CSV文件";
        }
    }
    
    // ==================== 步骤2:性别编码器 ====================
    static class GenderEncoder implements PipelineStep {
        @Override
        public void process(ProcessedResult result, NDManager manager) {
            List<Person> encodedPeople = new ArrayList<>();
            int maleCount = 0, femaleCount = 0;
            
            for (Person p : result.rawData) {
                // 创建编码后的副本(在实际应用中可能会修改原对象)
                Person encoded = new Person(
                    p.id, p.name, p.age, 
                    p.gender.equals("M") ? "0" : "1",  // 编码为字符串"0"/"1"
                    p.income
                );
                encodedPeople.add(encoded);
                
                if (p.gender.equals("M")){
                    maleCount++;
                }else {
                    femaleCount++;
                }
            }
            
            result.encodedData = encodedPeople;
            result.stats.put("男性人数", maleCount);
            result.stats.put("女性人数", femaleCount);
        }
        
        @Override
        public String getStepName() {
            return "性别编码(M→0, F→1)";
        }
    }
    
    // ==================== 步骤3:特征选择器 ====================
    static class FeatureSelector implements PipelineStep {
        @Override
        public void process(ProcessedResult result, NDManager manager) {
            int count = result.rawData.size();
            
            // 创建特征矩阵:使用double保证类型一致
            double[][] features = new double[count][2];  // [年龄, 性别编码]
            double[] labels = new double[count];         // 收入
            
            for (int i = 0; i < count; i++) {
                Person p = result.rawData.get(i);
                Person encoded = result.encodedData.get(i);
                
                features[i][0] = p.age;                    // 年龄
                features[i][1] = Double.parseDouble(encoded.gender);  // 性别编码(转为double)
                labels[i] = p.income;                      // 收入
            }
            
            // 创建张量
            result.features = manager.create(features);
            result.labels = manager.create(labels).reshape(-1, 1);  // 转为列向量
            
            result.stats.put("特征维度", features[0].length);
            result.stats.put("样本数量", count);
        }
        
        @Override
        public String getStepName() {
            return "特征选择与张量创建";
        }
    }

    // ==================== 步骤5:统计分析 ====================
    static class Statistics implements PipelineStep {
        @Override
        public void process(ProcessedResult result, NDManager manager) {
            NDArray features = result.features;
            NDArray labels = result.labels;

            // 基本统计
            double totalIncome = labels.sum().getDouble();
            double avgAge = features.get(":, 0").mean().getDouble();
            double avgIncome = labels.mean().getDouble();
            double maxIncome = labels.max().getDouble();
            double minIncome = labels.min().getDouble();

            // 按性别统计
            NDArray genderCol = features.get(":, 1");
            NDArray maleMask = genderCol.eq(0);
            NDArray femaleMask = genderCol.eq(1);

            double maleAvgIncome = 0, femaleAvgIncome = 0;
            int maleCount = 0, femaleCount = 0;

            // 使用更安全的方法统计
            for (int i = 0; i < labels.getShape().get(0); i++) {
                double gender = genderCol.getDouble(i);
                double income = labels.getDouble(i);

                if (gender == 0) {
                    maleAvgIncome += income;
                    maleCount++;
                } else {
                    femaleAvgIncome += income;
                    femaleCount++;
                }
            }

            if (maleCount > 0){
                maleAvgIncome /= maleCount;
            }
            if (femaleCount > 0){
                femaleAvgIncome /= femaleCount;
            }

            // 存储统计结果
            result.stats.put("总收入", String.format("%,.0f", totalIncome));
            result.stats.put("平均年龄", String.format("%.1f岁", avgAge));
            result.stats.put("平均收入", String.format("%,.0f", avgIncome));
            result.stats.put("最高收入", String.format("%,.0f", maxIncome));
            result.stats.put("最低收入", String.format("%,.0f", minIncome));
            result.stats.put("男性人数", maleCount);
            result.stats.put("女性人数", femaleCount);
            result.stats.put("男性平均收入", String.format("%,.0f", maleAvgIncome));
            result.stats.put("女性平均收入", String.format("%,.0f", femaleAvgIncome));

            // 计算年龄与收入的相关性(使用修正的方法)
            NDArray ages = features.get(":, 0");
            NDArray incomes = labels.flatten();

            double corr = calculateCorrelation(ages, incomes);
            result.stats.put("年龄-收入相关性", String.format("%.3f", corr));

            // 添加更多统计信息
            result.stats.put("年龄范围", String.format("%.0f-%.0f岁",
                    ages.min().getDouble(), ages.max().getDouble()));
            result.stats.put("收入标准差", String.format("%,.0f",
                    calculateStd(incomes)));
            result.stats.put("年龄标准差", String.format("%.1f岁",
                    calculateStd(ages)));

        }

        // 计算标准差的方法
        private double calculateStd(NDArray arr) {
            double mean = arr.mean().getDouble();
            NDArray centered = arr.sub(mean);
            double variance = centered.pow(2).mean().getDouble();
            return Math.sqrt(variance);
        }

        // 计算相关系数的方法
        private double calculateCorrelation(NDArray x, NDArray y) {
            // 1. 计算均值
            double xMean = x.mean().getDouble();
            double yMean = y.mean().getDouble();

            // 2. 中心化
            NDArray xCentered = x.sub(xMean);
            NDArray yCentered = y.sub(yMean);

            // 3. 计算协方差
            double cov = xCentered.mul(yCentered).mean().getDouble();

            // 4. 计算标准差
            double xStd = calculateStd(x);
            double yStd = calculateStd(y);

            // 5. 计算相关系数
            if (xStd == 0 || yStd == 0) {
                return 0;  // 避免除零
            }
            return cov / (xStd * yStd);
        }

        @Override
        public String getStepName() {
            return "数据统计分析";
        }
    }
    
    // ==================== 额外步骤示例:数据分割 ====================
    static class DataSplitter implements PipelineStep {
        private double trainRatio = 0.8;
        
        public DataSplitter(double trainRatio) {
            this.trainRatio = trainRatio;
        }
        
        @Override
        public void process(ProcessedResult result, NDManager manager) {
            int total = (int) result.features.getShape().get(0);
            int trainSize = (int) (total * trainRatio);
            int testSize = total - trainSize;
            
            // 分割特征
            NDArray trainFeatures = result.features.get("0:" + trainSize);
            NDArray testFeatures = result.features.get(trainSize + ":" + total);
            
            // 分割标签
            NDArray trainLabels = result.labels.get("0:" + trainSize);
            NDArray testLabels = result.labels.get(trainSize + ":" + total);
            
            // 在实际应用中,这里会将数据存储到result中
            result.stats.put("训练集大小", trainSize);
            result.stats.put("测试集大小", testSize);
            result.stats.put("分割比例", String.format("%.0f%%/%.0f%%", 
                    trainRatio*100, (1-trainRatio)*100));
        }
        
        @Override
        public String getStepName() {
            return String.format("数据分割(%.0f%%训练)", trainRatio*100);
        }
    }
}

运行结果

== 带Pipeline的CSV数据处理 ===

添加步骤: 加载CSV文件
添加步骤: 性别编码(M→0, F→1)
添加步骤: 特征选择与张量创建
添加步骤: 数据统计分析

开始执行Pipeline:
1. 加载CSV文件...
   完成 (耗时: 31 ms)
2. 性别编码(M→0, F→1)...
   完成 (耗时: 0 ms)
3. 特征选择与张量创建...
   完成 (耗时: 25 ms)
4. 数据统计分析...
   完成 (耗时: 32 ms)

=== 数据处理结果 ===

1. 原始数据(10条):
  ID:1 Alice      年龄:34 性别:F 收入:50,000
  ID:2 Bob        年龄:45 性别:M 收入:60,000
  ID:3 Charlie    年龄:23 性别:M 收入:45,000
  ID:4 Diana      年龄:56 性别:F 收入:70,000
  ID:5 Eva        年龄:29 性别:F 收入:55,000
  ID:6 Frank      年龄:38 性别:M 收入:65,000
  ID:7 Grace      年龄:42 性别:F 收入:58,000
  ID:8 Henry      年龄:51 性别:M 收入:72,000
  ID:9 Ivy        年龄:26 性别:F 收入:48,000
  ID:10 Jack       年龄:33 性别:M 收入:62,000

2. 特征张量((10, 2)):
  样本1: [34.0, 1.0]
  样本2: [45.0, 0.0]
  样本3: [23.0, 0.0]
  样本4: [56.0, 1.0]
  样本5: [29.0, 1.0]
  ... 还有5行未显示

3. 标签张量((10, 1)):
  样本1: [50000.0]
  样本2: [60000.0]
  样本3: [45000.0]
  样本4: [70000.0]
  样本5: [55000.0]
  ... 还有5行未显示

4. 统计信息:
  加载完成, 数据条数          : 10
  男性人数                : 5
  女性人数                : 5
  特征维度                : 2
  样本数量                : 10
  总收入                 : 585,000
  平均年龄                : 37.7岁
  平均收入                : 58,500
  最高收入                : 72,000
  最低收入                : 45,000
  男性平均收入              : 60,800
  女性平均收入              : 56,200
  年龄-收入相关性            : 0.867
  年龄范围                : 23-56岁
  收入标准差               : 8,652
  年龄标准差               : 10.2岁
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐