教程地址:TensorFlow中文社区

MNIST数据下载

源码: tensorflow/g3doc/tutorials/mnist/

本教程的目标是展示如何下载用于手写数字分类问题所要用到的(经典)MNIST数据集。

教程 文件

本教程需要使用以下文件:

文件目的
input_data.py下载用于训练和测试的MNIST数据集的源码

备注:

input_data.py 文件路径为:tensorflow\examples\tutorials\mnist,

内容为:

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import

 from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

你会发现,该文件主要引用该目录下tensorflow\contrib\learn\python\learn\datasets\的mnist.py文件里面的read_data_sets函数

该目录结构:

准备数据

MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.

MNIST Digits

更多详情, 请参考 Yann LeCun's MNIST page 或 Chris Olah's visualizations of MNIST.

下载

Yann LeCun's MNIST page 也提供了训练集与测试集数据的下载。

文件内容
train-images-idx3-ubyte.gz训练集图片 - 55000 张 训练图片, 5000 张 验证图片
train-labels-idx1-ubyte.gz训练集图片对应的数字标签
t10k-images-idx3-ubyte.gz测试集图片 - 10000 张 图片
t10k-labels-idx1-ubyte.gz测试集图片对应的数字标签

在 input_data.py 文件中, maybe_download() 函数可以确保这些训练数据下载到本地文件夹中。

文件夹的名字在 fully_connected_feed.py 文件的顶部由一个标记变量指定,你可以根据自己的需要进行修改。

解压 与 重构

这些文件本身并没有使用标准的图片格式储存,并且需要使用input_data.py文件中extract_images()extract_labels()函数来手动解压(页面中有相关说明)。

图片数据将被解压成2维的tensor:[image index, pixel index] 其中每一项表示某一图片中特定像素的强度值, 范围从 [0, 255] 到 [-0.5, 0.5]。 "image index"代表数据集中图片的编号, 从0到数据集的上限值。"pixel index"代表该图片中像素点得个数, 从0到图片的像素上限值。

train-*开头的文件中包括60000个样本,其中分割出55000个样本作为训练集,其余的5000个样本作为验证集。因为所有数据集中28x28像素的灰度图片的尺寸为784,所以训练集输出的tensor格式为[55000, 784]

数字标签数据被解压称1维的tensor: [image index],它定义了每个样本数值的类别分类。对于训练集的标签来说,这个数据规模就是:[55000]

数据集 对象

底层的源码将会执行下载、解压、重构图片和标签数据来组成以下的数据集对象:

数据集目的
data_sets.train55000 组 图片和标签, 用于训练。
data_sets.validation5000 组 图片和标签, 用于迭代验证训练的准确性。
data_sets.test10000 组 图片和标签, 用于最终测试训练的准确性。

执行read_data_sets()函数将会返回一个DataSet实例,其中包含了以上三个数据集。函数DataSet.next_batch()是用于获取以batch_size为大小的一个元组,其中包含了一组图片和标签,该元组会被用于当前的TensorFlow运算会话中。

images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)

 MNIST数据读取 

在TensorFlow的源码中,MNIST数据集的读取操作在contrib\learn\python\learn\datasets\data\mnist.py中,函数是read_data_sets。

read_data_sets函数:

def read_data_sets(train_dir,
    fake_data=False,
    one_hot=False,
    ype=dtypes.float32,
    reshape=True,
    validation_size=5000):

train_dir:为数据集在文件夹的位置,在这里为tensorflow\examples\tutorials\mnist\MNIST_data;

fake_data: 在官方教程中提到fake_data标记是用于单元测试的,读者可以不必理会;

one_hot:为one_hot编码,即独热码,作用是将状态值编码成状态向量,例如,数字状态共有0~9这10种,对于数字7,将它进行one_hot编码后为[0 0 0 0 0 0 0 1 0 0],这样使得状态对于计算机来说更加明确,对于矩阵操作也更加高效。

dtype:的作用是将图像像素点的灰度值从[0, 255]转变为[0.0, 1.0]。

reshape:的作用是将图像的形状从[num examples, rows, columns, depth]转变为[num examples, rows*columns] (对于二维图片,depth为1)。

validation_size:即为从训练集中抽取这么多来作为验证集。

变量定义好之后,接下来提取数据集。

with open(local_file, 'rb') as f:
    train_images = extract_images(f)

看extract_images函数:

with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                       (magic, f.name))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data

如果这么看代码可能很难理解,但是如果清楚MNIST数据集文件的结构之后就好理解得多,对于MNIST的images文件:

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
offsettypevaluedescription
000032 bit integer0x00000803(2051)magic number
000432 bit integer60000number of images
000832 bit integer28number of rows
001232 bit integer28number of columns
0016unsigned byte??pixel
0017unsigned byte??pixel
0018unsigned byte??pixel
......   
xxxxunsigned byte??pixel


代码中_read32()的作用是从文件流中动态读取4位数据并转换为uint32的数据。

image文件的前四位为魔术码(magic number),只有检测到这4位数据的值和2051相等时,才代表这是正确的image文件,才会继续往下读取。接下来继续读取之后的4位,代表着image文件中,所包含的图片的数量(num_images)。再接着读4位,为每一幅图片的行数(rows),再后4位,为每一幅图片的列数(cols)。最后再读接下来的rows * cols * num_images位,即为所有图片的像素值。最后再将读取到的所有像素值装换为[index, rows, cols, depth]的4D矩阵。这样就将全部的image数据读取了出来。

同理,对于MNIST的labels文件:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
offsettypevaluedescription
000032 bit integer0x00000801(2049)magic number
000432 bit integer60000number of items
0008unsigned byte??label
0009unsigned byte??label
......   
xxxxunsigned byte??label

再看代码:

def extract_labels(f, one_hot=False, num_classes=10):
  """Extract the labels into a 1D uint8 numpy array [index].

  Args:
    f: A file object that can be passed into a gzip reader.
    one_hot: Does one hot encoding for the result.
    num_classes: Number of classes for the one hot encoding.

  Returns:
    labels: a 1D uint8 numpy array.

  Raises:
    ValueError: If the bystream doesn't start with 2049.
  """
  print('Extracting', f.name)
  with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                       (magic, f.name))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
      return dense_to_one_hot(labels, num_classes)
    return labels

同样的也是依次读取文件的魔术码以及标签总数,最后把所有图片的标签读取出来,成一个长度为num_items的1D的向量。不过代码中还有一个one_hot的部分,dense_to_one_hot的代码为:

def dense_to_one_hot(labels_dense, num_classes):
  """Convert class labels from scalars to one-hot vectors."""
  num_labels = labels_dense.shape[0]
  index_offset = numpy.arange(num_labels) * num_classes
  labels_one_hot = numpy.zeros((num_labels, num_classes))
  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  return labels_one_hot


正如文章开头提到one_hot的作用,这里将1D向量中的每一个值,编码成一个长度为num_classes的向量,向量中对应于该值的位置为1,其余为0,所以one_hot将长度为num_labels的向量编码为一个[num_labels, num_classes]的2D矩阵。

以上就是如何将MNIST数据文件中的images和labels分别提取出来的过程。

备注:

以上函数都有,“@deprecated(None, 'Please use tf.data to implement this functionality.')”。

以后的新版本估计将没有这些函数。 

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐