摘要:本文主要向大家介绍了机器学习入门之机器学习算法:k近邻,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。
本文主要向大家介绍了机器学习入门之机器学习算法:k近邻,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。
k近邻(k-Nearest Neighbor,KNN)算法,应该是机器学习里最基础的算法,其核心思想是:给定一个未知分类的样本,如果与它最相似的k个已知样本中的多数属于某一个分类,那么这个未知样本也属于这个分类。
所谓相似,是指两个样本之间的欧氏距离小,其计算公式为:
其中Xi为样本X的第i个特征。
k近邻算法的优点在于实现简单,缺点在于时间和空间复杂度高。
上C#版代码,这里取k=1,即只根据最相近的一个点确定分类:
首先是DataVector,包含N维数据和分类标签,用于表示一个样本。
using System;namespace MachineLearning{ /// <summary> /// 数据向量 /// </summary> /// <typeparam name=""T""></typeparam> public class DataVector<T> { /// <summary> /// N维数据 /// </summary> public T[] Data { get; private set; } /// <summary> /// 分类标签 /// </summary> public string Label { get; set; } /// <summary> /// 构造 /// </summary> /// <param name=""dimension"">数据维度</param> public DataVector(int dimension) { Data = new T[dimension]; } public int Dimension { get { return this.Data.Length; } } }}
然后是核心算法:
using System;using System.Collections.Generic;namespace MachineLearning{ /// <summary> /// k近邻法 /// </summary> public class NearestNeighbour { private int m_K; private List<DataVector<double>> m_TrainingSet; public NearestNeighbour(int k = 1) { m_K = k; } /// <summary> /// 训练 /// </summary> /// <param name=""trainingSet""></param> public void Train(List<DataVector<double>> trainingSet) { m_TrainingSet = trainingSet; } /// <summary> /// 分类 /// </summary> /// <param name=""vector""></param> /// <returns></returns> public string Classify(DataVector<double> vector) { //K=1时可简化处理提高效率 if(m_K == 1) { double minDist = double.PositiveInfinity; int targetIndex = -1; for(int i = 0;i < m_TrainingSet.Count;i++) { //计算距离 double distance = ComputeDistance(vector, m_TrainingSet[i], minDist); //找最小值 if(distance < minDist) { minDist = distance; targetIndex = i; } } return m_TrainingSet[targetIndex].Label; } else { var dict = new SortedDictionary<double, string>(); for(int i = 0;i < m_TrainingSet.Count;i++) { //计算距离并记录 double distance = ComputeDistance(vector, m_TrainingSet[i]); dict[distance] = m_TrainingSet[i].Label; } //找最多的Label var labels = new List<string>(); int count = 0; foreach(var label in dict.Values) { labels.Add(label); if(++count > m_K - 1) break; } return GetMajorLabel(labels); } } /// <summary> /// 计算距离 /// </summary> /// <param name=""v1""></param> /// <param name=""v2""></param> /// <param name=""minValue""></param> /// <returns></returns> private double ComputeDistance(DataVector<double> v1, DataVector<double> v2, double minValue = double.PositiveInfinity) { double distance = 0.0; minValue = minValue * minValue; for(int i = 0;i < v1.Data.Length;++i) { double diff = v1.Data[i] - v2.Data[i]; distance += diff * diff; //如果当前累加的距离已经大于给定的最小值,不用继续计算了 if(distance > minValue) return double.PositiveInfinity; } return Math.Sqrt(distance); } /// <summary> /// 取多数 /// </summary> /// <param name=""dataSet""></param> /// <returns></returns> private string GetMajorLabel(List<string> labels) { var dict = new Dictionary<string, int>(); foreach(var item in labels) { if(!dict.ContainsKey(item)) dict[item] = 0; dict[item]++; } string label = string.Empty; int count = -1; foreach(var key in dict.Keys) { if(dict[key] > count) { label = key; count = dict[key]; } } return label; } }}
需要注意的是,计算距离时,数量级大的维度会对距离影响大,因此大多数情况下,不能直接计算,要对原始数据做归一化,并根据重要性进行加权。归一化可以使用公式:value = (old-min)/(max-min),其中old是原始值,max是所有数据的最大值,min是所有数据的最小值。这样计算得到的value将落在0至1的区间上。
这个算法太简单,暂时不上测试代码了,有时间再补吧。
本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!
您输入的评论内容中包含违禁敏感词
我知道了
请输入正确的手机号码
请输入正确的验证码
您今天的短信下发次数太多了,明天再试试吧!
我们会在第一时间安排职业规划师联系您!
您也可以联系我们的职业规划师咨询:
版权所有 职坐标-一站式IT培训就业服务领导者 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
沪公网安备 31011502005948号