训练神经网络模型之前,需要先获取训练数据集和测试数据集,本文介绍的获取数据集(get_data_train_test)的方法包括以下步骤:
1 在数据集文件夹中,不同类别图像分别放在以各自类别名称命名的文件夹中;
2 获取所有图像路径以及分类;
3 将分类转为字典格式;
4 将所有图像路径打乱;
5 将所有图像路径切分为训练部分和测试部分;
6 获取x部分
6.1 获取图像;
6.2 图像尺寸调整;
6.3 图像降维;
6.4 图像像素值取反;
6.5 图像像素值归一化;
7 获取y部分
7.1 获取图像的类别名称;
7.2 找到类别名称对应的id;
7.3 列表推到;

import os
import random
import math
import sys
import cv2
import numpy as np
from PIL import Image

#数据集路径
DATASET_TRAIN_TEST_DIR = 'D:/word/data_train_test'
DATASET_TEST_DIR = 'D:/word/data_test'
#随机种子
RANDOM_SEED = 0
#验证集数量
NUM_TEST = 20
#分类数量
NUM_CLASS = 10

#获取所有文件以及分类
def get_filenames_and_classes(dataset_dir):
	#数据目录
	directories = []
	#分类名称
	class_names = []
	for filename in os.listdir(dataset_dir):
		#合并文件路径
		path = os.path.join(dataset_dir, filename)
		#判断该路径是否为目录
		if os.path.isdir(path):
			#加入数据目录
			directories.append(path)
			#加入类别名称
			class_names.append(filename)

	photo_filenames = []
	#循环每个分类的文件夹
	for directory in directories:
		for filename in os.listdir(directory):
			path = os.path.join(directory, filename)
			#把图片加入图片列表
			photo_filenames.append(path)

	return photo_filenames, class_names

def get_xs(filenames):
	xs = []
	for i in range(len(filenames)):
		image = Image.open(filenames[i]).convert('L')
		blank = Image.new('L',[28,28],(255))
		max_length = np.max(image.size)
		w = int(image.size[0]*28/max_length)
		h = int(image.size[1]*28/max_length)
		#图像尺寸不超过28*28
		image = image.resize((w,h), Image.NEAREST)
		#图像尺寸调整为28*28
		blank.paste(image, ((28-w)//2, (28-h)//2))
		#图像尺寸调整为1*784
		x = blank.resize((1,784))
		#图像转换为数组
		x = np.array(x)
		#图像降维,如[[1],[2],[3]]变为[1,2,3]
		x = x.squeeze()
		#图像像素值取反
		x = np.full(784, 255) - x
		#图像像素值归一化
		max = np.max(x)
		x = x / np.full(784, max)
		#获取多幅图像数据
		xs.append(x)
	return xs

def get_ys(filenames, class_names_to_ids):
	ys = []
	for i in range(len(filenames)):
		#获得图片的类别名称
		class_name = os.path.basename(os.path.dirname(filenames[i]))
		#找到类别名称对应的id
		class_id = class_names_to_ids[class_name]
		#列表推到
		y=[1 if id==class_id else 0 for id in range(NUM_CLASS)]
		ys.append(y)
	return ys

def get_data_train_test():
	#获得所有图片路径以及分类
	photo_filenames, class_names = get_filenames_and_classes(DATASET_TRAIN_TEST_DIR)

	#把分类转为字典格式,类似于{'A':0, 'B':1, 'C':2}
	class_names_to_ids = dict(zip(class_names, range(len(class_names))))

	#把数据切分为训练集和测试集
	random.seed(RANDOM_SEED)
	random.shuffle(photo_filenames)
	training_filenames = photo_filenames[NUM_TEST:]
	testing_filenames = photo_filenames[:NUM_TEST]
	train_xs = get_xs(training_filenames)
	train_ys = get_ys(training_filenames, class_names_to_ids)
	test_xs = get_xs(testing_filenames)
	test_ys = get_ys(testing_filenames, class_names_to_ids)

	return train_xs, train_ys, test_xs, test_ys

def get_data_test():
	filenames = []
	for filename in os.listdir(DATASET_TEST_DIR):
		#合并文件路径
		path = os.path.join(DATASET_TEST_DIR, filename)
		filenames.append(path)
	xs = get_xs(filenames)
	return xs

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐