机器学习入门之《机器学习实战》(1)——kNN算法
小标 2019-03-26 来源 : 阅读 1536 评论 0

摘要:本文主要向大家介绍了机器学习入门之《机器学习实战》(1)——kNN算法,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

本文主要向大家介绍了机器学习入门之《机器学习实战》(1)——kNN算法,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

机器学习入门之《机器学习实战》(1)——kNN算法

   学习《机器学习实战》,对python语言不怎么熟悉,决定一段程序一段程序来学习,既学习算法,也顺便学习python的基础知识。最后,我将把python代码用Java重写一遍。

第一个算法kNNk-邻近算法)

这个算法的理论很简单,很容易理解。如果学过KMeans聚类算法,那么学这个算法会感觉更简单。

我对这个算法的过程理解如下:

第一步:把所有的训练集读入到内存中,这也是这个算法为什么会有空间复杂度高的原因了。

第二步:读入待分类的向量(如果是文本,要处理成向量的方式,VSM模型在这里起作用了)

第三步:计算待分类向量到所有训练集的距离。(既然是向量计算距离,一般用欧式距离就OK了)

第四步:对距离进行从小到大排序,取前k个训练集的Label

第五步:对前K个训练集的Label进行统计。把待分类向量分到Label个数最多的那一个类别。

第六步:算法结束。


学习了过程,再来学习代码的实现。

算法的过程了解后,就很容易得出需要输入的参数:待分类文本,训练集,训练集标签,k值。输出的参数:待分类文本的分类标签。

输入输出解决了的话,至少解决了问题的三分之一。

def classify(inX,dataSet,labels,k):    //
    dataSetSize=dataSet.shape[0]
    diffMat=tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat=diffMat**2
    sqDistances=sqDiffMat.sum(axis=1)
    distances=sqDistances**0.5
    sortedDistIndicies=distances.argsort();
    classCount={}    for i in range(k):
        voteIlabel=labels[sortedDistIndicies[i]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
    sortedClassCount=sorted(classCount.iteritems(),
                            key=operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]

python版的翻译成java版的:

package com.vancl.knn;import java.util.ArrayList;import java.util.Arrays;import java.util.Collections;import java.util.Comparator;import java.util.HashMap;import java.util.Map;import java.util.Map.Entry;public class KNN {    /*
     * @param inX 待分类的文本
     * @param dataSet 训练集
     * @param labels 训练集的分类标签
     * @param k值
     * @return 分类器得到的分类标签
     * */    public char classify(double[] inX,double[][] dataSet,char[] labels,int k){        //对应python 的dataSet.shape()[0]        int dataSetSize=dataSet.length;        //对应python 的tile(inX,(dataSetSize,1))-dataSet        //和diffMat**2 两行代码        double[][] sqDiffMat=createDiffMat(inX,dataSet,dataSetSize);        //对应python 的sqdiffMat.sum(axis=1);        //和distances=sqDistances**0.5两行代码        double[] distances=sum(sqDiffMat);
                       
        Node[] disNode=new Node[distances.length];        for(int i=0;i<distances.length;i++){
            Node node=new Node(distances[i],i);
            disNode[i]=node;        }        //对应python中 sortedDistaIndicies=distances.argsort(),排序得到下标
        Arrays.sort(disNode,new KNNCompartor());        //选择距离最小的k个点        //对应pyhton的classCount={}
        Map<Character,Integer> classCount=new HashMap<Character,Integer>();                       
        char voteLabel;        //对应python的 for i in range(k):        for(int i=0;i<k;i++){            //对应python的voteIlabel=labels[sortedDistIndicies[i]]
            voteLabel=labels[disNode[i].idx];            //对应 classCount[voteIlabel]=classCount.get(voteIlabel,0)+1            add(voteLabel,classCount);        }        //sortedClassCount=sorted(classCount.iteritems(),        //       key=operator.itemgetter(1),reverse=True)
        ArrayList<Map.Entry<Character,Integer>> l = new ArrayList<Map.Entry<Character,Integer>>(classCount.entrySet()); 
        Collections.sort(l,new Comparator<Map.Entry<Character,Integer>>(){            @Override            public int compare(Entry<Character, Integer> o1,
                    Entry<Character, Integer> o2) {                               
                return o2.getValue()-o1.getValue();            }        });        //对应 return sortedClassCount[0][0]        return l.get(0).getKey();    }                   
    public void add(char voteLabel,Map<Character,Integer> classCount){
        Integer id=classCount.get(voteLabel);        if(id==null) id=0;
        classCount.put(voteLabel, id+1);    }                   
    private double[] sum(double[][] sqDiffMat) {        int i,j;        double[] sqDistances=new double[sqDiffMat.length];        for(i=0;i<sqDiffMat.length;i++){
            sqDistances[i]=0;            for(j=0;j<sqDiffMat[i].length;j++){
                sqDistances[i]+=sqDiffMat[i][j];            }
            sqDistances[i]=Math.sqrt(sqDistances[i]);        }        return sqDistances;    }    private double[][] createDiffMat(double[] inX, double[][] dataSet,int dataSetSize) {        double[][] diffMat=new double[dataSetSize][inX.length];        for(int i=0;i<dataSetSize;i++){
            System.arraycopy(inX, 0, diffMat[i], 0, inX.length);            for(int j=0;j<inX.length;j++){
                diffMat[i][j]=diffMat[i][j]-dataSet[i][j];
                diffMat[i][j]=Math.pow(diffMat[i][j], 2);            }        }                           
        return diffMat;    }    class Node{        public Node(double value, int idx) {            super();            this.value = value;            this.idx = idx;        }        double value;        int idx;    }    class KNNCompartor implements Comparator<Node>{        @Override        public int compare(Node o1, Node o2) {            return Double.compare(o1.value, o2.value);        }                       
    }    public static void main(String[] args) {
        KNN knn=new KNN();        double[] inX={1,1};
       "

本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!

本文由 @小标 发布于职坐标。未经许可,禁止转载。
喜欢 | 0 不喜欢 | 0
看完这篇文章有何感觉?已经有0人表态,0%的人喜欢 快给朋友分享吧~
评论(0)
后参与评论

您输入的评论内容中包含违禁敏感词

我知道了

助您圆梦职场 匹配合适岗位
验证码手机号,获得海同独家IT培训资料
选择就业方向:
人工智能物联网
大数据开发/分析
人工智能Python
Java全栈开发
WEB前端+H5

请输入正确的手机号码

请输入正确的验证码

获取验证码

您今天的短信下发次数太多了,明天再试试吧!

提交

我们会在第一时间安排职业规划师联系您!

您也可以联系我们的职业规划师咨询:

小职老师的微信号:z_zhizuobiao
小职老师的微信号:z_zhizuobiao

版权所有 职坐标-一站式AI+学习就业服务平台 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
 沪公网安备 31011502005948号    

©2015 www.zhizuobiao.com All Rights Reserved