tf.data.Dataset的作用

TensorFlow提供了 tf.data 这一模块,包括了一套灵活的数据集构建API,能够帮助我们快速、高效地构建数据输入的流水线,尤其适用于数据量巨大的场景。

tf.data.Dataset API

The tf.data.Dataset API supports writing descriptive and efficient input pipelines. Dataset usage follows a common pattern:

  1. Create a source dataset from your input data.
  2. Apply dataset transformations to preprocess the data.
  3. Iterate over the dataset and process the elements.

Dataset的构建

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1]) //创建简单的数据集,测试、验证一些操作

test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels)) //常见的(data, label)训练数据

常见dataset transformations

打印dataset的元素

打印dataset元素的方法:文档中给出的例子dataset.as_numpy_iterator()方法不支持,不知原因。

list(dataset.as_numpy_iterator())

但可以用 for ele in dataset:  print (elem.numpy()) 这个方法

验证dataset.take()/skip的含义

In [89]: dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))

In [90]: for elem in dataset:
    ...:     print (elem.numpy())
    ...:     
8
3
0
8
2
1

In [91]: t1 = dataset.take(1)// 取第一个元素构建dataset(是第一个元素,不是随机的一个)

In [92]: for elem in t1:
    ...:     print (elem.numpy())
    ...:     
8

In [93]: t1 = dataset.skip(2)//跳过前2个元素后构建的dataset

In [94]: for elem in t1:
    ...:     print (elem.numpy())
    ...:     
0
8
2
1

In [95]: t1 = dataset.skip(2).take(3)//跳过请几个元素后取几个元素形成新的数据集,可用于数据集拆分成训练、测试、验证集

In [96]: for elem in t1:
    ...:     print (elem.numpy())
    ...:     
0
8
2

t1 = dataset.skip(1).take(1)//用于取确定的某个元素

一个具体的实例

In [147]: _x, _y , _n= create_segments_for_rnn(df, 64, 128)

In [148]: _x.shape
Out[148]: (26173, 128, 3)

In [149]: _y.shape
Out[149]: (26173, 5)

In [150]: _n.shape
Out[150]: (26173,)

In [151]: dataset = tf.data.Dataset.from_tensor_slices((_x, _y))
 

//train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1) # [60000, 28, 28, 1]

In [151]: def _decode_and_resize(data, label):                  
     ...:   x = tf.reshape(data, [128, 3, 1])                                 
     ...:   y = tf.cast(x, tf.float32)        
     ...:   return y, label 


In [154]: dataset = dataset.map(_decode_and_resize)

In [155]: dataset = dataset.shuffle(buffer_size=23000)

In [156]: dataset.element_spec
Out[156]: 
(TensorSpec(shape=(128, 3, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(5,), dtype=tf.float32, name=None))

In [157]: dataset = dataset.batch(32)

In [158]: dataset.element_spec
Out[158]: 
(TensorSpec(shape=(None, 128, 3, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 5), dtype=tf.float32, name=None))
 

数据集分割为训练集、测试集、验证集等

如下类似的方法

train_dataset = dataset.take(train_size)

test_dataset = dataset.skip(train_size).take(test_size)

val_dataset = dataset.skip(train_size+test_size).take(val_size)

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

构建模型

一般分为两部分:结构、compile

model_depth = models.Sequential()
model_depth.add(layers.DepthwiseConv2D((4, 6), activation='relu', padding='same', depth_multiplier=8, input_shape=(128, 3, 1)))
#model_depth.add(layers.MaxPooling2D((3, 3)))
model_depth.add(layers.AveragePooling2D((3, 3)))
model_depth.add(layers.Conv2D(16, (4, 1), padding='same', activation='relu'))
#model_depth.add(layers.MaxPooling2D((4, 1)))
model_depth.add(layers.AveragePooling2D((4, 1)))
model_depth.add(layers.Flatten())
model_depth.add(layers.Dropout(0.5))
model_depth.add(layers.Dense(16, activation='relu'))
model_depth.add(layers.Dense(5, activation='softmax'))
 

model_depth.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

训练、验证

data.Dataset的目的就是方便training

model_depth.fit(train_dataset, epochs=2, validation_data=test_dataset)

model_depth.evaluate(val_dataset)

这里有个坑,参与fit、evaluate的数据集必须进行过batch操作,否则shape对不上,可看下经过batch后,前面增加了1维。

 

tf.data.Dataset用于quantization

sensor_ds = tf.data.Dataset.from_tensor_slices((x_train_rs)).batch(1)
def representative_data_gen():
  for input_value in sensor_ds.take(100):
    yield [input_value]

tf.data.Dataset 用于interpreter

# evalue the quatizatin model
def eval_model(interpreter, mnist_ds):
  total_seen = 0
  num_correct = 0

  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  for img, label, name in mnist_ds:
    total_seen += 1
    interpreter.set_tensor(input_index, img)
    interpreter.invoke()
    predictions = interpreter.get_tensor(output_index)
    if np.argmax(predictions) == np.argmax(label): #这里是关键
      num_correct += 1
    else:
      print ('error'*30)
      print name
      print label
      print predictions
    if total_seen % 1000 == 0:
      print("Accuracy after %i images: %f" %
            (total_seen, float(num_correct) / float(total_seen)))

  return float(num_correct) / float(total_seen)

interpreter = tf.lite.Interpreter(model_path='har_float_model.tflite')
interpreter.allocate_tensors()
sensor_eval_ds_1 = tf.data.Dataset.from_tensor_slices((x_test_rs, y_test, n_test)).batch(1).take(10)

print("==="*10 + " float model " + "==="*10)
print(eval_model(interpreter, sensor_eval_ds_1))
 

访问数据集中的某个元素

In [166]: dataset = tf.data.Dataset.from_tensor_slices((_x, _y))

In [167]: t1 = dataset.take(1)

In [168]: t1.element_spec
Out[168]: 
(TensorSpec(shape=(128, 3), dtype=tf.float64, name=None),
 TensorSpec(shape=(5,), dtype=tf.float32, name=None))

In [169]: for ele in t1:
     ...:     print ele.numpy()
     ...:     
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-169-cbefee7a71fd> in <module>()
      1 for ele in t1:
----> 2     print ele.numpy()
      3 

AttributeError: 'tuple' object has no attribute 'numpy'
 

这个数据集的元素是个元组(_x, _y),所以不能ele.numpy()这样访问

可以通过访问元组成员的方法

In [174]: dataset = tf.data.Dataset.from_tensor_slices((_x, _y))

In [175]: t1 = dataset.take(1)

In [176]: for ele in t1:
     ...:     print (ele[0].numpy())
     ...:     
[[-0.46880656  1.15753446  0.16205443]
 [-0.6429841   1.34653488  0.28153576]

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐