利用Tensorflow实现手写数字识别(附python代码)
手写识别的应用场景有很多,智能手机、掌上电脑的信息工具的普及,手写文字输入,机器识别感应输出;还可以用来识别银行支票,如果准确率不够高,可能会引起严重的后果。当然,手写识别也是机器学习领域的一个Hello World任务,感觉每一个初识神经网络的人,搭建的第一个项目十之八九都是它。
我们来尝试搭建下手写识别中最基础的手写数字识别,与手写识别的不同是数字识别只需要识别0-9的数字,样本数据集也只需要覆盖到绝大部分包含数字0-9的字体类型,说白了就是简单,样本特征少,难度小很多。
一、目标
预期目标:传入一张数字图片给机器,机器通过识别,最后返回给用户图片上的数字
传入图片:
机器识别输出:
二、搭建(全连接神经网络)
环境:python3.6 tensorflow1.14
工具:pycharm
数据源:来自手写数据机器视觉数据库mnist数据集,包含7万张黑底白字手写数字图片,其中55000张为训练集,5000张为验证集,10000张为测试集。每张图片大小为28*28像素,图片纯黑色像素值为0,纯白色像素值为1。数据集的标签是长度为10的一维数组,数组中的每个元素索引号表示对应数字出现的概率。
可通过input_data模块中的read_data_sets()函数直接加载mnist数据集(详情见mnist_backward.py):
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)
一、定义网络模型,神经网络的前向传播(mnist_forward.py)
import tensorflow as tf
INPUT_NODE=784 # 输入节点
OUTPUT_NODE=10 # 输出节点
LAYER1_NODE=500 # 隐藏节点
def get_weight(shape,regularizer):
w=tf.Variable(tf.truncated_normal(shape,stddev=0.1))
if regularizer !=None:
tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))
return w
def get_bias(shape):
b=tf.Variable(tf.zeros(shape))
return b
def forward(x,regularizer):
w1=get_weight([INPUT_NODE,LAYER1_NODE],regularizer)
b1=get_bias(LAYER1_NODE)
y1=tf.nn.relu(tf.matmul(x,w1)+b1)
w2=get_weight([LAYER1_NODE,OUTPUT_NODE],regularizer)
b2=get_bias([OUTPUT_NODE])
y=tf.matmul(y1,w2)+b2
return y
这里定义了网络模型输入输出节点的个数、隐藏层节点数、同时定义get_weigt()函数实现对参数w的设置,包括参数的形状和是否正则化的标志,从输入层到隐藏层的参数w1形状为[784,500],由隐藏层到输出层的参数w2形状为[500,10]。定义get_bias()实现对偏置b的设置。由输入层到隐藏层的偏置b1形状长度为500的一维数组,由隐藏层到输出层的偏置b2形状长度为10的一维数组,初始化值为全0。
二、神经网络的反向传播(mnist_backward.py)
利用训练数据集对神经网络进行训练,通过降低损失函数值,实现网络模型参数的优化,从而得到准确率高且泛化能力强的神经网络模型。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os
batch_size=200
learning_rate_base=0.1 # 初始学习率
learning_rate_decay=0.99 # 学习率衰减率
regularizer=0.0001 # 正则化系数
steps=50000 # 训练轮数
moving_average_decay=0.99
model_save_path="./model/" # 模型保存路径
model_name="mnist_model"
def backward(mnist):
x=tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
y_=tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
y=mnist_forward.forward(x,regularizer) # 调用forward()函数,设置正则化,计算y
global_step=tf.Variable(0,trainable=False) # 当前轮数计数器设定为不可训练类型
# 调用包含所有参数正则化损失的损失函数loss
ce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
cem=tf.reduce_mean(ce)
loss=cem+tf.add_n(tf.get_collection('losses'))
# 设定指数衰减学习率learning_rate
learning_rate=tf.train.exponential_decay(
learning_rate_base,
global_step,
mnist.train.num_examples/batch_size,
learning_rate_decay,
staircase=True
)
# 梯度衰减对模型优化,降低损失函数
train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
# 定义参数的滑动平均
ema=tf.train.ExponentialMovingAverage(moving_average_decay,global_step)
ema_op=ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step,ema_op]):
train_op=tf.no_op(name='train')
saver=tf.train.Saver()
with tf.Session() as sess:
init_op=tf.global_variables_initializer() # 所有参数初始化
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(model_save_path) # 加载指定路径下的滑动平均
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(steps): # 循环迭代steps轮
xs,ys=mnist.train.next_batch(batch_size)
_,loss_value,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
if i %1000==0:
print("After %d training step(s),loss on training batch is %g."%(step,loss_value))
saver.save(sess,os.path.join(model_save_path,model_name),global_step=global_step) # 当前会话加载到指定路径
if __name__=='__main__':
mnist = input_data.read_data_sets("./data/", one_hot=True)
backward(mnist)
反向传播中,首先定义了每轮喂入神经网络的图片数batch_size、初始学习率learning_rate_base、学习率衰减率learning_rate_decay、正则化系数regularizer、训练轮数steps、模型保存路径以及模型保存名称等相关信息。反向传播backward()函数中,先传入minist数据集,用tf.placeholder(dtype,shape)函数实现训练样本x和样本标签y_占位。y表示定义的前向传播函数forward; tf.Variable(0,trainable=False)给当前轮数赋值,定义为不可训练类型。接着,loss表示定义的损失函数,一般为预测值与样本标签的交叉熵与正则化损失之和;train_step表示利用优化算法对模型参数进行优化,常用的优化算法有GradientDescentOptimizer、AdamOptimizer、MomentumOptimizer算法,这里使用的GradientDescentOptimizer梯度衰减算法。接着初始化saver对象,其中利用tf.global_variables_initializer()函数初始化所有模型参数,利用sess.run()函数实现模型的训练优化过程,并每隔一定轮数保存一次模型,模型训练好之后保存在ckpt中。
三、测试数据集,验证模型性能(mnist_test.py)
给神经网络模型输入测试集验证网络的准确性和泛化性(测试集和训练集是相互独立的)
# coding:utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
test_interval_secs=5 # 程序循环间隔时间5秒
def test(mnist):
with tf.Graph().as_default() as g: # 复现计算图
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, None)
# 实例化滑动平均的saver对象
ema = tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)
# 计算准确率
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
while True:
with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state(mnist_backward.model_save_path) # 加载指定路径下的滑动平均
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
print("After %s training step(s),test accuracy= %g."%(global_step,accuracy_score))
else:
print('No checkpoint file found')
return
time.sleep(test_interval_secs)
if __name__=='__main__':
mnist = input_data.read_data_sets("./data/", one_hot=True)
test(mnist)
首先,制定模型测试函数test(),通过tf.placeholder()给x,y_占位,调用mnist_forward文件中的前向传播过程forward()函数计算y,mnist_backward.moving_average_decay表示滑动衰减率。在with结构中,ckpt是加载训练好的模型,如果已有ckpt模型则恢复会话、轮数等。其次,制定main()函数,加载测试数据集,调用定义好的测试函数test()就行。
通过对测试数据的预测得到准确率,从而判断出训练出的神经网络模型性能的好坏。当准确率低时,可能原因有模型需要改进,或者是训练数据量太少导致过拟合等。
运行以上三个文件,运行结果如下:
从终端显示的运行结果可以看出,随着训练轮数的增加,网络模型的损失函数值在不断降低,在测试集上的准确率也在不断提升,具有较好的泛化能力。
四、输入真实图片,输出预测结果(mnist_app.py)
任务分两个函数完成:
(1)pre_pic()函数,对手写数字图片做预处理
(2)restore_model()函数,将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值。
# coding:utf-8
import tensorflow as tf
import mnist_forward
import mnist_backward
from PIL import Image
import numpy as np
def restore_model(testPicArr):
with tf.Graph().as_default() as tg: # 创建一个默认图
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward.forward(x, None)
preValue=tf.argmax(y,1) # 得到概率最大的预测值
'''
实现滑动平均模型,参数moving_average_decay用于控制模型的更新速度,训练过程会对每一个变量维护一个影子变量
这个影子变量的初始值就是相应变量的初始值,每次变量更新时,影子变量随之更细
'''
variable_averages=tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)
variable_to_restore=variable_averages.variables_to_restore()
saver=tf.train.Saver(variable_to_restore)
with tf.Session() as sess:
# 通过checkpoint文件定位到最新保存的模型
ckpt=tf.train.get_checkpoint_state(mnist_backward.model_save_path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
preValue=sess.run(preValue,feed_dict={x:testPicArr})
return preValue
else:
print("no checkpoint file found")
return -1
# 预处理函数,包括resize,转变灰度图,二值化操作等
def pre_pic(picName):
img=Image.open(picName)
reIm=img.resize((28,28),Image.ANTIALIAS)
im_arr=np.array(reIm.convert('L'))
threshold=50 # 设定合理的阈值
for i in range(28):
for j in range(28):
im_arr[i][j]=255-im_arr[i][j] # 模型要求黑底白字,输入图为白底黑字,对每个像素点的值改为255-原值=互补的反色
if (im_arr[i][j]<threshold):
im_arr[i][j]=0
else:
im_arr[i][j]=255
nm_arr=im_arr.reshape([1,784]) # 1行784列
nm_arr=nm_arr.astype(np.float32)
img_ready=np.multiply(nm_arr,1.0/255.0) # 从0-255之间的数变为0-1之间的浮点数
return img_ready
if __name__=='__main__':
testNum=int(input("input the number of test pictures:"))
for i in range(testNum):
testPic=input("the path of test picture:")
testPicArr=pre_pic(testPic)
preValue=restore_model(testPicArr)
print("the prediction number is",preValue)
在pre_pic()函数中,网络要求输入是28*28像素点的值,先将图片尺寸resize,模型要求的是黑底白字,但输入的图是白底黑字,则每个像素点的值改为255减去原值得到互补的反色。再对图片做二值化处理,这样可以滤掉噪声。nm_arr把图片拉成1行784列,并把值变为浮点数。restore_model()函数,计算输出y,网络输出的是一个一维数组(10个可能性概率),数组中最大的那个元素所对应的索引号就是预测的结果。
运行mnist_app.py文件,结果如下:
先输入需要识别的图片number数,然后传入图片路径,最后返回识别结果。我们传入的图片2.jpg,5.jpg如下所示:
预测结果也是2,5,说明模型还可以。但是,前面我们也提到过,如果数字识别用来识别银行支票97%的准确率不算高,然后卷积神经网络就开始大放异彩了...........................
最后,本人微信公众号:
放心,不用出钱也不用报班,只是单纯的想多两个粉丝罢了。
更多推荐
所有评论(0)