小标
2019-03-26
来源 :
阅读 1484
评论 0
摘要:本文主要向大家介绍了机器学习入门之机器学习算法:决策树,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。
本文主要向大家介绍了机器学习入门之机器学习算法:决策树,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。

决策树(Decision Tree)的核心思想是:根据训练样本构建这样一棵树,使得其叶节点是分类标签,非叶节点是判断条件,这样对于一个未知样本,能在树上找到一条路径到达叶节点,就得到了它的分类。
举个简单的例子,如何识别有毒的蘑菇?如果能够得到一棵这样的决策树,那么对于一个未知的蘑菇就很容易判断出它是否有毒了。
它是什么颜色的?
|
-------鲜艳---------浅色----
| |
有毒 有什么气味?
|
-----刺激性--------无味-----
| |
有毒 安全构建决策树有很多算法,常用的有ID3、C4.5等。本篇以ID3为研究算法。
构建决策树的关键在于每一次分支时选择哪个特征作为分界条件。这里的原则是:选择最能把数据变得有序的特征作为分界条件。所谓有序,是指划分后,每一个分支集合的分类尽可能一致。用信息论的方式表述,就是选择信息增益最大的方式划分集合。
所谓信息增益(information gain),是指变化前后熵(entropy)的增加量。为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
其中H为熵,n为分类数目,p(xi)是选择该分类的概率。
根据公式,计算一个集合熵的方式为:
计算每个分类出现的次数
foreach(每一个分类)
{
计算出现概率
根据概率计算熵
累加熵
}
return 累加结果判断如何划分集合,方式为:
foreach(每一个特征)
{
计算按此特征切分时的熵
计算与切分前相比的信息增益
保留能产生最大增益的特征为切分方式
}
return 选定的特征构建树节点的方法为:
if(集合没有特征可用了)
{
按多数原则决定此节点的分类
}
else if(集合中所有样本的分类都一致)
{
此标签就是节点分类
}
else
{
以最佳方式切分集合
每一种可能形成当前节点的一个分支
递归
}OK,上C#版代码,DataVector和上篇文章一样,不放了,只放核心算法:
using System;
using System.Collections.Generic;
namespace MachineLearning
{
/// <summary>
/// 决策树节点
/// </summary>
public class DecisionNode
{
/// <summary>
/// 此节点的分类标签,为空表示此节点不是叶节点
/// </summary>
public string Label { get; set; }
/// <summary>
/// 此节点的划分特征,为-1表示此节点是叶节点
/// </summary>
public int FeatureIndex { get; set; }
/// <summary>
/// 分支
/// </summary>
public Dictionary<string, DecisionNode> Child { get; set; }
public DecisionNode()
{
this.FeatureIndex = -1;
this.Child = new Dictionary<string, DecisionNode>();
}
}
}using System;
using System.Collections.Generic;
using System.Linq;
namespace MachineLearning
{
/// <summary>
/// 决策树(ID3算法)
/// </summary>
public class DecisionTree
{
private DecisionNode m_Tree;
/// <summary>
/// 训练
/// </summary>
/// <param name="trainingSet"></param>
public void Train(List<DataVector<string>> trainingSet)
{
var features = new List<int>(trainingSet[0].Dimension);
for(int i = 0;i < trainingSet[0].Dimension;++i)
features.Add(i);
//生成决策树
m_Tree = CreateTree(trainingSet, features);
}
/// <summary>
/// 分类
/// </summary>
/// <param name="vector"></param>
/// <returns></returns>
public string Classify(DataVector<string> vector)
{
return Classify(vector, m_Tree);
}
/// <summary>
/// 分类
/// </summary>
/// <param name="vector"></param>
/// <param name="node"></param>
/// <returns></returns>
private string Classify(DataVector<string> vector, DecisionNode node)
{
var label = string.Empty;
if(!string.IsNullOrEmpty(node.Label))
{
//是叶节点,直接返回结果
label = node.Label;
}
else
{
//取需要分类的字段,继续深入
var key = vector.Data[node.FeatureIndex];
if(node.Child.ContainsKey(key))
label = Classify(vector, node.Child[key]);
else
label = "[UNKNOWN]";
}
return label;
}
/// <summary>
/// 创建决策树
/// </summary>
/// <param name="dataSet"></param>
/// <param name="features"></param>
/// <returns></returns>
private DecisionNode CreateTree(List<DataVector<string>> dataSet, List<int> features)
{
var node = new DecisionNode();
if(dataSet[0].Dimension == 0)
{
//所有字段已用完,按多数原则决定Label,结束分类
node.Label = GetMajorLabel(dataSet);
}
else if(dataSet.Count == dataSet.Count(d => string.Equals(d.Label, dataSet[0].Label)))
{
//如果数据集中的Label相同,结束分类
node.Label = dataSet[0].Label;
}
else
{
//挑选一个最佳分类,分割集合,递归
int featureIndex = ChooseBestFeature(dataSet);
node.FeatureIndex = features[featureIndex];
var uniqueValues = GetUniqueValues(dataSet, featureIndex);
features.RemoveAt(featureIndex);
foreach(var value in uniqueValues)
{
node.Child[value.ToString()] = CreateTree(SplitDataSet(dataSet, featureIndex, value), new List<int>(features));
}
}
return node;
}
/// <summary>
/// 计算给定集合的香农熵
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private double ComputeShannon(List<DataVector<string>> dataSet)
{
double shannon = 0.0;
var dict = new Dictionary<string, int>();
foreach(var item in dataSet)
{
if(!dict.ContainsKey(item.Label))
dict[item.Label] = 0;
dict[item.Label] += 1;
}
foreach(var label in dict.Keys)
{
double prob = dict[label] * 1.0 / dataSet.Count;
shannon -= prob * Math.Log(prob, 2);
}
return shannon;
}
/// <summary>
/// 用给定的方式切分出数据子集
/// </summary>
/// <param name="dataSet"></param>
/// <param name="splitIndex"></param>
/// <param name="value"></param>
/// <returns></returns>
private List<DataVector<string>> SplitDataSet(List<DataVector<string>> dataSet, int splitIndex, string value)
{
var newDataSet = new List<DataVector<string>>();
foreach(var item in dataSet)
{
//只保留指定维度上符合给定值的项
if(item.Data[splitIndex] == value)
{
var newItem = new DataVector<string>(item.Dimension - 1);
newItem.Label = item.Label;
Array.Copy(item.Data, 0, newItem.Data, 0, splitIndex - 0);
Array.Copy(item.Data, splitIndex + 1, newItem.Data, splitIndex, item.Dimension - splitIndex - 1);
newDataSet.Add(newItem);
}
}
return newDataSet;
}
/// <summary>
/// 在给定的数据集上选择一个最好的切分方式
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private int ChooseBestFeature(List<DataVector<string>> dataSet)
{
int bestFeature = 0;
double bestInfoGain = 0.0;
double baseShannon = ComputeShannon(dataSet);
//遍历每一个维度来寻找
for(int i = 0;i < dataSet[0].Dimension;++i)
{
var uniqueValues = GetUniqueValues(dataSet, i);
double newShannon = 0.0<spa 本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!
喜欢 | 0
不喜欢 | 0
您输入的评论内容中包含违禁敏感词
我知道了

请输入正确的手机号码
请输入正确的验证码
您今天的短信下发次数太多了,明天再试试吧!
我们会在第一时间安排职业规划师联系您!
您也可以联系我们的职业规划师咨询:
版权所有 职坐标-一站式AI+学习就业服务平台 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
沪公网安备 31011502005948号