基于改进UNET的油页岩图像石油含油量基于UNet及其改进模型开发油页岩含油量智能识别系统,结合PyQt实现可视化界面


以下文字及代码仅供参考学习使用。

在这里插入图片描述
基于改进UNET的油页岩图像石油含油量基于UNet及其改进模型开发油页岩含油量智能识别系统,结合PyQt实现可视化界面在这里插入图片描述

🌿 基于改进 UNet 的油页岩图像石油含油量智能识别系统:结合 PyQt 实现可视化界面


📝 项目概述

本项目旨在开发一个基于 UNet 及其改进模型 的油页岩含油量智能识别系统,并通过 PyQt 实现用户友好的可视化界面。系统能够对油页岩图像进行二分类分割,识别出含油区域,并计算含油量百分比。

功能特点:

  • 图像上传:支持用户上传油页岩图像。
  • 自动检测:使用训练好的 UNet 模型进行图像分割和含油量识别。
  • 结果展示:在界面上显示分割结果和含油量百分比。
  • 结果下载:允许用户下载处理后的图像和结果报告。

🛠️ 技术栈

  • 深度学习框架:PyTorch
  • 模型架构:UNet 及其改进版本(如 ResUNet、Attention UNet 等)
  • 前端界面:PyQt5
  • 数据处理:OpenCV, NumPy

💻 系统实现

1. 数据准备与预处理

数据集结构
oil_shale_dataset/
├── images/
│   ├── train/
│   ├── val/
│   └── test/
├── masks/
│   ├── train/
│   ├── val/
│   └── test/
└── data.yaml
预处理代码
import cv2
import numpy as np
from torch.utils.data import Dataset

class OilShaleDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace('.jpg', '.png'))
        
        image = cv2.imread(img_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

2. 改进的 UNet 模型构建

UNet 架构
import torch
import torch.nn as nn
import torchvision.models as models

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.encoder = models.vgg16(pretrained=True).features
        
        self.upconv1 = nn.ConvTranspose2d(features[-1], features[-2], kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(features[-2] * 2, features[-2], kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[-2], features[-2], kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.upconv2 = nn.ConvTranspose2d(features[-2], features[-3], kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(features[-3] * 2, features[-3], kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[-3], features[-3], kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.upconv3 = nn.ConvTranspose2d(features[-3], features[-4], kernel_size=2, stride=2)
        self.decoder3 = nn.Sequential(
            nn.Conv2d(features[-4] * 2, features[-4], kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[-4], features[-4], kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder[:4](x)
        enc2 = self.encoder[4:9](enc1)
        enc3 = self.encoder[9:16](enc2)
        enc4 = self.encoder[16:](enc3)
        
        dec1 = self.upconv1(enc4)
        dec1 = torch.cat((dec1, enc3), dim=1)
        dec1 = self.decoder1(dec1)
        
        dec2 = self.upconv2(dec1)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec3 = self.upconv3(dec2)
        dec3 = torch.cat((dec3, enc1), dim=1)
        dec3 = self.decoder3(dec3)
        
        return self.final_conv(dec3)

3. 训练与推理

训练代码
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(dataloader):
        images = images.to(device)
        masks = masks.to(device).unsqueeze(1).float()

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    print(f'Training Loss: {epoch_loss:.4f}')

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    dataset = OilShaleDataset('data/images/train', 'data/masks/train')
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    num_epochs = 10
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        train(model, dataloader, criterion, optimizer, device)

if __name__ == '__main__':
    main()

4. PyQt5 可视化界面

主界面设计
import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QPushButton, QVBoxLayout, QWidget, QLabel, QFileDialog, QMessageBox
from PyQt5.QtGui import QPixmap
import cv2
import numpy as np

class OilShaleApp(QMainWindow):
    def __init__(self):
        super().__init__()
        self.initUI()

    def initUI(self):
        self.setWindowTitle('油页岩含油量智能识别系统')
        self.setGeometry(100, 100, 800, 600)

        self.central_widget = QWidget()
        self.setCentralWidget(self.central_widget)

        layout = QVBoxLayout()

        self.upload_button = QPushButton('上传图片', self)
        self.upload_button.clicked.connect(self.upload_image)
        layout.addWidget(self.upload_button)

        self.detect_button = QPushButton('开始检测', self)
        self.detect_button.clicked.connect(self.start_detection)
        layout.addWidget(self.detect_button)

        self.download_button = QPushButton('结果下载', self)
        self.download_button.clicked.connect(self.download_result)
        layout.addWidget(self.download_button)

        self.result_label = QLabel('含油量识别结果:', self)
        layout.addWidget(self.result_label)

        self.image_label = QLabel(self)
        layout.addWidget(self.image_label)

        self.central_widget.setLayout(layout)

    def upload_image(self):
        options = QFileDialog.Options()
        file_name, _ = QFileDialog.getOpenFileName(self, "选择油页岩图像", "", "Images (*.png *.jpg *.bmp);;All Files (*)", options=options)
        if file_name:
            self.image_path = file_name
            pixmap = QPixmap(file_name)
            self.image_label.setPixmap(pixmap.scaled(600, 400))

    def start_detection(self):
        if hasattr(self, 'image_path'):
            # 调用模型进行检测
            oil_percentage = self.detect_oil_percentage(self.image_path)
            self.result_label.setText(f'含油量识别结果:{oil_percentage:.2f}%')
        else:
            QMessageBox.warning(self, '警告', '请先上传一张图像!')

    def detect_oil_percentage(self, image_path):
        # 这里调用你的模型进行预测
        # 示例代码:
        image = cv2.imread(image_path)
        # 使用模型预测
        # ...
        # 返回含油量百分比
        return 25.487

    def download_result(self):
        if hasattr(self, 'result_label'):
            result_text = self.result_label.text()
            options = QFileDialog.Options()
            file_name, _ = QFileDialog.getSaveFileName(self, "保存结果", "", "Text Files (*.txt);;All Files (*)", options=options)
            if file_name:
                with open(file_name, 'w') as f:
                    f.write(result_text)
                QMessageBox.information(self, '成功', '结果已保存!')
        else:
            QMessageBox.warning(self, '警告', '没有可保存的结果!')

if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = OilShaleApp()
    ex.show()
    sys.exit(app.exec_())

以上文字及代码仅供参考学习使用。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐