手把手教你用 TensorFlow.js 在浏览器里跑机器学习模型

引言

你有没有想过,在不依赖后端服务器的情况下,直接在浏览器里训练和运行机器学习模型?这听起来像是科幻小说里的情节,但 TensorFlow.js 让这一切变成了现实。

TensorFlow.js 是 Google 推出的 JavaScript 机器学习库,它允许开发者在浏览器或 Node.js 环境中构建、训练和部署 ML 模型。无论你是前端开发者想拓展技能树,还是机器学习工程师想探索新的部署方式,这篇文章都能帮你快速上手。

本文将带你从零开始,通过一个完整的 TensorFlow.js 实战教程,掌握核心概念和实操步骤。我还会分享一些 TensorFlow.js 踩坑记录,帮你避开初学者常见的坑。

核心概念:张量是 TensorFlow.js 的基石

什么是张量(Tensor)?

在 TensorFlow.js 中,张量是最核心的数据结构。简单来说,张量就是多维数组的泛化:

  • 标量:0 维张量,就是一个单独的数字
  • 向量:1 维张量,如 [1, 2, 3]
  • 矩阵:2 维张量,如 [[1, 2], [3, 4]]
  • 更高维:3 维以上,常用于图像数据(宽度 × 高度 × 通道数)

创建张量的基本方法

让我们直接看 TensorFlow.js 代码示例:

// 导入TensorFlow.js库
import * as tf from '@tensorflow/tfjs';

// 创建一个1D张量(一维数组)
const tensor1d = tf.tensor([1, 2, 3, 4]);
tensor1d.print(); // 输出: [1, 2, 3, 4]

// 创建一个2D张量(矩阵)
const tensor2d = tf.tensor2d([[1, 2], [3, 4]]);
tensor2d.print(); 
// 输出: 
// [[1, 2],
//  [3, 4]]

小贴士tf.tensor() 是最通用的方法,可以创建任意维度的张量。而 tf.tensor2d() 是创建二维张量的快捷方式,代码更简洁。

张量运算:机器学习的基础操作

张量最强大的地方在于支持各种数学运算,这些运算构成了机器学习模型的基础:

// 张量运算:加法
const sum = tensor1d.add(tf.tensor([5, 6, 7, 8]));
sum.print(); // 输出: [6, 8, 10, 12]

// 张量运算:乘法(逐元素相乘)
const product = tensor2d.mul(tf.tensor2d([[2, 0], [1, 3]]));
product.print();
// 输出:
// [[2, 0],
//  [3, 12]]

重要提醒:TensorFlow.js 中的运算都是不可变的——它们不会修改原始张量,而是返回新的张量。

内存管理的痛点(踩坑记录)

这是很多初学者会忽略的问题。TensorFlow.js 使用 WebGL 或 WebGPU 在 GPU 上执行计算,这意味着张量会占用 GPU 内存。如果不手动释放,会导致内存泄漏和性能下降。

// 释放内存(避免内存泄漏)
tensor1d.dispose();
tensor2d.dispose();
sum.dispose();
product.dispose();

最佳实践:使用 tf.tidy() 自动管理内存。它会在函数执行完毕后自动释放所有中间张量:

const result = tf.tidy(() => {
  const a = tf.tensor([1, 2, 3]);
  const b = tf.tensor([4, 5, 6]);
  return a.add(b);
});
// 这里 a 和 b 自动被释放了
result.print(); // 输出: [5, 7, 9]

实战步骤:构建一个简单的线性回归模型

现在让我们通过一个完整的 TensorFlow.js 实操步骤,构建一个线性回归模型来预测房价。

步骤 1:准备数据

假设我们有一些房屋面积(平方米)和对应价格(万元)的数据:

// 训练数据:房屋面积 (x) 和 价格 (y)
const xs = tf.tensor2d([50, 60, 70, 80, 90, 100, 110, 120], [8, 1]);
const ys = tf.tensor2d([150, 180, 210, 240, 270, 300, 330, 360], [8, 1]);

// 数据归一化(重要!)
const xMean = xs.mean();
const xStd = xs.std();
const yMean = ys.mean();
const yStd = ys.std();

const normalizedXs = xs.sub(xMean).div(xStd);
const normalizedYs = ys.sub(yMean).div(yStd);

踩坑记录数据归一化是初学者容易忽略的步骤。如果不归一化,模型可能无法收敛或训练速度极慢。

步骤 2:定义模型

// 创建一个序列模型
const model = tf.sequential();

// 添加一个全连接层(输入维度1,输出维度1)
model.add(tf.layers.dense({
  units: 1,
  inputShape: [1],
  activation: 'linear'
}));

// 编译模型:指定优化器和损失函数
model.compile({
  optimizer: tf.train.sgd(0.1),  // 随机梯度下降,学习率0.1
  loss: 'meanSquaredError'        // 均方误差
});

步骤 3:训练模型

async function trainModel() {
  console.log('开始训练...');
  
  const history = await model.fit(normalizedXs, normalizedYs, {
    epochs: 100,
    batchSize: 4,
    callbacks: {
      onEpochEnd: (epoch, logs) => {
        console.log(`Epoch ${epoch + 1}: loss = ${logs.loss.toFixed(4)}`);
      }
    }
  });
  
  console.log('训练完成!

');
  return history;
}

// 执行训练
trainModel();

实操提示:训练过程是异步的,所以需要使用 async/await。回调函数 onEpochEnd 可以帮助你监控训练进度。

步骤 4:使用模型进行预测

function predict(area) {
  // 对输入进行同样的归一化
  const normalizedInput = tf.tensor2d([area]).sub(xMean).div(xStd);
  
  // 进行预测
  const normalizedOutput = model.predict(normalizedInput);
  
  // 反归一化得到真实价格
  const price = normalizedOutput.mul(yStd).add(yMean);
  
  console.log(`预测结果:${area}平方米的房价约为 ${price.dataSync()[0].toFixed(0)} 万元`);
  
  // 清理内存
  normalizedInput.dispose();
  normalizedOutput.dispose();
  price.dispose();
}

// 预测85平方米的房价
predict(85); // 输出:预测结果:85平方米的房价约为 255 万元

步骤 5:保存和加载模型

// 保存模型到浏览器本地存储
async function saveModel() {
  await model.save('localstorage://my-house-price-model');
  console.log('模型已保存');
}

// 从浏览器本地存储加载模型
async function loadModel() {
  const loadedModel = await tf.loadLayersModel('localstorage://my-house-price-model');
  console.log('模型已加载');
  return loadedModel;
}

TensorFlow.js 踩坑记录总结

  1. 内存泄漏:张量操作后务必释放内存,推荐使用 tf.tidy()
  2. 数据归一化:训练前对数据进行归一化,预测时也要做同样的归一化
  3. 异步操作model.fit() 是异步的,需要使用 await
  4. 浏览器兼容性:WebGL 后端在某些老旧浏览器上可能不支持,建议使用 WebGPU 或 CPU 后端作为备选
  5. 模型大小:复杂的模型文件可能很大,注意加载时间

总结

通过本文的 TensorFlow.js 实战教程,你已经学会了:

  • 张量的创建和基本运算
  • 内存管理的最佳实践
  • 从数据准备到模型训练、预测的完整流程
  • 模型的保存与加载

TensorFlow.js 的真正魅力在于,它让前端开发者也能轻松拥抱机器学习。你可以在浏览器中实现图像分类、语音识别、自然语言处理等高级功能,而无需搭建复杂的后端服务。

现在,打开你的编辑器,开始你的第一个 TensorFlow.js 项目吧!记住:机器学习的核心不是算法有多复杂,而是你能用它解决什么实际问题。

总结

通过本文的学习,相信你已经对「TensorFlow.js」有了更深入的理解。建议结合实际项目多加练习,在实践中巩固所学知识。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐