YOLO26 自定义数据集 自动识别 train / val / test 子目录
·
YOLO26 自定义数据集 自动识别 train / val / test 子目录
flyfish
ultralytics/data/utils.py
自动识别 train/val/test 子目录、自动提取类别的逻辑
data
|
|-- train/
| |-- class1
| | |-- image1.png
| | |-- image2.png
| | |-- ...
| |
| |-- class2
| | |-- image8.png
| | |-- image9.png
| | |-- ...
| |-- ...
|
|-- test
| |-- class1
| | |-- image1.png
| | |-- image2.png
| | |-- ...
| |
| |-- class2
| | |-- image8.png
| | |-- image9.png
| | |-- ...
| |-- ...
|
|-- val/(optional)
| |-- class1
| | |-- image1.png
| | |-- image2.png
| | |-- ...
| |
| |-- class2
| | |-- image8.png
| | |-- image9.png
| | |-- ...
| |-- ...
1. 自动识别 train / val / test 子目录
位置:check_cls_dataset 函数内,数据集路径解析完成之后。
train_set = data_dir / "train"
val_set = (
data_dir / "val"
if (data_dir / "val").exists()
else data_dir / "validation"
if (data_dir / "validation").exists()
else None
) # data/test or data/val
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
data_dir 就是 model.train(data="xxx") 传入的数据集根目录路径
训练集:固定拼接 train 子目录,这是分类数据集的必填目录
验证集:优先查找 val 命名的子目录;找不到则兼容 validation 命名;两者都不存在则设为 None
测试集:存在 test 子目录则启用,不存在则设为 None
2. 从 train 目录自动提取类别
位置:紧接着上面的目录识别代码之后。
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # 统计类别数量
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # 提取所有类别名
names = dict(enumerate(sorted(names))) # 生成 索引→类别名 的映射字典
只读取 train 目录下的一级子文件夹,文件夹名称即为类别名称
统计一级子文件夹总数,得到类别数 nc
对类别名排序后,生成 {0: "class1", 1: "class2", ...} 的标准映射字典
3. 最终返回:把识别结果交给训练器
函数末尾会把所有解析结果打包返回,这个返回值就是 ClassificationTrainer 中 self.data 的内容:
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
def check_cls_dataset(dataset, split=""):
"""
Checks a classification dataset such as Imagenet.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str | Path): The name of the dataset.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns:
(dict): A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset.
- 'nc' (int): The number of classes in the dataset.
- 'names' (dict): A dictionary of class names in the dataset.
"""
# Download (optional if dataset=https://file.zip is passed directly)
if str(dataset).startswith(("http:/", "https:/")):
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
file = check_file(dataset)
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir():
LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
t = time.time()
if str(dataset) == "imagenet":
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
download(url, dir=data_dir.parent)
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
train_set = data_dir / "train"
val_set = (
data_dir / "val"
if (data_dir / "val").exists()
else data_dir / "validation"
if (data_dir / "validation").exists()
else None
) # data/test or data/val
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
if split == "val" and not val_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
elif split == "test" and not test_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names)))
# Print to console
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
prefix = f'{colorstr(f"{k}:")} {v}...'
if v is None:
LOGGER.info(prefix)
else:
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
nf = len(files) # number of files
nd = len({file.parent for file in files}) # number of directories
if nf == 0:
if k == "train":
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
else:
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
elif nd != nc:
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
else:
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)