TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
一 why tfrecord?
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords
参考 :TensorFlow高效读取数据
二 代码详解
1导入库
import tensorflow as tf
import numpy
2 构建writer,用于写入数据
writer = tf.python_io.TFRecordWriter('test.tfrecord')
3 分俩步创建a,b,c三个不同格式的列表并保存到writer中
for i in range(0, 2):
a = 0.618 + i
b = [2016 + i, 2017+i]
c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
c = c.astype(numpy.uint8)
c_raw = c.tostring()#这里是把c换了一种格式存储
print 'i:',i
print ' a:',a
print ' b:',b
print ' c:',c
example = tf.train.Example(
features = tf.train.Features( #固定模式,字典格式保存
feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])),
'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)),
'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))}))
serialized = example.SerializeToString()
writer.write(serialized)
print ' writer',i,'DOWN!'
writer.close()
i: 0
a: 0.618
b: [2016, 2017]
c: [[0 1 2]
[3 4 5]]
writer 0 DOWN!
i: 1
a: 1.618
b: [2017, 2018]
c: [[1 2 3]
[4 5 6]]
writer 1 DOWN!
4 创建文件读取队列并读取其中内容(字典格式)
# output file name string to a queue
filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
features = tf.parse_single_example(serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([2], tf.int64),
'c': tf.FixedLenFeature([], tf.string)
}
)
5 读取内容
a_out = features['a']
b_out = features['b']
c_raw_out = features['c']
c_out = tf.decode_raw(c_raw_out, tf.uint8)
c_out = tf.reshape(c_out, [2, 3])
6 显示格式
print a_out
print b_out
print c_out
Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64)
Tensor("Reshape:0", shape=(2, 3), dtype=uint8)
7 通过shuffle_batch喂入数据
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3,
capacity=200, min_after_dequeue=100, num_threads=2)
8 构建sess,读入数据并显示
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
# print(a_val, b_val, c_val)
print 'first batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',c_val
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print 'second batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',c_val
first batch:
a_val: [ 0.61799997 1.61800003 0.61799997]
b_val: [[2016 2017]
[2017 2018]
[2016 2017]]
c_val: [[[0 1 2]
[3 4 5]]
[[1 2 3]
[4 5 6]]
[[0 1 2]
[3 4 5]]]
second batch:
a_val: [ 0.61799997 0.61799997 1.61800003]
b_val: [[2016 2017]
[2016 2017]
[2017 2018]]
c_val: [[[0 1 2]
[3 4 5]]
[[0 1 2]
[3 4 5]]
[[1 2 3]
[4 5 6]]]
之前定义了batch=3,所以每个batch输入三个数据,并且是随机读入的。
三 完整代码
for i in range(0,2):
print i
0
1
1 把数据写成tfrecord文件
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy
writer = tf.python_io.TFRecordWriter('test1111.tfrecord')
for i in range(0, 2):
a = 0.618 + i
b = [2016 + i, 2017+i]
#c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
#c = c.astype(numpy.uint8)
c = "你好哦"+str(i)
#c_raw = c.tostring()#这里是把c换了一种格式存储
c_raw = c
print 'i:',i
print ' a:',a
print ' b:',b
print ' c:',c
example = tf.train.Example(
features = tf.train.Features(
feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])),
'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)),
'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))}))
serialized = example.SerializeToString()
writer.write(serialized)
print ' writer',i,'DOWN!'
writer.close()
i: 0
a: 0.618
b: [2016, 2017]
c: 你好哦0
writer 0 DOWN!
i: 1
a: 1.618
b: [2017, 2018]
c: 你好哦1
writer 1 DOWN!
2数据提取及显示
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# output file name string to a queue
filename_queue = tf.train.string_input_producer(['test1111.tfrecord'], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
yyp
features = tf.parse_single_example(serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([2], tf.int64),
'c': tf.FixedLenFeature([],tf.string)
}
)
a_out = features['a']
b_out = features['b']
c_out = features['c']
#c_raw_out = features['c']
#c_raw_out = tf.sparse_to_dense(features['c'])
#c_out = tf.decode_raw(c_raw_out, tf.uint8)
print a_out
print b_out
print c_out
#c_out = tf.reshape(c_out, [2, 3])
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3,
capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
# print(a_val, b_val, c_val)
print 'first batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',c_val[0].decode('utf-8')
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print 'second batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',str(c_val).decode('utf-8')
Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64)
Tensor("ParseSingleExample/Squeeze_c:0", shape=(), dtype=string)
first batch:
a_val: [ 0.61799997 0.61799997 1.61800003]
b_val: [[2016 2017]
[2016 2017]
[2017 2018]]
c_val: 你好哦0
second batch:
a_val: [ 1.61800003 0.61799997 1.61800003]
b_val: [[2017 2018]
[2016 2017]
[2017 2018]]
c_val: ['\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61'
'\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60'
'\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61']
不解码输出的是unicode格式的
print '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60'.decode('utf-8')
你好哦0
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
2 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
2 个月前
更多推荐
已为社区贡献10条内容
所有评论(0)