Tensorflow中创建自己的TFRecord格式数据集
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
参考文献《TensorFlow实战Google深度学习框架》
TFRecord格式介绍
TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer格式(即二进制文件)存储,具体定义如下:
message Example{
Features features = 1;
};
message Features{
map<string,Feature> feature = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
它实际上存储了一个从属性名到取值的字典。其中属性名为一个字符串,属性取值可以为字符串(ByteList),实数列表(FloatList)和整数列表(Int64List)。比如对于一幅图像而言,可以将图像的像素信息保存成一个字符串,将图像对应的标签保存成整数列表。
创建TFRecord文件
先导入一些必要的库:(jupyter-notebook中实现)
import numpy as np
import tensorflow as tf
from PIL import Image
import os
import matplotlib.pyplot as plt
%matplotlib inline
数据预处理
我自己从网上下载了10张图片(3张猫,4张狗,3张马),分别存放在cat, dog和horse文件夹下,因为从网上下载的图片大小格式不统一,先将这些图片做预处理,函数如下(这里只是附上函数部分代码,文末会附上完整测试代码):
def preprocess(imageRawDir, imageDir):
"""
images preprocess
Arguments:
imageRawDir -- directory of primary images.
imageDir -- directory of processed images.
Return: none.
"""
imageNames = os.listdir(imageRawDir)
label = imageDir.split("/")[-2] # directory format:"./data/cat/"
for index, imageName in enumerate(imageNames):
image = Image.open(os.path.join(imageRawDir,imageName))
image = image.resize((256, 256))
savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg")
image.save(savePath)
预处理后的图片会保存在另一个指定的文件夹下。
写入到TFRecord文件
下面两个函数会在创建TFRecord文件的时候用到。因为如果不写成函数的形式,代码会很长,看起来也很头疼。
def _int64_feature(value):
"""
generate int64 feature.
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
"""
generate byte feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
下面这个函数完成的功能是读取我们之前预处理后的所有图片,并依次将每张图片写入到TFRecord文件中,这里因为数据很少,只写入到了一个TFRecord文件中,当数据量很大时,也可以写入多个文件中。
def createRecord(imageDir):
"""
create TFRecord data.
Arguments:
imageDir -- image directory.
Return: none.
"""
# create a writer to write TFRecord file
writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords"))
classNames = ["cat", "dog", "horse"]
for classIndex, className in enumerate(classNames):
print "class name = ",className
currentClassDir = os.path.join(imageDir,className)
print "current dir = ",currentClassDir
for index, imageName in enumerate(os.listdir(currentClassDir)):
image = Image.open(os.path.join(currentClassDir,imageName))
image_raw = image.tobytes() # convert image to binary format
print index, imageName
# write image data(pixel values and label) to Example Protocol Buffer
example = tf.train.Example(features = tf.train.Features(feature = {
"label": _int64_feature(classIndex),
"image_raw": _bytes_feature(image_raw),
}))
# write an example to TFRecord file
writer.write(example.SerializeToString())
writer.close()
读取TFRecord文件
def readRecord(recordName):
"""
read TFRecord data (images).
Arguments:
recordName -- the TFRecord file to be read.
return: data saved in recordName (image and label).
"""
filenameQueue = tf.train.string_input_producer([recordName])
reader = tf.TFRecordReader()
_, serializedExample = reader.read(filenameQueue)
features = tf.parse_single_example(serializedExample, features={
"label": tf.FixedLenFeature([], tf.int64),
"image_raw": tf.FixedLenFeature([], tf.string)
})
label = features["label"]
image = features["image_raw"]
image = tf.decode_raw(image, tf.uint8)
image = tf.reshape(image,[256,256,3])
label = tf.cast(label, tf.int32)
return image, label
注意,这里我们得到的返回值都是张量,需要在tensorflow中创建session后才能得到实际的数据。如下:
##test code
image, label = readRecord("./data/train.tfrecords")
print image, label
imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
#print image_batch.shape, label.shape
images, labels = sess.run([imageBatch, labelBatch])
print "batch shape = ", images.shape,"labels = ", labels
print "label = ", labels
for i in range(4):
plt.subplot(1,4,i+1)
plt.axis("off")
plt.imshow(images[i])
输出结果:
batch shape = (4, 256, 256, 3) labels = [0 1 1 0]
batch shape = (4, 256, 256, 3) labels = [1 0 0 0]
batch shape = (4, 256, 256, 3) labels = [1 2 2 1]
batch shape = (4, 256, 256, 3) labels = [1 2 1 2]
batch shape = (4, 256, 256, 3) labels = [0 0 1 0]
batch shape = (4, 256, 256, 3) labels = [2 1 2 1]
batch shape = (4, 256, 256, 3) labels = [1 1 0 0]
batch shape = (4, 256, 256, 3) labels = [0 2 1 1]
batch shape = (4, 256, 256, 3) labels = [2 2 1 0]
batch shape = (4, 256, 256, 3) labels = [1 2 0 1]
label = [1 2 0 1]
可以发现,最后一个的batch中的图像和标签是一一对应的(0: cat; 1: dog; 2: horse),说明我们已经成功从TFRecord文件中读出了数据。
完整样例代码
import numpy as np
import tensorflow as tf
from PIL import Image
import os
import matplotlib.pyplot as plt
currentDir = os.getcwd()
os.chdir(currentDir)
print currentDir
def preprocess(imageRawDir, imageDir):
"""
images preprocess
Arguments:
imageRawDir -- directory of primary images.
imageDir -- directory of processed images.
Return: none.
"""
imageNames = os.listdir(imageRawDir)
label = imageDir.split("/")[-2] # directory format:"./data/cat/"
for index, imageName in enumerate(imageNames):
image = Image.open(os.path.join(imageRawDir,imageName))
image = image.resize((256, 256))
savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg")
image.save(savePath)
##test code
catRawDir = "./data_raw/cat/"
catDir = "./data/cat/"
preprocess(catRawDir, catDir)
dogRawDir = "./data_raw/dog/"
dogDir = "./data/dog/"
preprocess(dogRawDir, dogDir)
horseRawDir = "./data_raw/horse/"
horseDir = "./data/horse/"
preprocess(horseRawDir, horseDir)
def _int64_feature(value):
"""
generate int64 feature.
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
"""
generate byte feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def createRecord(imageDir):
"""
create TFRecord data.
Arguments:
imageDir -- image directory.
Return: none.
"""
writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords"))
classNames = ["cat", "dog", "horse"]
for classIndex, className in enumerate(classNames):
print "class name = ",className
currentClassDir = os.path.join(imageDir,className)
print "current dir = ",currentClassDir
for index, imageName in enumerate(os.listdir(currentClassDir)):
image = Image.open(os.path.join(currentClassDir,imageName))
image_raw = image.tobytes() # convert image to binary format
print index, imageName
example = tf.train.Example(features = tf.train.Features(feature = {
"label": _int64_feature(classIndex),
"image_raw": _bytes_feature(image_raw),
}))
writer.write(example.SerializeToString())
writer.close()
##test code
createRecord(os.path.join(currentDir, "data/"))
def readRecord(recordName):
"""
read TFRecord data (images).
Arguments:
recordName -- the TFRecord file to be read.
return: data saved in recordName (image and label).
"""
filenameQueue = tf.train.string_input_producer([recordName])
reader = tf.TFRecordReader()
_, serializedExample = reader.read(filenameQueue)
features = tf.parse_single_example(serializedExample, features={
"label": tf.FixedLenFeature([], tf.int64),
"image_raw": tf.FixedLenFeature([], tf.string)
})
label = features["label"]
image = features["image_raw"]
image = tf.decode_raw(image, tf.uint8)
image = tf.reshape(image,[256,256,3])
label = tf.cast(label, tf.int32)
return image, label
##test code
image, label = readRecord("./data/train.tfrecords")
print image, label
imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)
##test code
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
#print image_batch.shape, label.shape
images, labels = sess.run([imageBatch, labelBatch])
print "batch shape = ", images.shape,"labels = ", labels
print "label = ", labels
for i in range(4):
plt.subplot(1,4,i+1)
plt.axis("off")
plt.imshow(images[i])
GitHub 加速计划 / te / tensorflow
184.54 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:1 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献6条内容
所有评论(0)