小标
2018-10-22
来源 :
阅读 1734
评论 0
摘要:本文主要向大家介绍了机器学习入门之机器学习实战笔记--决策树,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。
本文主要向大家介绍了机器学习入门之机器学习实战笔记--决策树,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助。
tree.py代码
1 #encoding:utf-8
2 from math import log
3 import operator
4 import treePlotter as tp
5
6
7 def createDataSet(): #简单测试数据创建
8 dataSet = [[1, 1, ‘yes‘],
9 [1, 1, ‘yes‘],
10 [1, 0, ‘no‘],
11 [0, 1, ‘no‘],
12 [0, 1, ‘no‘]]
13 labels = [‘no surfacing‘, ‘flippers‘]
14 # change to discrete values
15 return dataSet, labels
16
17
18 def calcShannonEnt(dataSet): #计算给定数据集的香农熵
19 numEntries = len(dataSet)
20 labelCounts = {}
21 for featVec in dataSet:
22 currentLabel = featVec[-1]
23 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
24 labelCounts[currentLabel] += 1
25 shannonEnt = 0.0
26 for key in labelCounts:
27 prob = float(labelCounts[key]) / numEntries
28 shannonEnt -= prob * log(prob, 2)
29 return shannonEnt
30
31 #按照给定特征划分数据集
32 def splitDataSet(dataSet, axis, value): #dataSet:数据集 axis:下标(用于指定哪个特征) value:该特征的值
33 retDataSet = []
34 for featVec in dataSet:
35 if featVec[axis] == value:
36 reducedFeatVec = featVec[:axis]
37 reducedFeatVec.extend(featVec[axis + 1:])
38 retDataSet.append(reducedFeatVec) #reducedFeatVec中没有指定的那个特征值了,注意append和extend的区别
39 return retDataSet
40
41
42 def chooseBestFeatureToSplit(dataSet):
43 numFeatures = len(dataSet[0]) - 1
44 baseEntropy = calcShannonEnt(dataSet)
45 bestInfoGain = 0.0;
46 bestFeature = -1
47 for i in range(numFeatures): #第i列
48 featList = [example[i] for example in dataSet]
49 uniqueVals = set(featList) #创建唯一的分类标签列表
50 newEntropy = 0.0
51 for value in uniqueVals: #计算每种分类方式的信息熵,并加到总的熵,一个特征可能有多个值
52 subDataSet = splitDataSet(dataSet, i, value)
53 prob = len(subDataSet) / float(len(dataSet))
54 newEntropy += prob * calcShannonEnt(subDataSet) #总的熵
55 infoGain = baseEntropy - newEntropy # 得到信息增益
56 if (infoGain > bestInfoGain): #如果更好,则更新
57 bestInfoGain = infoGain
58 bestFeature = i
59 return bestFeature #返回最好的第几列,整型
60
61
62 def majorityCnt(classList):
63 classCount = {} #类似于map
64 for vote in classList: #统计分类名称出现的次数
65 if vote not in classCount.keys(): classCount[vote] = 0
66 classCount[vote] += 1
67 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) #排序
68 return sortedClassCount[0][0] #返回出现次数最多的分类名称
69
70
71 def createTree(dataSet, labels): #构建决策树
72 classList = [example[-1] for example in dataSet]
73 if classList.count(classList[0]) == len(classList): #类别相同则停止划分
74 return classList[0]
75 if len(dataSet[0]) == 1: # 遍历完所有特征值时返回最多的
76 return majorityCnt(classList)
77 bestFeat = chooseBestFeatureToSplit(dataSet) #最佳划分
78 bestFeatLabel = labels[bestFeat] #最佳划分属性名
79 myTree = {bestFeatLabel: {}}
80 del (labels[bestFeat]) #删除该属性
81 featValues = [example[bestFeat] for example in dataSet]
82 uniqueVals = set(featValues) #得到列表包含的所有特征值
83 for value in uniqueVals:
84 subLabels = labels[:]
85 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
86 return myTree #返回决策树
87
88
89 def classify(inputTree, featLabels, testVec):
90 firstStr = inputTree.keys()[0] #inputTree根字符串
91 secondDict = inputTree[firstStr] #形如 0:‘yes‘, 1:{下一级树}
92 featIndex = featLabels.index(firstStr) #将标签字符串转化为索引
93 key = testVec[featIndex] #testVec当前属性下的值
94 valueOfFeat = secondDict[key] #值为key的下一级树
95 if isinstance(valueOfFeat, dict): #valueOfFeat为字典数据类型时,递归
96 classLabel = classify(valueOfFeat, featLabels, testVec)
97 else:
98 classLabel = valueOfFeat #否则就是当前结果
99 return classLabel
100
101
102 def storeTree(inputTree, filename): #决策树的存储
103 import pickle
104 fw = open(filename, ‘w‘)
105 pickle.dump(inputTree, fw)
106 fw.close()
107
108
109 def grabTree(filename): #决策树的读取
110 import pickle
111 fr = open(filename)
112 return pickle.load(fr)
113
114 if __name__ == ‘__main__‘:
115 # dataSet, labels = createDataSet()
116 # print dataSet
117 # print labels
118 # shannonEnt = calcShannonEnt(dataSet)
119 # print "香农熵为 %f" % (shannonEnt)
120 # myMat = splitDataSet(dataSet,0,1)
121 # print myMat
122 # index = chooseBestFeatureToSplit(dataSet)
123 # print index
124 #mytree = createTree(dataSet, labels)
125 # print "决策树:"
126 # print mytree
127 # myTree = tp.retrieveTree(0)
128 # print myTree
129 # storeTree(myTree,‘myTree.txt‘)
130 # myTree = grabTree(‘myTree.txt‘)
131 # print myTree
132 # print classify(myTree,labels,[1,0])
133
134 #决策树预测隐形眼镜类型
135 fr = open(‘lenses.txt‘)
136 lenses = [line.strip().split(‘\t‘) for line in fr.readlines()]
137 lensesLabels = [‘age‘,‘prescript‘,‘astigmatic‘,‘tearRate‘]
138 lensesTree = createTree(lenses,lensesLabels)
139 print lensesTree
140 tp.createPlot(lensesTree)
treePlotter.py代码
1 #encoding:utf-8
2 import matplotlib.pyplot as plt
3
4 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
5 leafNode = dict(boxstyle="round4", fc="0.8")
6 arrow_args = dict(arrowstyle="<-")
7
8
9 def getNumLeafs(myTree): #得到树的叶子节点数
10 numLeafs = 0
11 firstStr = myTree.keys()[0]
12 secondDict = myTree[firstStr]
13 for key in secondDict.keys():
14 if type(secondDict[
15 key]).__name__ == ‘dict‘: # test to see if the nodes are dictonaires, if not they are leaf nodes
16 numLeafs += getNumLeafs(secondDict[key])
17 else:
18 numLeafs += 1
19 return numLeafs
20
21
22 def getTreeDepth(myTree): #得到树的深度
23 maxDepth = 0
24 firstStr = myTree.keys()[0]
25 secondDict = myTree[firstStr]
26 for key in secondDict.keys():
27 if type(secondDict[
28 key]).__name__ == ‘dict‘: # test to see if the nodes are dictonaires, if not they are leaf nodes
29 thisDepth = 1 + getTreeDepth(secondDict[key])
30 else:
31 thisDepth = 1
32 if thisDepth > maxDepth: maxDepth = thisDepth
33 return maxDepth
34
35
36 def plotNode(nodeTxt, centerPt, parentPt, nodeType): #绘制带箭头的注解
37 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=‘axes fraction‘,
38 xytext=centerPt, textcoords=‘axes fraction‘,
39 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
40
41
42 def plotMidText(cntrPt, parentPt, txtString): #在父子节点间填充文本信息
43 xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
44 yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
45 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
46
47 #绘制树
48 def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
49 numLeafs = getNumLeafs(myTree) # this determines the x width of this tree
50 depth = getTreeDepth(myTree)
51 firstStr = myTree.keys()[0] # the text label for this node should be this
52 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
53 plotMidText(cntrPt, parentPt, nodeTxt)
54 plotNode(firstStr, cntrPt, parentPt, decisionNode)
55 secondDict = myTree[firstStr]
56 plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
57 for key in secondDict.keys():
58 if type(secondDict[
59 key]).__name__ == ‘dict‘: # test to see if the nodes are dictonaires, if not they are leaf nodes
60 plotTree(secondDict[key], cntrPt, str(key)) # recursion
61 else: # it‘s a leaf node print the leaf node
62 plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
63 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
64 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
65 plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
66
67
68 # if you do get a dictonary you know it‘s a tree, and the first element will be another dict
69
70 def createPlot(inTree):
71 fig = plt.figure(1, facecolor=‘white‘)
72 fig.clf()
73 axprops = dict(xticks=[], yticks=[])
74 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
75 # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
76 plotTree.totalW = float(getNumLeafs(inTree))
77 plotTree.totalD = float(getTreeDepth(inTree))
78 plotTree.xOff = -0.5 / plotTree.totalW;
79 plotTree.yOff = 1.0;
80 plotTree(inTree, (0.5, 1.0), ‘‘)
81 plt.show()
82
83
84 # def createPlot():
85 # fig = plt.figure(1, facecolor=‘white‘)
86 # fig.clf()
87 # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
88 # plotNode(‘a decision node‘, (0.5, 0.1), (0.1, 0.5), decisionNode)
89 # plotNode(‘a leaf node‘, (0.8, 0.1), (0.3, 0.8), leafNode)
90 # plt.show()
91
92 def retrieveTree(i):
93 listOfTrees = [{‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: ‘no‘, 1: ‘yes‘}}}},
94 {‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: {‘head‘: {0: ‘no‘, 1: ‘yes‘}}, 1: ‘no‘}}}}
95 ]
96 return listOfTrees[i]
97
98 # createPlot(thisTree)
99
100 if __name__ == ‘__main__‘:
101 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
102 leafNode = dict(boxstyle="round4", fc="0.8")
103 arrow_args = dict(arrowstyle="<-")
104 #createPlot()
105 myTree = retrieveTree(0)
106 createPlot(myTree)
107 # print myTree
108 # print getNumLeafs(myTree)
109 # print getTreeDepth(myTree)
本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标人工智能机器学习频道!
喜欢 | 0
不喜欢 | 0
您输入的评论内容中包含违禁敏感词
我知道了

请输入正确的手机号码
请输入正确的验证码
您今天的短信下发次数太多了,明天再试试吧!
我们会在第一时间安排职业规划师联系您!
您也可以联系我们的职业规划师咨询:
版权所有 职坐标-一站式AI+学习就业服务平台 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
沪公网安备 31011502005948号