基于改进UNET的油页岩图像石油含油量基于UNet及其改进模型开发油页岩含油量智能识别系统,结合PyQt实现可视化界面
·
基于改进UNET的油页岩图像石油含油量基于UNet及其改进模型开发油页岩含油量智能识别系统,结合PyQt实现可视化界面
文章目录
以下文字及代码仅供参考学习使用。

基于改进UNET的油页岩图像石油含油量基于UNet及其改进模型开发油页岩含油量智能识别系统,结合PyQt实现可视化界面
🌿 基于改进 UNet 的油页岩图像石油含油量智能识别系统:结合 PyQt 实现可视化界面
📝 项目概述
本项目旨在开发一个基于 UNet 及其改进模型 的油页岩含油量智能识别系统,并通过 PyQt 实现用户友好的可视化界面。系统能够对油页岩图像进行二分类分割,识别出含油区域,并计算含油量百分比。
功能特点:
- 图像上传:支持用户上传油页岩图像。
- 自动检测:使用训练好的 UNet 模型进行图像分割和含油量识别。
- 结果展示:在界面上显示分割结果和含油量百分比。
- 结果下载:允许用户下载处理后的图像和结果报告。
🛠️ 技术栈
- 深度学习框架:PyTorch
- 模型架构:UNet 及其改进版本(如 ResUNet、Attention UNet 等)
- 前端界面:PyQt5
- 数据处理:OpenCV, NumPy
💻 系统实现
1. 数据准备与预处理
数据集结构
oil_shale_dataset/
├── images/
│ ├── train/
│ ├── val/
│ └── test/
├── masks/
│ ├── train/
│ ├── val/
│ └── test/
└── data.yaml
预处理代码
import cv2
import numpy as np
from torch.utils.data import Dataset
class OilShaleDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_path = os.path.join(self.image_dir, self.images[index])
mask_path = os.path.join(self.mask_dir, self.images[index].replace('.jpg', '.png'))
image = cv2.imread(img_path)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
return image, mask
2. 改进的 UNet 模型构建
UNet 架构
import torch
import torch.nn as nn
import torchvision.models as models
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
self.encoder = models.vgg16(pretrained=True).features
self.upconv1 = nn.ConvTranspose2d(features[-1], features[-2], kernel_size=2, stride=2)
self.decoder1 = nn.Sequential(
nn.Conv2d(features[-2] * 2, features[-2], kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(features[-2], features[-2], kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv2 = nn.ConvTranspose2d(features[-2], features[-3], kernel_size=2, stride=2)
self.decoder2 = nn.Sequential(
nn.Conv2d(features[-3] * 2, features[-3], kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(features[-3], features[-3], kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv3 = nn.ConvTranspose2d(features[-3], features[-4], kernel_size=2, stride=2)
self.decoder3 = nn.Sequential(
nn.Conv2d(features[-4] * 2, features[-4], kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(features[-4], features[-4], kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
enc1 = self.encoder[:4](x)
enc2 = self.encoder[4:9](enc1)
enc3 = self.encoder[9:16](enc2)
enc4 = self.encoder[16:](enc3)
dec1 = self.upconv1(enc4)
dec1 = torch.cat((dec1, enc3), dim=1)
dec1 = self.decoder1(dec1)
dec2 = self.upconv2(dec1)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec3 = self.upconv3(dec2)
dec3 = torch.cat((dec3, enc1), dim=1)
dec3 = self.decoder3(dec3)
return self.final_conv(dec3)
3. 训练与推理
训练代码
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
def train(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
for images, masks in tqdm(dataloader):
images = images.to(device)
masks = masks.to(device).unsqueeze(1).float()
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_loss = running_loss / len(dataloader.dataset)
print(f'Training Loss: {epoch_loss:.4f}')
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
dataset = OilShaleDataset('data/images/train', 'data/masks/train')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
print(f'Epoch {epoch + 1}/{num_epochs}')
train(model, dataloader, criterion, optimizer, device)
if __name__ == '__main__':
main()
4. PyQt5 可视化界面
主界面设计
import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QPushButton, QVBoxLayout, QWidget, QLabel, QFileDialog, QMessageBox
from PyQt5.QtGui import QPixmap
import cv2
import numpy as np
class OilShaleApp(QMainWindow):
def __init__(self):
super().__init__()
self.initUI()
def initUI(self):
self.setWindowTitle('油页岩含油量智能识别系统')
self.setGeometry(100, 100, 800, 600)
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
layout = QVBoxLayout()
self.upload_button = QPushButton('上传图片', self)
self.upload_button.clicked.connect(self.upload_image)
layout.addWidget(self.upload_button)
self.detect_button = QPushButton('开始检测', self)
self.detect_button.clicked.connect(self.start_detection)
layout.addWidget(self.detect_button)
self.download_button = QPushButton('结果下载', self)
self.download_button.clicked.connect(self.download_result)
layout.addWidget(self.download_button)
self.result_label = QLabel('含油量识别结果:', self)
layout.addWidget(self.result_label)
self.image_label = QLabel(self)
layout.addWidget(self.image_label)
self.central_widget.setLayout(layout)
def upload_image(self):
options = QFileDialog.Options()
file_name, _ = QFileDialog.getOpenFileName(self, "选择油页岩图像", "", "Images (*.png *.jpg *.bmp);;All Files (*)", options=options)
if file_name:
self.image_path = file_name
pixmap = QPixmap(file_name)
self.image_label.setPixmap(pixmap.scaled(600, 400))
def start_detection(self):
if hasattr(self, 'image_path'):
# 调用模型进行检测
oil_percentage = self.detect_oil_percentage(self.image_path)
self.result_label.setText(f'含油量识别结果:{oil_percentage:.2f}%')
else:
QMessageBox.warning(self, '警告', '请先上传一张图像!')
def detect_oil_percentage(self, image_path):
# 这里调用你的模型进行预测
# 示例代码:
image = cv2.imread(image_path)
# 使用模型预测
# ...
# 返回含油量百分比
return 25.487
def download_result(self):
if hasattr(self, 'result_label'):
result_text = self.result_label.text()
options = QFileDialog.Options()
file_name, _ = QFileDialog.getSaveFileName(self, "保存结果", "", "Text Files (*.txt);;All Files (*)", options=options)
if file_name:
with open(file_name, 'w') as f:
f.write(result_text)
QMessageBox.information(self, '成功', '结果已保存!')
else:
QMessageBox.warning(self, '警告', '没有可保存的结果!')
if __name__ == '__main__':
app = QApplication(sys.argv)
ex = OilShaleApp()
ex.show()
sys.exit(app.exec_())
以上文字及代码仅供参考学习使用。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)