什么是PyTorch Lightning

PyTorch Lightning(PL)的主要优势包括:

  • 训练自动化:PyTorch Lightning可以帮助开发者处理训练循环,包括数据加载、批次迭代、前向传播、损失计算和反向传播等。100行左右的代码就可以写出完整的深度学习项目。
  • 分布式训练支持:PyTorch Lightning支持分布式训练,可以在多个GPU或多台机器上进行训练,从而加快训练速度,而且配置特别简单。
  • 可复现性:PyTorch Lightning提供的API方便用户使用固定的随机种子和训练环境,确保每次运行的结果是可复现的。

总之,PyTorch Lightning是一个强大而灵活的框架,可以帮助用户更高效地进行深度学习模型的训练和开发。它提供了许多易用的功能和工具,使用户可以更好地管理和组织训练代码,提高工作效率。

常用功能

pl深度学习项目的基本思路:

  1. 定义PyTorch Lightning Module
  2. 定义Trainer
  3. 调用Trainer训练并检验深度学习模块

自动储存训练日志

PL的便捷功能其中之一是在PyTorch中记录包括训练误差、测试误差的训练日志。PL默认使用Tensorboard来记录日志。

要查看日志,可以在终端中运行以下命令:

tensorboard --logdir=lightning_logs/

可以使用on_epoch参数来确定是否记录每个epoch的累积指标。

trainer = pl.Trainer(max_epochs=MAX_EPOCHS, 
num_sanity_val_steps=0,  ) # num_sanity_val_steps=0 because of va_spo_list

使用torchmetrics一行代码评估模型

torchmetrics是一个用于PyTorch深度学习库的指标计算和评估工具包。它提供了一系列常用的评估指标,用于衡量模型在不同任务上的性能,包括分类、回归、分割和生成等。

torchmetrics支持各种常见的评估指标,如准确率、精确度、召回率、F1分数、AUC、平均绝对误差、均方根误差等。它还提供了一些高级指标,如多类别混淆矩阵、Jaccard系数、Dice系数和IoU等。

torchmetrics的设计目标是提供一种简洁、灵活和可扩展的方式来计算和记录模型性能指标。它与PyTorch框架紧密集成,可以无缝地与PyTorch的训练和验证流程结合使用。这一点从本文文末提供的代码可以感受得到。

加载训练好的checkpoint

# load the model  
CHECKPOINT_PATH = 'lightning_logs/version_9/checkpoints/epoch=59-step=120000.ckpt'  
TEMP_VIDEO_PATH = 'tmp_video'  
MODEL_TYPE = 'slowfast'  
classifier = SignLanguageClassifier.load_from_checkpoint(CHECKPOINT_PATH, strict=False, model_type=MODEL_TYPE)  
classifier.model_type = MODEL_TYPE  
trainer = pl.Trainer()

# make inference 
trainer.test(classifier, test_dataloader)

定义模型运行的设备:CUDA or CPU?

trainer = pl.Trainer(max_epochs=MAX_EPOCH, 
devices='auto', accelerator='auto', 	# 如果只用CPU,把'auto'改成'cpu'就行了
logger=tensorboard_logger)

深度学习实战项目模板

下面是我在实战中用PL写的图片分类代码,对整个数据集进行5折交叉验证后汇报平均准确率和混淆矩阵。这个代码可以很直观地体现出PL的逻辑和整体流程。

"""  
a Python script to train ResNet-18 using PyTorch Lightning. The dataset includes 5 categories.  
Report the classification accuracy and confusion matrix with torch-metrics.  
  
Use 5-fold stratified sampling.  
Report the final average classification accuracies at the end of the program.  
"""  
  
  
import numpy as np  
import pytorch_lightning as pl  
from pytorch_lightning.loggers import TensorBoardLogger  
import torch  
from torch.nn import functional as F  
from torch.utils.data import DataLoader, TensorDataset  
from torchvision import models, transforms  
import torchmetrics  
from sklearn.model_selection import StratifiedKFold  
import seaborn  
import matplotlib.pyplot as plt  
  
  
MAX_EPOCH = 100  
  
  
class Classifier(pl.LightningModule):  
    def __init__(self, num_classes: int, model_type: str = 'resnet18'):  
        super().__init__()  
        self.model_type = model_type  
        if model_type == 'resnet18':  
            self.model = models.resnet18(pretrained=True)  
            self.model.fc = torch.nn.Sequential(  
                    torch.nn.Linear(self.model.fc.in_features, 128),  
                    torch.nn.ReLU(),  
                    torch.nn.Linear(128, 64),  
                    torch.nn.ReLU(),  
                    torch.nn.Linear(64, num_classes)  
                    )  
        elif model_type == 'mlp':  
            self.model = torch.nn.Sequential(  
                torch.nn.Linear(40, 128),  
                torch.nn.ReLU(),  
                torch.nn.Linear(128, 64),  
                torch.nn.ReLU(),  
                torch.nn.Linear(64, num_classes)  
            )  
        else:  
            raise ValueError(f'Invalid model_type: {model_type}')  
        self.accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes)  
        self.conf_mat = torchmetrics.classification.MulticlassConfusionMatrix(num_classes, normalize='true')  
  
    def forward(self, x):  
        if self.model_type == 'resnet18':  
            x = x.view(x.size(0), 1, -1, 1)    # Reshape 1D data into a single-channel "image"  
            x = torch.repeat_interleave(x, repeats=3, dim=1)  
        return self.model(x.float())  
  
    def training_step(self, batch, batch_idx):  
        x, y = batch  
        y_hat = self(x)  
        loss = F.cross_entropy(y_hat, y.long())  
        self.log('train_loss', loss, )  
        return loss  
  
    def validation_step(self, batch, batch_idx):  
        x, y = batch  
        y_hat = self(x)  
        self.log('val_accuracy', self.accuracy, on_epoch=True, prog_bar=True)  
        self.log('val_loss', F.cross_entropy(y_hat, y.long()), on_step=True, prog_bar=True)  
        self.conf_mat.update(y_hat, y)  
        self.accuracy.update(y_hat, y)  
  
    def on_validation_end(self):  
        conf_matrix = self.conf_mat.compute()  
        print(conf_matrix)  
        plt.figure()  
        seaborn.heatmap(conf_matrix.cpu(), annot=True)  
        plt.savefig(f'conf_mat_{fold_id}.png')  
        accuracy_computed = self.accuracy.compute()  
        print(f'Fold Accuracy={accuracy_computed}')  
  
    def configure_optimizers(self):  
        return torch.optim.Adam(self.parameters(), lr=0.00001)  
  
# Load data and labels from .npy file  
data_and_labels = np.load('data/data_and_labels.npy', allow_pickle=True).item()  
X = data_and_labels['X']  
y = data_and_labels['y']  
  
# Prepare 5-fold stratified sampling  
skf = StratifiedKFold(n_splits=5, shuffle=True)  
  
# Initialize list for storing classification accuracies  
accuracies = []  
fold_id = 0  
# Perform 5-fold stratified sampling  
for train_index, val_index in skf.split(X, y):  
    X_train, X_val = X[train_index], X[val_index]  
    y_train, y_val = y[train_index], y[val_index]  
  
    # Create TensorDatasets  
    train_data = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))  
    val_data = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))  
  
    # Create DataLoaders  
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True)  
    val_loader = DataLoader(val_data, batch_size=64)  
  
    # Model  
    model = Classifier(num_classes=5)  
  
    # Training  
    tensorboard_logger = TensorBoardLogger(save_dir='.', version=fold_id)  
    trainer = pl.Trainer(max_epochs=MAX_EPOCH, devices='auto', accelerator='auto', logger=tensorboard_logger)  
    trainer.fit(model, train_loader, val_loader)  
  
    fold_id += 1
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐