寸先生的AI道路 2018-07-19
尝试使用过各大公司推出的植物识别APP吗?比如微软识花、花伴侣等这些APP。当你看到一朵不知道学名的花时,只需要打开植物识别APP,拍摄一张你所想辨认的植物照片并上传,APP会自动识别出该花的品种及详细介绍,感觉手机中装了一个知识渊博的生物学家,是不是很神奇?其实,背后的原理很简单,是一个图像分类的过程,将上传的图像与手机中预存的数据集或联网数据进行匹配,将其分类到对应的类别即可。随着深度学习方法的应用,图像分类的精度越来越高,在部分数据集上已经超越了人眼的能力。
相对于传统神经网络的方法而言,深度学习方法一般对数据集规模、硬件平台有着比较高的要求,如果只是单纯的想尝试了解图像分类任务的基本流程,建议采用小数据集样本及传统的神经网络方法实现。本文将带领读者采用鸢尾属植物数据集(Iris Data Set)来实现一个分类任务,整个鸢尾属植物数据集是机器学习中历史悠久的数据集,比现在常用的数字手写体数据集(Mnist Data Set)数据集还要早得多,该数据集来源于英国著名的统计学家、生物学家Ronald Fiser。本文在不使用相关软件库的情况下,从头开始构建针对鸢尾属植物数据的神经网络模型,对其进行训练并获得好的结果。
鸢尾属植物数据集是用于测试机器学习算法的最常用数据集。该数据包含四种特征,萼片长度、萼片宽度、花瓣长度和花瓣宽度,用于鸢尾属植物的不同物种(versicolor, virginica和setosa)。此外,每个物种有50个实例(数据行),下面让我们看看样本数据分布情况。
我们将在这个数据集上使用神经网络构建分类模型。为了简单起见,使用花瓣长度和花瓣宽度作为特征,且只有两类物种:versicolor和virginica。下面就让我们在Python中逐步训练针对该样本数据集的神经网络:
将Iris数据集导入python并对数据进行子集划分以保留行之间的相关性:
蓝色点代表Versicolor物种,红色点代表Virginica物种。本文构建的神经网络将在这些数据上进行训练,以期最后能正确地分类物种。
下面构建一个具有单个隐藏层的神经网络。此外,将隐藏图层的大小设置为6:
在前向传播过程中,使用tanh激活函数作为第一层的激活函数,使用sigmoid激活函数作为第二层的激活函数:
目标是使得计算的代价函数小化,本文采用交叉熵(cross-entropy)作为代价函数:
计算反向传播过程,主要是计算代价函数的导数:
使用反向传播过程中计算的梯度来更新权重和偏置:
步骤7:建立神经网络
将以上所有函数组合起来以创建设计的神经网络模型。总而言之,下面是模型函数的整体顺序:
将隐藏层节点设置为6,最大迭代次数设置为10,000次,并每隔1000次打印出训练的结果:
从图中可以观察到,只有四个点被错误分类。虽然我们可以调整模型来进一步地提高模型训练精度,但该些操作显然会导致过拟合现象的出现。
Rohan Joseph,数据科学家
本文由阿里云云栖社区组织翻译。
文章原标题《Neural network on iris data》,译者:海棠,审校:Uncle_LLD。