请参考:YOLOv5(ultralytics) 训练自己的数据集,VOC2007为例

Fire Dataset:

https://download.csdn.net/download/W1995S/20666141

将数据下载在yolov5/my_data文件夹下,进行文件夹合并,弄成如下:
在这里插入图片描述
1、ImageSets/Main文件夹下生成train.txt,val.txt,test.txt和trainval.txt四个文件(存放图片名字)
my_data目录下,创建split_train_val.py文件

# coding:utf-8

import os
import random
import argparse
import time

parser = argparse.ArgumentParser()
# xml文件的地址,根据自己的数据进行修改 xml一般存放在Annotations下
parser.add_argument('--xml_path', default='Annotations', type=str, help='input xml label path')
# 数据集的划分,地址选择自己数据下的ImageSets/Main
parser.add_argument('--txt_path', default='ImageSets/Main', type=str, help='output txt label path')
opt = parser.parse_args()

trainval_percent = 1.0
train_percent = 0.9
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path

# print('-'*20)
# time.sleep(1000)

total_xml = os.listdir(xmlfilepath)
if not os.path.exists(txtsavepath):
    os.makedirs(txtsavepath)

num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)

file_trainval = open(txtsavepath + '/trainval.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')
file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')

for i in list_index:
    name = total_xml[i][:-4] + '\n'
    # name = name.replace(' ', '')
    if i in trainval:
        file_trainval.write(name)
        if i in train:
            file_train.write(name)
        else:
            file_val.write(name)
    else:
        file_test.write(name)

file_trainval.close()
file_train.close()
file_val.close()
file_test.close()

2、创建yolo格式的label,my_data目录下,创建xml2yolo_label.py文件,会产生labels文件夹:
在这里插入图片描述

#!/usr/bin/env python
# -*- coding: utf8 -*-
import os
import time
import sys
from xml.etree import ElementTree
from xml.etree.ElementTree import Element, SubElement
from lxml import etree
import codecs
import cv2
from glob import glob

XML_EXT = '.xml'
ENCODE_METHOD = 'utf-8'


class PascalVocReader:
    def __init__(self, filepath):
        # shapes type:
        # [labbel, [(x1,y1), (x2,y2), (x3,y3), (x4,y4)], color, color, difficult]
        self.shapes = []
        self.filepath = filepath
        self.verified = False
        try:
            self.parseXML()
        except:
            pass

    def getShapes(self):
        return self.shapes

    def addShape(self, label, bndbox, filename, difficult):
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)
        points = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
        self.shapes.append((label, points, filename, difficult))

    def parseXML(self):
        assert self.filepath.endswith(XML_EXT), "Unsupport file format"
        parser = etree.XMLParser(encoding=ENCODE_METHOD)
        xmltree = ElementTree.parse(self.filepath, parser=parser).getroot()
        filename = xmltree.find('filename').text
        path = xmltree.find('path').text
        try:
            verified = xmltree.attrib['verified']
            if verified == 'yes':
                self.verified = True
        except KeyError:
            self.verified = False

        for object_iter in xmltree.findall('object'):
            bndbox = object_iter.find("bndbox")
            label = object_iter.find('name').text
            # Add chris

            difficult = False
            if object_iter.find('difficult') is not None:
                difficult = bool(int(object_iter.find('difficult').text))
            self.addShape(label, bndbox, path, difficult)
        return True


classes = dict()
num_classes = 0

# try:
#     input = raw_input
# except NameError:
#     pass

parentpath = os.getcwd()  # /home/cv/PycharmProjects/YOLOv5/yolov5_ultralytics_v5/my_data/fire

# parentpath = './'  # "Directory path with parent dir before xml_dir or img_dir"
# addxmlpath = parentpath + '/validation/annotations'  # "Directory path with XML files"
# addimgpath = parentpath + '/validation/images'  # "Directory path with IMG files"
# outputpath = parentpath + '/labels'  # "output folder for yolo format"

addxmlpath = parentpath + '/Annotations'  # "Directory path with XML files"
addimgpath = parentpath + '/JPEGImages'  # "Directory path with IMG files"
outputpath = parentpath + '/labels'  # "output folder for yolo format"
classes_txt = parentpath + 'fire_classes.txt'  # "File containing classes"
ext = '.jpg'  # "Image file extension [.jpg or .png]"

if not os.path.exists(outputpath):
    os.makedirs(outputpath)

if os.path.isfile(classes_txt):
    with open(classes_txt, "r") as f:
        class_list = f.read().strip().split()
        classes = {k: v for (v, k) in enumerate(class_list)}

xmlPaths = glob(addxmlpath + "/*.xml")
# imgPaths = glob(addimgpath + "/*"+ext)

for xmlPath in xmlPaths:
    tVocParseReader = PascalVocReader(xmlPath)
    shapes = tVocParseReader.getShapes()

    with open(outputpath + "/" + os.path.basename(xmlPath)[:-4] + ".txt", "w") as f:
        for shape in shapes:
            class_name = shape[0]
            box = shape[1]
            # filename = os.path.splittext(xmlPath)[0] + ext
            filename = os.path.splitext(addimgpath + "/" + os.path.basename(xmlPath)[:-4])[0] + ext

            if class_name not in classes.keys():
                classes[class_name] = num_classes
                num_classes += 1
            class_idx = classes[class_name]

            (height, width, _) = cv2.imread(filename).shape

            coord_min = box[0]
            coord_max = box[2]

            xcen = float((coord_min[0] + coord_max[0])) / 2 / width
            ycen = float((coord_min[1] + coord_max[1])) / 2 / height
            w = float((coord_max[0] - coord_min[0])) / width
            h = float((coord_max[1] - coord_min[1])) / height

            f.write("%d %.06f %.06f %.06f %.06f\n" % (class_idx, xcen, ycen, w, h))
            print(class_idx, xcen, ycen, w, h)

with open(parentpath + "classes.txt", "w") as f:
    for key in classes.keys():
        f.write("%s\n" % key)
        print(key)

3、创建训练、测试图片数据路径
my_data目录下,创建firedata_path.py文件,会产生3个txt文件:

# -*- coding: utf-8 -*-
import xml.etree.ElementTree as ET
import os

sets = ['train', 'val', 'test']
# classes = ["a", "b"]   # 改成自己的类别
classes = ["fire"]  # class names
abs_path = os.getcwd()


def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = (box[0] + box[1]) / 2.0 - 1
    y = (box[2] + box[3]) / 2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return x, y, w, h


def convert_annotation(image_id):
    in_file = open(abs_path + '/Annotations/%s.xml' % (image_id), encoding='UTF-8')
    out_file = open(abs_path + '/labels/%s.txt' % (image_id), 'w')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        # difficult = obj.find('Difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        b1, b2, b3, b4 = b
        # 标注越界修正
        if b2 > w:
            b2 = w
        if b4 > h:
            b4 = h
        b = (b1, b2, b3, b4)
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')


for image_set in sets:
    if not os.path.exists(abs_path + '/labels/'):
        os.makedirs(abs_path + '/labels/')

    image_ids = open(abs_path + '/ImageSets/Main/%s.txt' % (image_set)).read()#.strip().split()

    image_ids = image_ids.split('\n')[:-1]
    print(image_ids)

    list_file = open(abs_path + '/%s.txt' % (image_set), 'w')
    for image_id in image_ids:
        list_file.write(abs_path + '/JPEGImages/%s.jpg\n' % (image_id))
        convert_annotation(image_id)
    list_file.close()

在这里插入图片描述在这里插入图片描述
4、配置数据集文件
在my_data文件夹下,新建一个fire.yaml文件

train: /home/cv/PycharmProjects/YOLOv5/yolov5_ultralytics_v5/my_data/fire/train.txt  # train images  8966 images
val: /home/cv/PycharmProjects/YOLOv5/yolov5_ultralytics_v5/my_data/fire/val.txt  # val images  997 images
test:  # test images (optional)

# Classes
nc: 1  # number of classes

names: [ 'fire' ]   # class names

训练

train.py进行修改,跑个100轮,batch_size看显存:

python train.py --batch 16 --epoch 100 --weights weights/yolov5s.pt --data my_data/fire/fire.yaml --cfg models/yolov5s.yaml

检测

视频:

python detect.py --source data/videos/fire.mp4 --weights runs/train/exp4/weights/best.pt

请添加图片描述
请添加图片描述
请添加图片描述
数据集小啊,好像只能检测红色火焰。

GitHub 加速计划 / ul / ultralytics
27.23 K
5.39 K
下载
ultralytics - 提供 YOLOv8 模型,用于目标检测、图像分割、姿态估计和图像分类,适合机器学习和计算机视觉领域的开发者。
最近提交(Master分支:2 个月前 )
e48a42ec Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> 4 天前
1790ca0f Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> 4 天前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐