手把手教你用 TensorFlow.js 在浏览器里跑机器学习模型
手把手教你用 TensorFlow.js 在浏览器里跑机器学习模型
引言
你有没有想过,在不依赖后端服务器的情况下,直接在浏览器里训练和运行机器学习模型?这听起来像是科幻小说里的情节,但 TensorFlow.js 让这一切变成了现实。
TensorFlow.js 是 Google 推出的 JavaScript 机器学习库,它允许开发者在浏览器或 Node.js 环境中构建、训练和部署 ML 模型。无论你是前端开发者想拓展技能树,还是机器学习工程师想探索新的部署方式,这篇文章都能帮你快速上手。
本文将带你从零开始,通过一个完整的 TensorFlow.js 实战教程,掌握核心概念和实操步骤。我还会分享一些 TensorFlow.js 踩坑记录,帮你避开初学者常见的坑。
核心概念:张量是 TensorFlow.js 的基石
什么是张量(Tensor)?
在 TensorFlow.js 中,张量是最核心的数据结构。简单来说,张量就是多维数组的泛化:
- 标量:0 维张量,就是一个单独的数字
- 向量:1 维张量,如
[1, 2, 3] - 矩阵:2 维张量,如
[[1, 2], [3, 4]] - 更高维:3 维以上,常用于图像数据(宽度 × 高度 × 通道数)
创建张量的基本方法
让我们直接看 TensorFlow.js 代码示例:
// 导入TensorFlow.js库
import * as tf from '@tensorflow/tfjs';
// 创建一个1D张量(一维数组)
const tensor1d = tf.tensor([1, 2, 3, 4]);
tensor1d.print(); // 输出: [1, 2, 3, 4]
// 创建一个2D张量(矩阵)
const tensor2d = tf.tensor2d([[1, 2], [3, 4]]);
tensor2d.print();
// 输出:
// [[1, 2],
// [3, 4]]
小贴士:tf.tensor() 是最通用的方法,可以创建任意维度的张量。而 tf.tensor2d() 是创建二维张量的快捷方式,代码更简洁。
张量运算:机器学习的基础操作
张量最强大的地方在于支持各种数学运算,这些运算构成了机器学习模型的基础:
// 张量运算:加法
const sum = tensor1d.add(tf.tensor([5, 6, 7, 8]));
sum.print(); // 输出: [6, 8, 10, 12]
// 张量运算:乘法(逐元素相乘)
const product = tensor2d.mul(tf.tensor2d([[2, 0], [1, 3]]));
product.print();
// 输出:
// [[2, 0],
// [3, 12]]
重要提醒:TensorFlow.js 中的运算都是不可变的——它们不会修改原始张量,而是返回新的张量。
内存管理的痛点(踩坑记录)
这是很多初学者会忽略的问题。TensorFlow.js 使用 WebGL 或 WebGPU 在 GPU 上执行计算,这意味着张量会占用 GPU 内存。如果不手动释放,会导致内存泄漏和性能下降。
// 释放内存(避免内存泄漏)
tensor1d.dispose();
tensor2d.dispose();
sum.dispose();
product.dispose();
最佳实践:使用 tf.tidy() 自动管理内存。它会在函数执行完毕后自动释放所有中间张量:
const result = tf.tidy(() => {
const a = tf.tensor([1, 2, 3]);
const b = tf.tensor([4, 5, 6]);
return a.add(b);
});
// 这里 a 和 b 自动被释放了
result.print(); // 输出: [5, 7, 9]
实战步骤:构建一个简单的线性回归模型
现在让我们通过一个完整的 TensorFlow.js 实操步骤,构建一个线性回归模型来预测房价。
步骤 1:准备数据
假设我们有一些房屋面积(平方米)和对应价格(万元)的数据:
// 训练数据:房屋面积 (x) 和 价格 (y)
const xs = tf.tensor2d([50, 60, 70, 80, 90, 100, 110, 120], [8, 1]);
const ys = tf.tensor2d([150, 180, 210, 240, 270, 300, 330, 360], [8, 1]);
// 数据归一化(重要!)
const xMean = xs.mean();
const xStd = xs.std();
const yMean = ys.mean();
const yStd = ys.std();
const normalizedXs = xs.sub(xMean).div(xStd);
const normalizedYs = ys.sub(yMean).div(yStd);
踩坑记录:数据归一化是初学者容易忽略的步骤。如果不归一化,模型可能无法收敛或训练速度极慢。
步骤 2:定义模型
// 创建一个序列模型
const model = tf.sequential();
// 添加一个全连接层(输入维度1,输出维度1)
model.add(tf.layers.dense({
units: 1,
inputShape: [1],
activation: 'linear'
}));
// 编译模型:指定优化器和损失函数
model.compile({
optimizer: tf.train.sgd(0.1), // 随机梯度下降,学习率0.1
loss: 'meanSquaredError' // 均方误差
});
步骤 3:训练模型
async function trainModel() {
console.log('开始训练...');
const history = await model.fit(normalizedXs, normalizedYs, {
epochs: 100,
batchSize: 4,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch + 1}: loss = ${logs.loss.toFixed(4)}`);
}
}
});
console.log('训练完成!
');
return history;
}
// 执行训练
trainModel();
实操提示:训练过程是异步的,所以需要使用 async/await。回调函数 onEpochEnd 可以帮助你监控训练进度。
步骤 4:使用模型进行预测
function predict(area) {
// 对输入进行同样的归一化
const normalizedInput = tf.tensor2d([area]).sub(xMean).div(xStd);
// 进行预测
const normalizedOutput = model.predict(normalizedInput);
// 反归一化得到真实价格
const price = normalizedOutput.mul(yStd).add(yMean);
console.log(`预测结果:${area}平方米的房价约为 ${price.dataSync()[0].toFixed(0)} 万元`);
// 清理内存
normalizedInput.dispose();
normalizedOutput.dispose();
price.dispose();
}
// 预测85平方米的房价
predict(85); // 输出:预测结果:85平方米的房价约为 255 万元
步骤 5:保存和加载模型
// 保存模型到浏览器本地存储
async function saveModel() {
await model.save('localstorage://my-house-price-model');
console.log('模型已保存');
}
// 从浏览器本地存储加载模型
async function loadModel() {
const loadedModel = await tf.loadLayersModel('localstorage://my-house-price-model');
console.log('模型已加载');
return loadedModel;
}
TensorFlow.js 踩坑记录总结
- 内存泄漏:张量操作后务必释放内存,推荐使用
tf.tidy() - 数据归一化:训练前对数据进行归一化,预测时也要做同样的归一化
- 异步操作:
model.fit()是异步的,需要使用await - 浏览器兼容性:WebGL 后端在某些老旧浏览器上可能不支持,建议使用 WebGPU 或 CPU 后端作为备选
- 模型大小:复杂的模型文件可能很大,注意加载时间
总结
通过本文的 TensorFlow.js 实战教程,你已经学会了:
- 张量的创建和基本运算
- 内存管理的最佳实践
- 从数据准备到模型训练、预测的完整流程
- 模型的保存与加载
TensorFlow.js 的真正魅力在于,它让前端开发者也能轻松拥抱机器学习。你可以在浏览器中实现图像分类、语音识别、自然语言处理等高级功能,而无需搭建复杂的后端服务。
现在,打开你的编辑器,开始你的第一个 TensorFlow.js 项目吧!记住:机器学习的核心不是算法有多复杂,而是你能用它解决什么实际问题。
总结
通过本文的学习,相信你已经对「TensorFlow.js」有了更深入的理解。建议结合实际项目多加练习,在实践中巩固所学知识。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)