使用MNIST数据集在10分钟内进行图像分类!

DCDLIN 2018-08-20

点击上方关注,All in AI中国

作者:Orhan Gazi Yalçın

当你开始用神经网络学习深度学习时,你会意识到最强大的监督深度学习技术之一是卷积神经网络(简称CNN)。它最终的结构非常类似于正则化的神经网络(RegularNets),其中有带有权重和偏差的神经元。此外,在CNNs中,我们还使用了损失函数(如交叉熵或softmax)、优化器(如adam优化器)和全连接层[2]。另一方面,在CNNs中,有卷积层、池化层和扁平层。CNNs主要用于图像分类,但你在其他应用领域也会发现它的"身影",如自然语言处理(本教程将重点介绍图像分类)。

使用MNIST数据集在10分钟内进行图像分类!

MNIST数据集和数字分类[1]

RegularNets的主要结构特征是所有神经元的相互连接。例如,当我们有一个28×28像素,只有灰度的图像时。我们最终会得到一个有784(28×28×1)个神经元,看起来易于管理的层。然而,大多数图像有更多的像素,而且不仅仅是灰度图。因此,假设我们有一组4K超高清的彩色图像,我们将有26542080 (4096 x 2160 x 3)个不同的神经元在第一层相互连接。这无疑很难管理。因此,我们可以说RegularNets(正则网络)对于图像分类是不可扩展的。特别是两个单独的像素之间除了它们彼此接近以外,并没有什么相关性或联系。这也引出了卷积层和池化层的概念。

CNN的分层

我们能够在一个卷积神经网络中使用许多不同的层。然而,卷积层、池化层和全连接层是最重要的。因此,在实现这些层之前,我会快速介绍它们。

卷积层

卷积层是我们从数据集图像中提取特征的第一层。由于像素只与相邻和相近的像素相关,因此卷积允许我们保持图像不同部分之间的关系。卷积基本上就是用一个更小的像素滤波器来过滤图像,以减少图像大小的同时而不丢失像素之间的关系。当我们使用带有1x1步长(每步移动1个像素)的3x3滤波器对5x5图像进行卷积时。我们最终会得到一个3x3的输出(复杂度降低64%)。

使用MNIST数据集在10分钟内进行图像分类!

用3x3像素滤波器卷积的5x5像素图像(步长=1x1像素)

池化层

在构建CNNs时,通常在每个卷积层之后插入池化层,以减小表示的空间大小,减少参数计数,从而降低计算复杂度。此外,池化层也有助于解决过度拟合问题。基本上,我们通过选择这些像素内的最大值、平均值或和值来选择池大小以减少参数的数量。最大池化是最常见的池化技术之一,可以演示如下:

使用MNIST数据集在10分钟内进行图像分类!

2x2的最大池

全连接层

一个完全连接的层是我们的正则网络,其中每个参数相互连接,以确定每个参数在标签上的真实关系和效果。由于卷积层和池化层大大降低了复杂度,因此我们可以构建一个全连接层来对图像进行分类。一组全连接的层如下所示:

使用MNIST数据集在10分钟内进行图像分类!

具有两个隐藏层的一个全连接层

现在你已经对我们将要使用的各个层有了一些了解,我认为是时候对一个完整的卷积神经网络做出了解了。

使用MNIST数据集在10分钟内进行图像分类!

卷积神经网络实例[3]

现在你对卷积神经网络已经有了基本的了解,你可以建立图像分类,这里我们将使用最老套的分类数据集:MNIST数据集,它代表着国家标准修改后的技术研究所数据库。它是一个大型的手写数字数据库,通常用于训练各种图像处理系统。

下载MNIST数据

MNIST数据集是用于图像分类的最常见的数据集之一,可从许多不同来源访问。实际上,甚至TensorFlow和Keras也允许我们从它们的API中直接导入和下载MNIST数据集。因此,我将从以下两行开始,在KerasAPI下导入TensorFlow和MNIST数据集。

使用MNIST数据集在10分钟内进行图像分类!

MNIST的数据库包含6万张训练图像和1万张测试图像,这些图像来自美国人口普查局的雇员和美国高中生[4]。因此,在第二行中,我将这两组分别作为训练集和测试集,并将标签和图片分开。x_train和x_test部分包含灰度RGB代码(从0到255),y_train和y_test部分包含从0到9的标签,表示它们实际上是哪个数字。为了使这些数字形象化,我们可以从matplotlib中得到帮助。

使用MNIST数据集在10分钟内进行图像分类!

使用MNIST数据集在10分钟内进行图像分类!

当我们运行上面的代码时,我们将得到RGB代码的灰度可视化,如下所示。

使用MNIST数据集在10分钟内进行图像分类!

对索引7777样本图像进行可视化

我们还需要知道数据集的形状,以便将其传输到卷积神经网络。因此,我将使用numpy数组的"form"属性,代码如下:

使用MNIST数据集在10分钟内进行图像分类!

你将得到(60000,28,28)这组数字。正如你可能已经猜到的,60000表示训练数据集中的图像数量,(28,28)表示图像的大小:28x28像素。

图像的重塑和规范化

为了能够在Keras API中使用数据集,我们需要4-dims numpy数组。然而,如上所示,我们的数组是3-dims。此外,我们必须规范化我们的数据,因为它将应用到神经网络模型中。我们可以通过将RGB代码除以255(最大RGB代码减去最小RGB代码)来实现这一点。这可以通过以下代码来实现:

使用MNIST数据集在10分钟内进行图像分类!

使用MNIST数据集在10分钟内进行图像分类!

使用MNIST数据集在10分钟内进行图像分类!

使用MNIST数据集在10分钟内进行图像分类!

建立卷积神经网络

我们将使用高级Keras API来构建模型,Keras API将在后端使用TensorFlow或Theano。我想说的是,有一些TensorFlow API,比如Layer,Keras,和Estimators,这些API可以帮助我们创建具有高级别的神经网络。但是,如果把它们混合起来使用,这可能会导致混淆,因为它们的实现结构各不相同。因此,如果你看到相同神经网络的代码却不尽相同(而且它们都使用了TensorFlow),这就是原因。在本文我将使用最直接的API,即Keras。因此,我将从Keras导入顺序模型并添加Conv2D函数、最大池、Flatten()函数、Dropout和致密层。我已经讨论过Conv2D、最大池和致密层。此外,在构建全连接层之前,dropout层会在训练时忽略一些神经元,将二维数组拉平至一维数组,从而与过度拟合进行斗争。

使用MNIST数据集在10分钟内进行图像分类!

使用MNIST数据集在10分钟内进行图像分类!

使用MNIST数据集在10分钟内进行图像分类!

对于第一致密层,我们可以用任意数进行实验。但是,最后的致密层必须有10个神经元,因为我们有10个编号类(0,1,2,…,9),你可以尝试内核大小、池大小、激活函数、dropout率和第一致密层下神经元的数目以得到更好的结果。

模型的编译和拟合

通过上面的代码,我们创建了一个尚未优化的CNN。现在是时候使用一个给定的损失函数来设置优化器了。让该函数选择一个度量然后,利用训练数据拟合模型。对于这些任务,我们将使用以下代码:

使用MNIST数据集在10分钟内进行图像分类!

你可以在优化器、损失函数、度量和时间方面做出更多的尝试。但是,我认为Adam优化器通常优于其他优化器。我不确定你是否真的能改变损失函数的多类分类。Epoch的确有点小,但它并不影响你的测试精度。也就是说达到98%-99%的测试精度也很正常。因为MNIST数据集不需要很强的计算能力。

评估模型

最后,你可以使用x_test和y_test对经过训练的模型进行评估:

使用MNIST数据集在10分钟内进行图像分类!

对于经过10个epoch的模型来说,这一结果是相当好的。

使用MNIST数据集在10分钟内进行图像分类!

该模型的准确率为98.5%。坦率地说,在很多情况下(例如自动驾驶汽车),我们甚至不能容忍0.1%的误差,因为用一个比喻来说,它会在1000起事故中造成1起事故。然而,对于我们的第一个模型,我想说,结果仍然是相当好的。我们还可以使用以下代码进行单独的预测:

使用MNIST数据集在10分钟内进行图像分类!

我们的模型将图像分类为"9",下面是图像的视觉效果:

使用MNIST数据集在10分钟内进行图像分类!

我们的模型正确地将图像归到数字9这一类

虽然这并不是一个好的手写体数字9,我们的模型能够分类为9。

祝贺你

你已经成功地构建了一个卷积神经网络来使用TensorFlow的KerasAPI对手写数字进行分类(https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model)。你已经达到了98%以上的准确性,现在你甚至可以保存这个模型或创建一个数字分类器应用程序!如果你对这方面抱有极大的热情,我建议你点击学习更多关于这方面的知识。

资源

[1]KataKoda,

https//www.katacoda.com/basiafusinska/courses/tensorflow-getting-started/tensorflow-mnist-beginner

[2]CS231n卷积神经网络用于视觉识别

http://cs231n.github.io/convolutional-networks/

[3]关于深度学习的介绍,https://www.mathworks.com/content/dam/mathworks/tag-team/Objects/d/80879v00_Deep_Learning_ebook.pdf

[4]维基百科,MNIST数据库,https:/en.wikipara.org/wiki/MNIST_DATABASE

使用MNIST数据集在10分钟内进行图像分类!

相关推荐