MNIST数据集由Yann LeCun搜集,是一个大型的手写体数字数据库,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。MNIST数字文字识别数据集数据量不会太多,而且是单色的图像,较简单,适合深度学习初学者练习建立模型、训练、预测。MNIST数据库中的图像集是NIST(National Institute of Standards and Technology)的两个数据库的组合:专用数据库1和特殊数据库3。数据集是有250人手写数字组成,一半是高中生,一半是美国人口普查局。

MNIST数据集共有训练数据60000项、测试数据10000项。每张图像的大小为28*28(像素),每张图像都为灰度图像,位深度为8(灰度图像是0-255)。

下载读取MNIST数据

  • 手动下载

下载地址:http://yann.lecun.com/exdb/mnist/

MNIST数据集包含4个文件,下载四个压缩文件,解压缩。解压缩后发现这些文件并不是标准的图像格式。这些图像数据都保存在二进制文件中。train文件是训练数据集,t10k是测试数据集,images文件是图像文件,lables文件是对应的标签文件。

train-images-idx3-ubyte.gz: training set images (9912422 bytes)

train-labels-idx1-ubyte.gz: training set labels (28881 bytes)

t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)

t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

  • 使用TensorFlow下载

读取数据

  • 读取单张图像

MNIST的图像大小是28*28,先读取训练集中的第一张图像。注意:在 train-images-idx3-ubyte 文件头部有4个integer类型,需要跳过去。

程序代码

运行结果

  • 读取多张图像

读取100张的t10k的测试图像和标签,并且显示出来。

程序代码

运行结果

在TensorFlow中使用

MNIST数据集在机器学习方面已经被广泛使用,比如说在MNIST上采用Softmax回归训练,在MNIST上使用CNN做可视化训练等。在TensorFlow上MNIST可以直接被调取,只需要导入input_data.py这个文件就可以,不需要对其进行二进制文件转为图像的操作,使用tensorflow.contrib.learn中的read_data_sets来加载数据就可以了。代码如下,

注:这里使用的直接是二进制文件,如果下载的数据集是图像文件,在使用之前还得进行数据预处理(二维图像矩阵转换成一维向量、label数字转换成One-hot向量等)。

问题:使用tensorflow的input_data下载MNIST数据集时,会报错误如下,

Traceback (most recent call last):
  File "alex_mnist.py", line 5, in <module>
    mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
  File "/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 250, in new_func
    return func(*args, **kwargs)
  File "/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py", line 260, in read_data_sets
    source_url + TRAIN_IMAGES)
  File "/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 250, in new_func
    return func(*args, **kwargs)
  File "/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 252, in maybe_download
    temp_file_name, _ = urlretrieve_with_retry(source_url)
  File "/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 250, in new_func
    return func(*args, **kwargs)
  File "/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 205, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 233, in urlretrieve_with_retry
    return urllib.request.urlretrieve(url, filename)
  File "/home/cnu105/anaconda3/lib/python3.6/urllib/request.py", line 248, in urlretrieve
    with contextlib.closing(urlopen(url, data)) as fp:
  File "/home/cnu105/anaconda3/lib/python3.6/urllib/request.py", line 223, in urlopen
    return opener.open(url, data, timeout)
  File "/home/cnu105/anaconda3/lib/python3.6/urllib/request.py", line 526, in open
    response = self._open(req, data)
  File "/home/cnu105/anaconda3/lib/python3.6/urllib/request.py", line 544, in _open
    '_open', req)
  File "/home/cnu105/anaconda3/lib/python3.6/urllib/request.py", line 504, in _call_chain
    result = func(*args)
  File "/home/cnu105/anaconda3/lib/python3.6/urllib/request.py", line 1361, in https_open
    context=self._context, check_hostname=self._check_hostname)
  File "/home/cnu105/anaconda3/lib/python3.6/urllib/request.py", line 1321, in do_open
    r = h.getresponse()
  File "/home/cnu105/anaconda3/lib/python3.6/http/client.py", line 1331, in getresponse
    response.begin()
  File "/home/cnu105/anaconda3/lib/python3.6/http/client.py", line 297, in begin
    version, status, reason = self._read_status()
  File "/home/cnu105/anaconda3/lib/python3.6/http/client.py", line 258, in _read_status
    line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
  File "/home/cnu105/anaconda3/lib/python3.6/socket.py", line 586, in readinto
    return self._sock.recv_into(b)
  File "/home/cnu105/anaconda3/lib/python3.6/ssl.py", line 1009, in recv_into
    return self.read(nbytes, buffer)
  File "/home/cnu105/anaconda3/lib/python3.6/ssl.py", line 871, in read
    return self._sslobj.read(len, buffer)
  File "/home/cnu105/anaconda3/lib/python3.6/ssl.py", line 631, in read
    v = self._sslobj.read(len, buffer)
ConnectionResetError: [Errno 104] Connection reset by peer

经查,原因是 input_data.read_data_sets访问https://storage.googleapis.com/cvdf-datasets/mnist/被限制,需要修改read_data_sets函数,将source_url设置成http://yann.lecun.com/exdb/mnist/,修改后代码如下,

#/home/cnu105/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py

# CVDF mirror of http://yann.lecun.com/exdb/mnist/
# DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
DEFAULT_SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'

重新运行,正常。

以上是关于MNIST数据集的介绍,下面是使用MNIST的例子。

TensorFlow实现AlexNet模型进行分类预测MNIST

Keras多层感知器识别手写数字

Keras卷积神经网络识别手写数字

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐