机器学习入门之机器学习算法:k近邻
小标 2019-03-26 来源 : 阅读 984 评论 0

摘要:本文主要向大家介绍了机器学习入门之机器学习算法:k近邻,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

本文主要向大家介绍了机器学习入门之机器学习算法:k近邻,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

机器学习入门之机器学习算法:k近邻

k近邻(k-Nearest Neighbor,KNN)算法,应该是机器学习里最基础的算法,其核心思想是:给定一个未知分类的样本,如果与它最相似的k个已知样本中的多数属于某一个分类,那么这个未知样本也属于这个分类


所谓相似,是指两个样本之间的欧氏距离小,其计算公式为:

机器学习入门之机器学习算法:k近邻

其中Xi为样本X的第i个特征。

k近邻算法的优点在于实现简单,缺点在于时间和空间复杂度高。


上C#版代码,这里取k=1,即只根据最相近的一个点确定分类:

首先是DataVector,包含N维数据和分类标签,用于表示一个样本。

using System;namespace MachineLearning{    /// <summary>    /// 数据向量    /// </summary>    /// <typeparam name=""T""></typeparam>
    public class DataVector<T>    {        /// <summary>        /// N维数据        /// </summary>
        public T[] Data { get; private set; }        /// <summary>        /// 分类标签        /// </summary>
        public string Label { get; set; }        /// <summary>        /// 构造        /// </summary>        /// <param name=""dimension"">数据维度</param>
        public DataVector(int dimension)        {
            Data = new T[dimension];        }
        
        public int Dimension        {
            get { return this.Data.Length; }        }    }}


然后是核心算法:

using System;using System.Collections.Generic;namespace MachineLearning{    /// <summary>    /// k近邻法    /// </summary>
    public class NearestNeighbour    {
        private int m_K;
        private List<DataVector<double>> m_TrainingSet;
        
        public NearestNeighbour(int k = 1)        {
            m_K = k;        }        
        /// <summary>        /// 训练        /// </summary>        /// <param name=""trainingSet""></param>
        public void Train(List<DataVector<double>> trainingSet)        {
            m_TrainingSet = trainingSet;        }        /// <summary>        /// 分类        /// </summary>        /// <param name=""vector""></param>        /// <returns></returns>
        public string Classify(DataVector<double> vector)        {            //K=1时可简化处理提高效率            if(m_K == 1)            {                double minDist = double.PositiveInfinity;                int targetIndex = -1;                for(int i = 0;i < m_TrainingSet.Count;i++)                {                    //计算距离                    double distance = ComputeDistance(vector, m_TrainingSet[i], minDist);                    //找最小值                    if(distance < minDist)                    {
                        minDist = distance;
                        targetIndex = i;                    }                }            
                return m_TrainingSet[targetIndex].Label;            }            else            {
                var dict = new SortedDictionary<double, string>();                
                for(int i = 0;i < m_TrainingSet.Count;i++)                {                    //计算距离并记录                    double distance = ComputeDistance(vector, m_TrainingSet[i]);
                    dict[distance] = m_TrainingSet[i].Label;                }                
                //找最多的Label
                var labels = new List<string>();                int count = 0;                foreach(var label in dict.Values)                {
                    labels.Add(label);                    if(++count > m_K - 1)                        break;                }                return GetMajorLabel(labels);            }        }    
        /// <summary>        /// 计算距离        /// </summary>        /// <param name=""v1""></param>        /// <param name=""v2""></param>        /// <param name=""minValue""></param>        /// <returns></returns>
        private double ComputeDistance(DataVector<double> v1, DataVector<double> v2, double minValue = double.PositiveInfinity)        {            double distance = 0.0;
            minValue = minValue * minValue;            for(int i = 0;i < v1.Data.Length;++i)            {                double diff = v1.Data[i] - v2.Data[i];
                distance += diff * diff;            
                //如果当前累加的距离已经大于给定的最小值,不用继续计算了                if(distance > minValue)                    return double.PositiveInfinity;            }            return Math.Sqrt(distance);        }        
        /// <summary>        /// 取多数        /// </summary>        /// <param name=""dataSet""></param>        /// <returns></returns>
        private string GetMajorLabel(List<string> labels)        {
            var dict = new Dictionary<string, int>();            foreach(var item in labels)            {                if(!dict.ContainsKey(item))
                    dict[item] = 0;
                dict[item]++;            }

            string label = string.Empty;            int count = -1;            foreach(var key in dict.Keys)            {                if(dict[key] > count)                {
                    label = key;
                    count = dict[key];                }            }            
            return label;        }    }}


需要注意的是,计算距离时,数量级大的维度会对距离影响大,因此大多数情况下,不能直接计算,要对原始数据做归一化,并根据重要性进行加权。归一化可以使用公式:value = (old-min)/(max-min),其中old是原始值,max是所有数据的最大值,min是所有数据的最小值。这样计算得到的value将落在0至1的区间上。


这个算法太简单,暂时不上测试代码了,有时间再补吧。


本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!

本文由 @小标 发布于职坐标。未经许可,禁止转载。
喜欢 | 0 不喜欢 | 0
看完这篇文章有何感觉?已经有0人表态,0%的人喜欢 快给朋友分享吧~
评论(0)
后参与评论

您输入的评论内容中包含违禁敏感词

我知道了

助您圆梦职场 匹配合适岗位
验证码手机号,获得海同独家IT培训资料
选择就业方向:
人工智能物联网
大数据开发/分析
人工智能Python
Java全栈开发
WEB前端+H5

请输入正确的手机号码

请输入正确的验证码

获取验证码

您今天的短信下发次数太多了,明天再试试吧!

提交

我们会在第一时间安排职业规划师联系您!

您也可以联系我们的职业规划师咨询:

小职老师的微信号:z_zhizuobiao
小职老师的微信号:z_zhizuobiao

版权所有 职坐标-一站式IT培训就业服务领导者 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
 沪公网安备 31011502005948号    

©2015 www.zhizuobiao.com All Rights Reserved

208小时内训课程