博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【机器学习】sklearn机器学习入门案例——使用k近邻算法进行鸢尾花分类
阅读量:2019 次
发布时间:2019-04-28

本文共 4699 字,大约阅读时间需要 15 分钟。

1 背景

这个案例恐怕已经被说的很烂了,机器学习方面不同程度的人对该案例还是有着不同的感觉。有的人追求先理解机器学习背后的理论甚至自己推导一遍相关数学公式,再用代码实现;有的人则满足于能够实现相关功能即可。凡是都有两面性,理解算法背后原理,再去实现相关算法,这个对算法理解深刻,更能融会贯通,拓展性强,但是需要有一定的数学基础以及要花费一段时间;若能够实现相关算法,知道各个参数的意义,也是能够尽快处理相关的任务,但是可拓展性就不那么好了。

当前,对于传统的机器学习算法进行很好的实现的Python包也当属sklearn了,本文更注重使用sklearn提供的算法包去完成鸢尾花分类任务,也不用去把相关算法去逐一实现(就不要去造轮子了)。对于相关算法理论只作简要介绍。

2 任务背景

  1. 假设你有一份数据集:有很多不同类别的鸢尾花,每一条数据有多个特征:花瓣(petal)长、宽花萼(sepal)长、宽,以及这条数据对应的类别,也就是说有4个特征,1个标签。
  2. 任务是:使用这些数据采用监督学习的机器学习相关算法训练一个模型,然后对没有类别,只有花瓣长宽、花萼长宽的数据预测其所属类别
  3. 数据背景:这里使用的sklearn中自带的数据集,其中鸢尾花的类别有三种setosa、versicolor、virginica,相关数据将在后面的内容详细介绍。
  4. 实验环境:Python3.7, sklearn 0.22.1
    图来自:introduction_to_ml_with_python

3 了解一下数据

from sklearn.datasets import load_irisiris_dataset = load_iris()print("keys of iris_dataset:{}\n".format(iris_dataset.keys()))# 结果"""keys of iris_dataset:dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])"""

可通过以下方式查看数据形式:

import numpy as npfor key in iris_dataset.keys():    print('current key:"{}", key type:{}'.format(key, type(iris_dataset[key])))    # 如果为np.ndarray 可以说明是训练数据以及对应的标签    if isinstance(iris_dataset[key], np.ndarray):        print(iris_dataset[key][0])    elif isinstance(iris_dataset[key], str):        print(iris_dataset[key][:150])    else:        print(iris_dataset[key]))

结果如如下:

current key:"data", key type:
[5.1 3.5 1.4 0.2]current key:"target", key type:
0current key:"target_names", key type:
setosacurrent key:"DESCR", key type:
.. _iris_dataset:Iris plants dataset--------------------**Data Set Characteristics:** :Number of Instances: 150 (50 in each of three classescurrent key:"feature_names", key type:
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']current key:"filename", key type:
C:\software\Development\Python\Anaconda3\envs\commEnv\lib\site-packages\sklearn\datasets\data\iris.csv

数据的具体介绍可以通过iris_dataset[“DESCR”]查看,其他数据也比较清楚了,如:每一条数据如:[5.1 3.5 1.4 0.2]对应的特征名称为:花萼长、宽,花瓣长、宽,数据样本一共有150条;iris_dataset[“target”]是使用0,1,2分别代表setosa,versicolor,virginica这三个类别的。

4 数据预处理

数据只有150,需要将一部分数据拿出来训练得到一个模型,剩余的数据留出来验证模型的泛化能力,也就是看这个模型对不在训练集中的数据的识别能力。当然有时候将还会将数据集划再划分一个验证集,这是因为在训练集中可以构建多个模型,通过训练集选出一个泛化能力更好的模型,然后在使用这个更好的模型在测试集数据上进行测试,这些复杂的过程(根据哪些指标在备选的模型中挑选最好等)暂且不谈。

这里使用scikit-learn中的train_test_split函数打乱数据集(原始数据是每一类都在一起,需要将这个顺序打乱)并拆分为两个部分,默认将75%的数据划分为训练集train_data,25%的数据划分为测试集test_data。操作如下

from sklearn.model_selection import train_test_split# 划分数据时设置随机数种子,random_state 便于实验的复现X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)  # 查看划分后的数据集情况print("X_train shape:{}".format(X_train.shape))print("y_train shape:{}".format(y_train.shape))print("X_test shape:{}".format(X_test.shape))print("y_test shape:{}".format(y_test.shape))"""X_train shape:(112, 4)y_train shape:(112,)X_test shape:(38, 4)y_test shape:(38,)"""

数据划分好了,也不要急着去构建机器学习模型,我们先来观察一下数据,可视化一下。由于数据集有4个特征,我们绘制散点图矩阵,也就是查看数据的两两特征之间的情况。绘制代码如下:

import pandas as pdiris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)iris_dataframe.head(5)# 绘图grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={
'bins':20}, s=60, alpha=0.8)

绘图结果

从上面的图可以看出,使用数据集中的几个特征是可以将各个类型的数据区别开的,也就是说使用机器学习算法也就能够区分开了。

5 模型构建

在机器学习中,分类的算法有很多,这里使用一个简单比较容易理解的KNN算法来分类。这里的K的含义是,待预测的数据与训练集中最近的任意k个邻居,根据这k个邻居的类别确定这个预测数据的类别。下面是构建模型的代码:

from sklearn.neighbors import KNeighborsClassifier# 1 构建模型对象,设置相关参数knn = KNeighborsClassifier(n_neighbors=1)  # 根据最近的一个训练数据来确定类别# 2 训练模型knn.fit(X_train, y_train)# knn会被返回,如下:"""KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',                     metric_params=None, n_jobs=None, n_neighbors=1, p=2,                     weights='uniform')"""

当然训练后的一个knn模型包含很多参数,很多参数用于速度优化或非常特殊的用途。

6 模型评估

这里我们使用一个简单的指标精度

y_pred = knn.predict(X_test)print('Test set predictions:{}\n'.format(y_pred))print("Test set accuracy:{:.2f}%".format(100*np.mean(y_pred==y_test)))

结果如下:

Test set predictions:[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2]Test set accuracy:97.37%

总体来看效果还是可以的。

7 预测数据

我们根据训练模型的数据构建预测数据的格式,例如:X_new=[[5, 2.9, 1, 0.2]],预测代码如下:

X_new=[[5, 2.9, 1, 0.2]]# 开始预测prediction=knn.predict(X_new)# 预测结果print("Prediction:{}".format(prediction))print("Predicted target name:{}".format(iris_dataset['target_names'][prediction]))"""Prediction:[0]Predicted target name:['setosa']"""

8 总结

其实上面是一个使用sklearn的一个简单入门,不管是理论还是模型调参所涉及的内容都还没介绍太多,不过上面也涉及使用sklearn进行机器学习的大部分步骤。当然也不要被当前深度学习“遍地飞”所震慑,传统的机器学习算法依然在一些领域中发挥着很重要的作用。当然,需要在这个领域中还要进行更深一步的学习和深入。当然本部分内容是参考《Python机器学习基础教程》内容并结合自己的理解写出,所以我还是推荐​一下这本书,或者可以在订阅号“AIAS编程有道”中回复“Python机器学习基础教程”获取电子档后决定​是否要购买,建议购买正版书籍。​

课程推荐

图解Python数据结构与算法-实战篇

转载地址:http://lmlxf.baihongyu.com/

你可能感兴趣的文章
jquery实现星级评分
查看>>
Hogan模板引擎
查看>>
转:vue+canvas如何实现b站萌系登录界面
查看>>
vue和echarts 封装的 v-charts 图表组件
查看>>
转:vue+element实现树形组件
查看>>
mpVue小程序全栈开发
查看>>
树形插件 --- zTree
查看>>
zTree demo
查看>>
Vue的双向数据绑定
查看>>
Vue项目启动后首页URL带的#该怎么去掉?
查看>>
如何从现有版本升级到element UI2.0?使用npm-check-updates
查看>>
vue-router两种模式,到底什么情况下用hash,什么情况下用history模式呢?
查看>>
转:MySQL如何修改密码
查看>>
Lodash JavaScript 实用工具库
查看>>
转:Session,Token相关区别
查看>>
面试图谱
查看>>
实现自己的Koa2
查看>>
小程序商城笔记
查看>>
vue-router 基本使用
查看>>
使用VW时,图片的问题
查看>>