深度学习Hello World --- 手写体识别 实战

   日期:2020-10-01     浏览:109    评论:0    
核心提示:最近因为学校事情比较多,也开始准备研究出一些深度学习方面的教程,但总被一些大大小小的原因在往后拖进度,这期用Python写一篇从零到一的手写体识别算法实战课来教各位如何入门深度学习。准备数据集首先准备一个 mnist 数据集。这是下载地址四个数据集分别是训练图集、训练结果、测试图集、测试结果。下载后存到一个文件夹中备用。Tensorflow 数据流图框架首先先调用Python第三方库,将数据集全部调用进程序(在这里使用 Tensorflow2.3.0 以及 scipy==1.2.1)

最近因为学校事情比较多,也开始准备研究出一些深度学习方面的教程,但总被一些大大小小的原因在往后拖进度,这期用Python写一篇从零到一的手写体识别算法实战课来教各位如何入门深度学习。

准备数据集

首先准备一个 mnist 数据集。
这是下载地址

四个数据集分别是训练图集、训练结果、测试图集、测试结果。
下载后存到一个文件夹中备用。

Tensorflow 数据流图框架

首先先调用Python第三方库,将数据集全部调用进程序
(在这里使用 Tensorflow2.3.0 以及 scipy==1.2.1)

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() # 这两句话是为了避免 Tensorflow 1.x与2.x的区别而引起的错误
from tensorflow.examples.tutorials.mnist import input_data 
import numpy as np
import os
import scipy.misc

mnist = input_data.read_data_sets('mnist',one_hot=True)

先看一下训练集的图片的结构

print(mnist.train.images.shape)
# -> (55000, 784) 五万五千张图片,每张图片含有784向量
print(mnist.train.labels.shape)
# -> (55000, 10) 五万五千张图片,由0到9展示的十维向量
print(mnist.train.labels[0,:])
# -> [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] 十个数字组成,表示第一个数为7

标出图片的位置以及不存在则创建

dir_path = 'mnist/data/'
if not os.path.exists(dir_path):
    os.makedirs(dir_path)

提取其中五张训练图作为实验,检查训练结果与训练图片的真实对应。

for i in range(5):
    image_array = mnist.train.images[i, :] # 提取第i张图片
    image_array = image_array.reshape(28,28) # 将图片转换为28*28像素的图片
    image_file = dir_path + 'mnist_train %d.jpg' % i # 放置图片的保存位置和图片名称
    scipy.misc.toimage(image_array,cmin=0.0,cmax=1.0).save(image_file) # 下载图片到本地,基本的图片格式设置

我们在for循环里面观察一下image和label对应的输出。看看训练集的每张图片是否对应。

	image_lable = mnist.train.labels[i, :]
    label = np.argmax(image_lable)
    print("image_train %d label is : %d" %(i,label))

首先先定义一下 Tensorflow 中的每个参数的变量。

x = tf.placeholder(tf.float32,[None,784]) # 占位符表示,第二个参数中第一个值为None代表不固定个数,维数为784
w = tf.Variable(tf.zeros([784,10])) # 定义初始化变量,从784层向量转化为10层的向量的过程,神经网络一层的结构
b = tf.Variable(tf.zeros([10])) # 偏执向量
y_ = tf.placeholder(tf.float32,[None,10])

# 原理 y = softmax(x*w+b) 
y = tf.nn.softmax(tf.matmul(x, w) + h)

接下来构建损失函数,在这里使用交叉熵损失函数,这是 Tensorflow 非常经典的已经封装好的函数,相当于构建真实的Y和输出的Y值所对应的交叉熵。

# 在这里labels与logits绝对不能弄混
cross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y))

梯度下降的迭代使得损失函数最小,在这里使用随机梯度下降,设置初始学习速率

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

Tensorflow 会话 Session

在之前的准备工作中,我们仅仅定义了一些 Tensorflow 所必需的一些数据变量,但是我们如果希望 Tensorflow 跑起来的话必须得使用 Tensorflow 的会话工作。在 Session 才是数据真正的开始流,创建之前并没有真正的数据在里面。初始化所有的变量。因为数据量比较小,我们迭代一千次梯度下降。读取batch批次,只有到 Session.run 的时候才是真正的数据跑起来。然后我们定义准确率去查看准确度大致多少。最终使用 test 测试数据集来验证准确率的大小。

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for _ in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step,feed_dict={ x:batch_xs,y_:batch_ys})

    curr = tf.equal(tf.argmax(y, 1),tf.argmax(y_, 1))
    acc = tf.reduce_mean(tf.cast(curr,tf.float32))
    print(sess.run(acc,feed_dict={ x:mnist.test.images,y_:mnist.test.labels}))
# -> 0.9144

总的来说,Tensorflow 入门级别也并不是很容易,但是每个人都得学的手写体识别,堪称神经网络的Hello World算法。希望每个大佬都能耐心的学下去,变得更强更秃。

全部代码

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import scipy.misc
mnist = input_data.read_data_sets('mnist',one_hot=True)
dir_path = 'mnist/data/'
if not os.path.exists(dir_path):
    os.makedirs(dir_path)
x = tf.placeholder(tf.float32,[None,784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y_ = tf.placeholder(tf.float32,[None,10])
y = tf.nn.softmax(tf.matmul(x, w) + b)
cross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for _ in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step,feed_dict={ x:batch_xs,y_:batch_ys})
    curr = tf.equal(tf.argmax(y, 1),tf.argmax(y_, 1))
    acc = tf.reduce_mean(tf.cast(curr,tf.float32))
    print(sess.run(acc,feed_dict={ x:mnist.test.images,y_:mnist.test.labels}))

最后还是希望你们能给我点一波小小的关注。

奉上自己诚挚的爱心

 
打赏
 本文转载自:网络 
所有权利归属于原作者,如文章来源标示错误或侵犯了您的权利请联系微信13520258486
更多>最近资讯中心
更多>最新资讯中心
0相关评论

推荐图文
推荐资讯中心
点击排行
最新信息
新手指南
采购商服务
供应商服务
交易安全
关注我们
手机网站:
新浪微博:
微信关注:

13520258486

周一至周五 9:00-18:00
(其他时间联系在线客服)

24小时在线客服