一 why tfrecord?

对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords
参考 :TensorFlow高效读取数据

二 代码详解

1导入库
import tensorflow as tf
import numpy
2 构建writer,用于写入数据
writer = tf.python_io.TFRecordWriter('test.tfrecord')
3 分俩步创建a,b,c三个不同格式的列表并保存到writer中
for i in range(0, 2):
    a = 0.618 + i
    b = [2016 + i, 2017+i]
    c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
    c = c.astype(numpy.uint8)
    c_raw = c.tostring()#这里是把c换了一种格式存储
    print 'i:',i
    print '  a:',a
    print '  b:',b
    print '  c:',c
    example = tf.train.Example(
        features = tf.train.Features( #固定模式,字典格式保存
            feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])),
                       'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)),
                       'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))}))
    serialized = example.SerializeToString()
    writer.write(serialized)
    print '   writer',i,'DOWN!'
writer.close()
i: 0
  a: 0.618
  b: [2016, 2017]
  c: [[0 1 2]
 [3 4 5]]
   writer 0 DOWN!
i: 1
  a: 1.618
  b: [2017, 2018]
  c: [[1 2 3]
 [4 5 6]]
   writer 1 DOWN!
4 创建文件读取队列并读取其中内容(字典格式)
# output file name string to a queue
filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example

features = tf.parse_single_example(serialized_example,
        features={
            'a': tf.FixedLenFeature([], tf.float32),
            'b': tf.FixedLenFeature([2], tf.int64),
            'c': tf.FixedLenFeature([], tf.string)
        }
    )
5 读取内容
a_out = features['a']
b_out = features['b']
c_raw_out = features['c']
c_out = tf.decode_raw(c_raw_out, tf.uint8)
c_out = tf.reshape(c_out, [2, 3])
6 显示格式
print a_out
print b_out
print c_out
Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64)
Tensor("Reshape:0", shape=(2, 3), dtype=uint8)
7 通过shuffle_batch喂入数据
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3, 
                                capacity=200, min_after_dequeue=100, num_threads=2)
8 构建sess,读入数据并显示
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
# print(a_val, b_val, c_val)
print 'first batch:'
print '  a_val:',a_val
print '  b_val:',b_val
print '  c_val:',c_val
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print 'second batch:'
print '  a_val:',a_val
print '  b_val:',b_val
print '  c_val:',c_val
first batch:
  a_val: [ 0.61799997  1.61800003  0.61799997]
  b_val: [[2016 2017]
 [2017 2018]
 [2016 2017]]
  c_val: [[[0 1 2]
  [3 4 5]]

 [[1 2 3]
  [4 5 6]]

 [[0 1 2]
  [3 4 5]]]
second batch:
  a_val: [ 0.61799997  0.61799997  1.61800003]
  b_val: [[2016 2017]
 [2016 2017]
 [2017 2018]]
  c_val: [[[0 1 2]
  [3 4 5]]

 [[0 1 2]
  [3 4 5]]

 [[1 2 3]
  [4 5 6]]]

之前定义了batch=3,所以每个batch输入三个数据,并且是随机读入的。

三 完整代码

for i in range(0,2):
    print i
0
1
1 把数据写成tfrecord文件
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy
writer = tf.python_io.TFRecordWriter('test1111.tfrecord')
for i in range(0, 2):
    a = 0.618 + i
    b = [2016 + i, 2017+i]
    #c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
    #c = c.astype(numpy.uint8)
    c = "你好哦"+str(i)
    #c_raw = c.tostring()#这里是把c换了一种格式存储
    c_raw = c
    print 'i:',i
    print '  a:',a
    print '  b:',b
    print '  c:',c
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])),
                       'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)),
                       'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))}))
    serialized = example.SerializeToString()
    writer.write(serialized)
    print '   writer',i,'DOWN!'
writer.close()
i: 0
  a: 0.618
  b: [2016, 2017]
  c: 你好哦0
   writer 0 DOWN!
i: 1
  a: 1.618
  b: [2017, 2018]
  c: 你好哦1
   writer 1 DOWN!
2数据提取及显示
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# output file name string to a queue
filename_queue = tf.train.string_input_producer(['test1111.tfrecord'], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
yyp
features = tf.parse_single_example(serialized_example,
        features={
            'a': tf.FixedLenFeature([], tf.float32),
            'b': tf.FixedLenFeature([2], tf.int64),
            'c': tf.FixedLenFeature([],tf.string)
        }
    )
a_out = features['a']
b_out = features['b']
c_out = features['c']
#c_raw_out = features['c']
#c_raw_out = tf.sparse_to_dense(features['c'])
#c_out = tf.decode_raw(c_raw_out, tf.uint8)
print a_out
print b_out
print c_out
#c_out = tf.reshape(c_out, [2, 3])
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3, 
                                capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
# print(a_val, b_val, c_val)
print 'first batch:'
print '  a_val:',a_val
print '  b_val:',b_val
print '  c_val:',c_val[0].decode('utf-8')
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print 'second batch:'
print '  a_val:',a_val
print '  b_val:',b_val
print '  c_val:',str(c_val).decode('utf-8')
Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64)
Tensor("ParseSingleExample/Squeeze_c:0", shape=(), dtype=string)
first batch:
  a_val: [ 0.61799997  0.61799997  1.61800003]
  b_val: [[2016 2017]
 [2016 2017]
 [2017 2018]]
  c_val: 你好哦0
second batch:
  a_val: [ 1.61800003  0.61799997  1.61800003]
  b_val: [[2017 2018]
 [2016 2017]
 [2017 2018]]
  c_val: ['\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61'
 '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60'
 '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61']

不解码输出的是unicode格式的

print '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60'.decode('utf-8')
你好哦0
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:3 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 3 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 3 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐