tensorflow学习笔记——获取训练数据集和测试数据集
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
训练神经网络模型之前,需要先获取训练数据集和测试数据集,本文介绍的获取数据集(get_data_train_test)的方法包括以下步骤:
1 在数据集文件夹中,不同类别图像分别放在以各自类别名称命名的文件夹中;
2 获取所有图像路径以及分类;
3 将分类转为字典格式;
4 将所有图像路径打乱;
5 将所有图像路径切分为训练部分和测试部分;
6 获取x部分
6.1 获取图像;
6.2 图像尺寸调整;
6.3 图像降维;
6.4 图像像素值取反;
6.5 图像像素值归一化;
7 获取y部分
7.1 获取图像的类别名称;
7.2 找到类别名称对应的id;
7.3 列表推到;
import os
import random
import math
import sys
import cv2
import numpy as np
from PIL import Image
#数据集路径
DATASET_TRAIN_TEST_DIR = 'D:/word/data_train_test'
DATASET_TEST_DIR = 'D:/word/data_test'
#随机种子
RANDOM_SEED = 0
#验证集数量
NUM_TEST = 20
#分类数量
NUM_CLASS = 10
#获取所有文件以及分类
def get_filenames_and_classes(dataset_dir):
#数据目录
directories = []
#分类名称
class_names = []
for filename in os.listdir(dataset_dir):
#合并文件路径
path = os.path.join(dataset_dir, filename)
#判断该路径是否为目录
if os.path.isdir(path):
#加入数据目录
directories.append(path)
#加入类别名称
class_names.append(filename)
photo_filenames = []
#循环每个分类的文件夹
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
#把图片加入图片列表
photo_filenames.append(path)
return photo_filenames, class_names
def get_xs(filenames):
xs = []
for i in range(len(filenames)):
image = Image.open(filenames[i]).convert('L')
blank = Image.new('L',[28,28],(255))
max_length = np.max(image.size)
w = int(image.size[0]*28/max_length)
h = int(image.size[1]*28/max_length)
#图像尺寸不超过28*28
image = image.resize((w,h), Image.NEAREST)
#图像尺寸调整为28*28
blank.paste(image, ((28-w)//2, (28-h)//2))
#图像尺寸调整为1*784
x = blank.resize((1,784))
#图像转换为数组
x = np.array(x)
#图像降维,如[[1],[2],[3]]变为[1,2,3]
x = x.squeeze()
#图像像素值取反
x = np.full(784, 255) - x
#图像像素值归一化
max = np.max(x)
x = x / np.full(784, max)
#获取多幅图像数据
xs.append(x)
return xs
def get_ys(filenames, class_names_to_ids):
ys = []
for i in range(len(filenames)):
#获得图片的类别名称
class_name = os.path.basename(os.path.dirname(filenames[i]))
#找到类别名称对应的id
class_id = class_names_to_ids[class_name]
#列表推到
y=[1 if id==class_id else 0 for id in range(NUM_CLASS)]
ys.append(y)
return ys
def get_data_train_test():
#获得所有图片路径以及分类
photo_filenames, class_names = get_filenames_and_classes(DATASET_TRAIN_TEST_DIR)
#把分类转为字典格式,类似于{'A':0, 'B':1, 'C':2}
class_names_to_ids = dict(zip(class_names, range(len(class_names))))
#把数据切分为训练集和测试集
random.seed(RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[NUM_TEST:]
testing_filenames = photo_filenames[:NUM_TEST]
train_xs = get_xs(training_filenames)
train_ys = get_ys(training_filenames, class_names_to_ids)
test_xs = get_xs(testing_filenames)
test_ys = get_ys(testing_filenames, class_names_to_ids)
return train_xs, train_ys, test_xs, test_ys
def get_data_test():
filenames = []
for filename in os.listdir(DATASET_TEST_DIR):
#合并文件路径
path = os.path.join(DATASET_TEST_DIR, filename)
filenames.append(path)
xs = get_xs(filenames)
return xs
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献9条内容
所有评论(0)