机器学习实践:基于支持向量机算法对鸢尾花进行分类
摘要:List item使用scikit-learn机器学习包的支持向量机算法,使用全部特征对鸢尾花进行分类。
本文分享自华为云社区《支持向量机算法之鸢尾花特征分类【机器学习】》,作者:上进小菜猪。
一.前言
1.1 本文原理
支持向量机(SVM)是一种二元分类模型。它的基本模型是在特征空间中定义最大区间的线性分类器,这使它不同于感知器;支持向量机还包括核技术,这使得它本质上是一个非线性分类器。支持向量机的学习策略是区间最大化,它可以形式化为求解凸二次规划的问题,等价于正则化铰链损失函数的最小化。支持向量机的学习算法是求解凸二次规划的优化算法。Scikit learn(sklearn)是机器学习中常见的第三方模块。它封装了常见的机器学习方法,包括回归、降维、分类、聚类等。
1.2 本文目的
- List item使用scikit-learn机器学习包的支持向量机算法,使用全部特征对鸢尾花进行分类;
- 使用scikit-learn机器学习包的支持向量机算法,设置SVM对象的参数,包括kernel、gamma和C,分别选择一个特征、两个特征、三个特征,写代码对鸢尾花进行分类;
- 使用scikit-learn机器学习包的支持向量机算法,选择特征0和特征2对鸢尾花分类并画图,gamma参数分别设置为1、10、100,运行程序并截图,观察gamma参数对训练分数(score)的影响,请说明如果错误调整gamma参数会产生什么问题?
二.实验过程
2.1 支持向量机算法SVM
实例的特征向量(以2D为例)映射到空间中的一些点,如下图中的实心点和空心点,它们属于两个不同的类别。支持向量机的目的是画一条线来“最好”区分这两类点,这样,如果将来有新的点,这条线也可以很好地进行分类。
![](https://pic1.zhimg.com/80/v2-aed97645c0cae81d5a579a1d5c50a1cc_720w.jpg)
2.2List item使用scikit-learn机器学习包的支持向量机算法,使用全部特征对鸢尾花进行分类;
首先引入向量机算法svm模块:
from sklearn import svm
还是老样子,使用load_iris模块,里面有150组鸢尾花特征数据,我们可以拿来进行学习特征分类。
如下代码:
from sklearn.datasets import load_iris iris = load_iris() X = iris.data print(X.shape, X) y = iris.target print(y.shape, y)
下面使用sklearn.svm.SVC()函数。
C-支持向量分类器如下:
svm=svm.SVC(kernel='rbf',C=1,gamma='auto')
使用全部特征对鸢尾花进行分类
svm.fit(X[:,:4],y)
输出训练得分:
print("training score:",svm.score(X[:,:4],y)) print("predict: ",svm.predict([[7,5,2,0.5],[7.5,4,7,2]]))
使用全部特征对鸢尾花进行分类训练得分如下:
![](https://pic3.zhimg.com/80/v2-9f1703034dc71aa95eb56eaa7c72a946_720w.jpg)
2.3 使用scikit-learn机器学习包的支持向量机算法,设置SVM对象的参数,包括kernel、gamma和C,分别选择一个特征、两个特征、三个特征,写代码对鸢尾花进行分类;
2.3.1 使用一个特征对鸢尾花进行分类
上面提过的基础就不再写了。如下代码:
![](https://pic4.zhimg.com/80/v2-02c66db24cbfdeb3f6ea2a2435f4bc37_720w.jpg)
使用一个特征对鸢尾花进行分类,如下代码:
svm=svm.SVC()
svm.fit(X,y)
输出训练得分:
print("training score:",svm.score(X,y)) print("predict: ",svm.predict([[7,5,2,0.5],[7.5,4,7,2]]))
使用一个特征对鸢尾花进行分类训练得分如下:
![](https://pic3.zhimg.com/80/v2-94be876eae7902378a9835a238fadb2a_720w.jpg)
2.3.2 使用两个特征对鸢尾花进行分类
使用两个特征对鸢尾花进行分类,如下代码:
svm=svm.SVC() svm.fit(X[:,:1],y)
输出训练得分:
print("training score:",svm.score(X[:,:1],y)) print("predict: ",svm.predict([[7],[7.5]]))
使用两个特征对鸢尾花进行分类训练得分如下:
![](https://pic2.zhimg.com/80/v2-e8c78aa47c363977d7ed41adcf798cdd_720w.jpg)
2.3.3 使用三个特征对鸢尾花进行分类
使用三个特征对鸢尾花进行分类,如下代码:
svm=svm.SVC(kernel='rbf',C=1,gamma='auto') svm.fit(X[:,1:3],y)
输出训练得分:
print("training score:",svm.score(X[:,1:3],y)) print("predict: ",svm.predict([[7,5],[7.5,4]]))
使用三个特征对鸢尾花进行分类训练得分如下:
![](https://pic4.zhimg.com/80/v2-6c567b4ae366a9e56a8e2fffa9981917_720w.jpg)
2.3.4 可视化三个特征分类结果
使用plt.subplot()函数用于直接指定划分方式和位置进行绘图。
x_min,x_max=X[:,1].min()-1,X[:,1].max()+1 v_min,v_max=X[:,2].min()-1,X[:,2].max()+1 h=(x_max/x_min)/100 xx,vy =np.meshgrid(np.arange(x_min,x_max,h),np.arange(v_min,v_max,h)) plt.subplot(1,1,1) Z=svm.predict(np.c_[xx.ravel(),vy.ravel()]) Z=Z.reshape(xx.shape)
绘图,输出可视化。如下代码
plt.contourf(xx,vy,Z,cmap=plt.cm.Paired,alpha=0.8) plt.scatter(X[:, 1], X[:, 2], c=y, cmap=plt.cm.Paired) plt.xlabel('Sepal width') plt.vlabel('Petal length') plt.xlim(xx.min(), xx.max()) plt.title('SVC with linear kernel') plt.show()
可视化三个特征分类结果图:
![](https://pic2.zhimg.com/80/v2-4ea23e87bf7006323854924f9fd99079_720w.jpg)
2.4使用scikit-learn机器学习包的支持向量机算法,选择特征0和特征2对鸢尾花分类并画图,gamma参数分别设置为1、10、100,运行程序并截图,观察gamma参数对训练分数(score)的影响,请说明如果错误调整gamma参数会产生什么问题?
2.4.1当gamma为1时:
讲上文的gamma='auto‘ 里的auto改为1,得如下代码:
svm=svm.SVC(kernel='rbf',C=1,gamma='1') svm.fit(X[:,1:3],y)
运行上文可视化代码,得如下结果:
![](https://pic2.zhimg.com/80/v2-0fa8054d5427739f189b3ac48e7f1521_720w.jpg)
![](https://pic3.zhimg.com/80/v2-40d0a8ee03a548c76667e0bb8d76d796_720w.jpg)
2.4.2当gamma为10时:
讲上文的gamma='auto‘ 里的auto改为10,得如下代码:
svm=svm.SVC(kernel='rbf',C=1,gamma='10') svm.fit(X[:,:3:2],y)
运行上文可视化代码,得如下结果:
![](https://pic3.zhimg.com/80/v2-c146fe92fb576c521f7c713a9c2d5aba_720w.jpg)
![](https://pic1.zhimg.com/80/v2-5cbdd9b1b0c8ee89b0dc8a38486edabc_720w.jpg)
2.4.3当gamma为100时:
讲上文的gamma='auto‘ 里的auto改为100,得如下代码:
svm=svm.SVC(kernel='rbf',C=1,gamma='100') svm.fit(X[:,:3:2],y)
运行上文可视化代码,得如下结果:
![](https://pic3.zhimg.com/80/v2-11be9c73fad654c1aef41fc0dcbf91a6_720w.jpg)
![](https://pic3.zhimg.com/80/v2-639d51a6f7504408dc565f348a5a9742_720w.jpg)
2.4.4 结论
参数gamma主要是对低维的样本进行高度度映射,gamma值越大映射的维度越高,训练的结果越好,但是越容易引起过拟合,即泛化能力低。通过上面的图可以看出gamma值越大,分数(score)越高。错误使用gamma值可能会引起过拟合,太低可能训练的结果太差。
相关文章
- 【Laravel】在企业级项目中使用Laravel框架中的工厂状态下的页面方法 Code Verifier以及错误处理
- 结构建模设计——Solidworks软件之使用钣金折弯功能做一个带折弯固定口的铝合金面板
- Adobe Acrobat Pro DC 世界上最优秀的桌面版 PDF 文档创建 / 编辑 / 审查软件
- 计算机等级二级Mysql综合应用题汇总2022.9.24
- 说说前端面试比较好的回答
- promise执行顺序面试题令我头秃
- 结合RocketMQ 源码,带你了解并发编程的三大神器
- 云小课|云小课带你玩转可视化分析ELB日志
- 打造无证服务化:这个政务服务平台有点不一样
- 【云享·人物】华为云AI高级专家白小龙:AI如何释放应用生产力,向AI工程化前行?
- 云小课|云小课教您如何选择Redis实例类型
- DTSE Tech Talk 第13期:Serverless凭什么被誉为未来云计算范式?
- 云原生微服务治理技术朝无代理架构的演进之路
- 华夏天信携手华为云开天aPaaS,打造安全、高效、节能的主煤流运输系统
- 要想后期修改少,代码重构要趁早
- 千年荒漠变绿洲,看沙漠“卫士”携手昇腾AI植起绿色希望
- 探讨Morest在RESTful API测试的行业实践
- FCOS论文复现:通用物体检测算法
- 基于OpenHarmony L2设备,如何用IoTDeviceSDKTiny对接华为云
- 网站停服、秒杀大促…解析高可用网站架构云化