两种方法

  • 1.tensorflow_datasets包载入https://tensorflow.google.cn/datasets/overview
  • 2.tf.keras.datasets载入

一、tensorflow_datasets载入数据集

安装包

pip install tensorflow_datasets

导入包

import tensorflow as tf
import tensorflow_datasets as tfds

可用数据集

tfds.list_builders()
['abstract_reasoning',
 'aeslc',
 'aflw2k3d',
 'amazon_us_reviews',
 'arc',
 'bair_robot_pushing_small',
 'big_patent',
 'bigearthnet',
 'billsum',
 'binarized_mnist',
 'binary_alpha_digits',
 'c4',
 'caltech101',
 'caltech_birds2010',
 'caltech_birds2011',
 'cars196',
 'cassava',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'chexpert',
 'cifar10',
 'cifar100',
 'cifar10_1',
 'cifar10_corrupted',
 'citrus_leaves',
 'cityscapes',
 'civil_comments',
 'clevr',
 'cmaterdb',
 'cnn_dailymail',
 'coco',
 'coil100',
 'colorectal_histology',
 'colorectal_histology_large',
 'cos_e',
 'curated_breast_imaging_ddsm',
 'cycle_gan',
 'deep_weeds',
 'definite_pronoun_resolution',
 'diabetic_retinopathy_detection',
 'dmlab',
 'downsampled_imagenet',
 'dsprites',
 'dtd',
 'duke_ultrasound',
 'dummy_dataset_shared_generator',
 'dummy_mnist',
 'emnist',
 'esnli',
 'eurosat',
 'fashion_mnist',
 'flic',
 'flores',
 'food101',
 'gap',
 'gigaword',
 'glue',
 'groove',
 'higgs',
 'horses_or_humans',
 'i_naturalist2017',
 'image_label_folder',
 'imagenet2012',
 'imagenet2012_corrupted',
 'imagenet_resized',
 'imagenette',
 'imdb_reviews',
 'iris',
 'kitti',
 'kmnist',
 'lfw',
 'lm1b',
 'lost_and_found',
 'lsun',
 'malaria',
 'math_dataset',
 'mnist',
 'mnist_corrupted',
 'movie_rationales',
 'moving_mnist',
 'multi_news',
 'multi_nli',
 'multi_nli_mismatch',
 'newsroom',
 'nsynth',
 'omniglot',
 'open_images_v4',
 'oxford_flowers102',
 'oxford_iiit_pet',
 'para_crawl',
 'patch_camelyon',
 'pet_finder',
 'places365_small',
 'plant_leaves',
 'plant_village',
 'plantae_k',
 'quickdraw_bitmap',
 'reddit_tifu',
 'resisc45',
 'rock_paper_scissors',
 'rock_you',
 'scan',
 'scene_parse150',
 'scicite',
 'scientific_papers',
 'shapes3d',
 'smallnorb',
 'snli',
 'so2sat',
 'squad',
 'stanford_dogs',
 'stanford_online_products',
 'starcraft_video',
 'sun397',
 'super_glue',
 'svhn_cropped',
 'ted_hrlr_translate',
 'ted_multi_translate',
 'tf_flowers',
 'the300w_lp',
 'titanic',
 'trivia_qa',
 'uc_merced',
 'ucf101',
 'vgg_face2',
 'visual_domain_decathlon',
 'voc',
 'wider_face',
 'wikihow',
 'wikipedia',
 'wmt14_translate',
 'wmt15_translate',
 'wmt16_translate',
 'wmt17_translate',
 'wmt18_translate',
 'wmt19_translate',
 'wmt_t2t_translate',
 'wmt_translate',
 'xnli',
 'xsum']

tfds.load 方法,载入所需的数据集

tfds.load 方法返回一个 tf.data.Dataset 对象。部分重要的参数如下:

  • as_supervised :若为 True,则根据数据集的特性返回为 (input, label) 格式,否则返回所有特征的字典。

  • split:指定返回数据集的特定部分,若无则返回整个数据集。一般有 tfds.Split.TRAIN (训练集)和 tfds.Split.TEST (测试集)选项。

  • download:布尔值,是否进行数据下载,如果数据准备好了,后续的 load 命令便不会重新下载,可以重复使用准备好的数据。你可以通过指定 data_dir= (默认是 ~/tensorflow_datasets/) 来自定义数据保存/加载的路径。

# 训练集
train_dataset = tfds.load("mnist", split=tfds.Split.TRAIN,as_supervised=True)
# mnist_train = tfds.load(name="mnist", split="train")
# 测试集
# test_dataset = tfds.load("mnist", split=tfds.Split.TEST)
# mnist_train = tfds.load(name="mnist", split="test")
# dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
# dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)
train_dataset
<DatasetV1Adapter shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>
train_dataset = tfds.load("mnist", split=tfds.Split.TRAIN,as_supervised=False)
train_dataset
<DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

当得到了 tf.data.Dataset 类型的数据集后,我们即可使用 tf.data 对数据集进行各种预处理以及读取数据。例如:

# 对 dataset 进行大小调整、打散和分批次操作
dataset = train_dataset.map(lambda img, label: (tf.image.resize(img, [28,28]) / 255.0, label))
dataset = dataset.shuffle(1024)
dataset = dataset.batch(128)

通过循环迭代取出数据:

for images, labels in dataset:
    print(images.shape)
    # 对images和labels进行操作,比如模型训练
(128, 28, 28, 1)
(128, 28, 28, 1)
(128, 28, 28, 1)
(128, 28, 28, 1)
KeyboardInterrupt: 

或者通过以下方式迭代数据:

iterator = dataset.make_one_shot_iterator()
try:
    while True:
        images, labels = iterator.get_next()
        # 对images和labels进行操作,比如模型训练
        print(images.shape,labels.shape)
except tf.errors.OutOfRangeError:
    print("end!")
(128, 28, 28, 1) (128,)
(128, 28, 28, 1) (128,)
(128, 28, 28, 1) (128,)
(128, 28, 28, 1) (128,)
end!

二、tf.keras.datasets载入数据集

可用数据集:

boston_housing module: Boston housing price regression dataset.

cifar10 module: CIFAR10 small images classification dataset.

cifar100 module: CIFAR100 small images classification dataset.

fashion_mnist module: Fashion-MNIST dataset.

imdb module: IMDB sentiment classification dataset.

mnist module: MNIST handwritten digits dataset.

load_data载入数据

(x_train, y_train),(x_test, y_test) = tf.keras.datasets.imdb.load_data()

(x_train, y_train), (x_test, y_test)为返回的数据集

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐