Tensorflow2.0入门教程9:TensorFlow Datasets 数据集载入
两种方法
- 1.tensorflow_datasets包载入https://tensorflow.google.cn/datasets/overview
- 2.tf.keras.datasets载入
一、tensorflow_datasets载入数据集
安装包
pip install tensorflow_datasets
导入包
import tensorflow as tf
import tensorflow_datasets as tfds
可用数据集
tfds.list_builders()
['abstract_reasoning',
'aeslc',
'aflw2k3d',
'amazon_us_reviews',
'arc',
'bair_robot_pushing_small',
'big_patent',
'bigearthnet',
'billsum',
'binarized_mnist',
'binary_alpha_digits',
'c4',
'caltech101',
'caltech_birds2010',
'caltech_birds2011',
'cars196',
'cassava',
'cats_vs_dogs',
'celeb_a',
'celeb_a_hq',
'chexpert',
'cifar10',
'cifar100',
'cifar10_1',
'cifar10_corrupted',
'citrus_leaves',
'cityscapes',
'civil_comments',
'clevr',
'cmaterdb',
'cnn_dailymail',
'coco',
'coil100',
'colorectal_histology',
'colorectal_histology_large',
'cos_e',
'curated_breast_imaging_ddsm',
'cycle_gan',
'deep_weeds',
'definite_pronoun_resolution',
'diabetic_retinopathy_detection',
'dmlab',
'downsampled_imagenet',
'dsprites',
'dtd',
'duke_ultrasound',
'dummy_dataset_shared_generator',
'dummy_mnist',
'emnist',
'esnli',
'eurosat',
'fashion_mnist',
'flic',
'flores',
'food101',
'gap',
'gigaword',
'glue',
'groove',
'higgs',
'horses_or_humans',
'i_naturalist2017',
'image_label_folder',
'imagenet2012',
'imagenet2012_corrupted',
'imagenet_resized',
'imagenette',
'imdb_reviews',
'iris',
'kitti',
'kmnist',
'lfw',
'lm1b',
'lost_and_found',
'lsun',
'malaria',
'math_dataset',
'mnist',
'mnist_corrupted',
'movie_rationales',
'moving_mnist',
'multi_news',
'multi_nli',
'multi_nli_mismatch',
'newsroom',
'nsynth',
'omniglot',
'open_images_v4',
'oxford_flowers102',
'oxford_iiit_pet',
'para_crawl',
'patch_camelyon',
'pet_finder',
'places365_small',
'plant_leaves',
'plant_village',
'plantae_k',
'quickdraw_bitmap',
'reddit_tifu',
'resisc45',
'rock_paper_scissors',
'rock_you',
'scan',
'scene_parse150',
'scicite',
'scientific_papers',
'shapes3d',
'smallnorb',
'snli',
'so2sat',
'squad',
'stanford_dogs',
'stanford_online_products',
'starcraft_video',
'sun397',
'super_glue',
'svhn_cropped',
'ted_hrlr_translate',
'ted_multi_translate',
'tf_flowers',
'the300w_lp',
'titanic',
'trivia_qa',
'uc_merced',
'ucf101',
'vgg_face2',
'visual_domain_decathlon',
'voc',
'wider_face',
'wikihow',
'wikipedia',
'wmt14_translate',
'wmt15_translate',
'wmt16_translate',
'wmt17_translate',
'wmt18_translate',
'wmt19_translate',
'wmt_t2t_translate',
'wmt_translate',
'xnli',
'xsum']
tfds.load 方法,载入所需的数据集
tfds.load 方法返回一个 tf.data.Dataset 对象。部分重要的参数如下:
-
as_supervised :若为 True,则根据数据集的特性返回为 (input, label) 格式,否则返回所有特征的字典。
-
split:指定返回数据集的特定部分,若无则返回整个数据集。一般有 tfds.Split.TRAIN (训练集)和 tfds.Split.TEST (测试集)选项。
-
download:布尔值,是否进行数据下载,如果数据准备好了,后续的 load 命令便不会重新下载,可以重复使用准备好的数据。你可以通过指定 data_dir= (默认是 ~/tensorflow_datasets/) 来自定义数据保存/加载的路径。
# 训练集
train_dataset = tfds.load("mnist", split=tfds.Split.TRAIN,as_supervised=True)
# mnist_train = tfds.load(name="mnist", split="train")
# 测试集
# test_dataset = tfds.load("mnist", split=tfds.Split.TEST)
# mnist_train = tfds.load(name="mnist", split="test")
# dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
# dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)
train_dataset
<DatasetV1Adapter shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>
train_dataset = tfds.load("mnist", split=tfds.Split.TRAIN,as_supervised=False)
train_dataset
<DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
当得到了 tf.data.Dataset 类型的数据集后,我们即可使用 tf.data 对数据集进行各种预处理以及读取数据。例如:
# 对 dataset 进行大小调整、打散和分批次操作
dataset = train_dataset.map(lambda img, label: (tf.image.resize(img, [28,28]) / 255.0, label))
dataset = dataset.shuffle(1024)
dataset = dataset.batch(128)
通过循环迭代取出数据:
for images, labels in dataset:
print(images.shape)
# 对images和labels进行操作,比如模型训练
(128, 28, 28, 1)
(128, 28, 28, 1)
(128, 28, 28, 1)
(128, 28, 28, 1)
KeyboardInterrupt:
或者通过以下方式迭代数据:
iterator = dataset.make_one_shot_iterator()
try:
while True:
images, labels = iterator.get_next()
# 对images和labels进行操作,比如模型训练
print(images.shape,labels.shape)
except tf.errors.OutOfRangeError:
print("end!")
(128, 28, 28, 1) (128,)
(128, 28, 28, 1) (128,)
(128, 28, 28, 1) (128,)
(128, 28, 28, 1) (128,)
end!
二、tf.keras.datasets载入数据集
可用数据集:
boston_housing module: Boston housing price regression dataset.
cifar10 module: CIFAR10 small images classification dataset.
cifar100 module: CIFAR100 small images classification dataset.
fashion_mnist module: Fashion-MNIST dataset.
imdb module: IMDB sentiment classification dataset.
mnist module: MNIST handwritten digits dataset.
load_data载入数据
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.imdb.load_data()
(x_train, y_train), (x_test, y_test)为返回的数据集
更多推荐
所有评论(0)