改-机器学习-决策树(Python实现)

###什么是决策树

  • 决策树(Decision Tree)是一种树型分支结构的决策模型。
  • 它以信息论中的香浓熵作为基本划分依据。
  • 它的优势是模型便于理解、分类效果优异、生产上实现简单等。
  • 著名算法:ID3、C4.5、C5.0、CART等
  • R:C50、rpart
  • Python: sklearn(tree)

###效果

  • tree

###python实现

trees.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

#!/usr/bin/python
# -*- coding: utf-8 -*-

from math import log

#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
# 总的训练数据样本条数
numEntries = len(dataSet)
# 类标签 每条样本所属类别
labelCounts={}
# 遍历每条样本数据
for featVec in dataSet:
# 每条最后一列为各自的类别
currentLabel = featVec[-1]
'''
为所有可能的类别取值建立key value 结构
key 类别 value 表示该类出现的次数
'''
# 初始化
if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0
# 出现+1
labelCounts[currentLabel] += 1

# 保存信息熵
shannonEnt = 0.0

# 样本遍历完毕后 计算各类别的占总样本的比例
for key in labelCounts: # 遍历词典<key,value>
# 计算该类别的比例
prob=float(labelCounts[key])/numEntries
# 计算信息增益,以2为底的对数
shannonEnt -= prob* log(prob,2)

# 返回数据集的熵
return shannonEnt

#计算条件熵,划分数据集 按照给定的特征划分数据集
def splitDataSet(dataSet,axis,value):
# 定义新变量 保存划分后的数据集
retDataSet=[]

# 遍历数据集的每条数据
for featVec in dataSet:
# 符合要求的数据抽取出来存入retDataSet
if featVec[axis]==value:
# 除了给定的特征axis和值value,其他整行被保存下来
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
#保存去除该行后的数据
retDataSet.append(reducedFeatVec)

# 返回去除制定特征值行的数据 便于计算该条件下的条件熵
return retDataSet

# 选取最好的数据集划分方式 返回最佳特征下标
# 最好的特征即为信息增益最大的特征
def chooseBestFeatureToSplit(dataSet):

#保存特征个数 最后一列为类标签 -1
numFeatures = len(dataSet[0])-1
# 计算整个数据集的香农熵
baseEntropy = calcShannonEnt(dataSet)
# 保存最大信息增益值
bestInfoGain = 0.0
# 保存信息增益最大的特征
bestFeature = -1

# 循环遍历数据集中的所有特征
for i in range(numFeatures):

# 取得当前特征对应列下的值
featList = [example[i] for example in dataSet]
# 当前特征对应值去重 每个特征值唯一
uniqueVals = set(featList)

# 保存对应特征值的条件熵
newEntropy = 0.0
#遍历特征值
for value in uniqueVals:
#根据当前特征值划分子集
subDataSet = splitDataSet(dataSet,i,value)
# 计算子集记录数与总记录数 子集比率
prob = len(subDataSet)/float(len(dataSet))
# 计算子集记录的熵
newEntropy += prob*calcShannonEnt(subDataSet)

# 信息增益 = 数据集的熵 - 数据集划分后的熵
infoGain = baseEntropy - newEntropy

# 最好的特征即为信息增益最大的特征
if(infoGain>bestInfoGain):
bestInfoGain = infoGain
bestFeature = i

# 返回最好的特征
return bestFeature

# 递归构建决策树
# 其中当所有的特征都用完时,采用多数表决的方法来决定该叶子节点的分类
# 即该叶节点中属于某一类最多的样本数,那么我们就说该叶节点属于那一类!
def majorityCnt(classList):
# 每个类别出现的次数
classCount = {}
# 遍历数据集中的类别
for vote in classList:
# 初始类别第一次加入字典
if vote not in classCount.keys():classCount[vote]=0
# 记录次数
classCount[vote]+=1

# 遍历结束后 次数value值从大到小排序
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
# 返回数量最多的类别
return sortedClassCount[0][0]

#创建树(数据集,特征名)
def createTree(dataSet,labels):

# 取出数据集最后一列 训练数据的类标签
classList = [example[-1] for example in dataSet]

# 类别完全相同停止划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果数据集没有特征 返回数据集中数量最多的类别
if len(dataSet[0]) == 1:return majorityCnt(classList)

# 选取数据集中的最佳划分子集特征
bestFeat = chooseBestFeatureToSplit(dataSet)
# 将该特征名作为树根的节点
bestFeatLabel = labels[bestFeat]

#print 'bestFeatLabel:'+ str(bestFeatLabel)

# 初始赋值决策树
myTree = {bestFeatLabel:{}}
# 删除已选择的特征名
del(labels[bestFeat])

#取得最后划分特征的下标
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) # 去重

#遍历去重后的特征值
for value in uniqueVals:
# 获得除已删除的特征外 其余特征的名称
subLabels = labels[:]

print "最佳划分特征:"+bestFeatLabel+" 值:"+value
print '其余特征:'+ str(subLabels)

'''
以当前的特征值划分子集,以子集为参数 递归调用创建树的方法
将递归调用的结果作为树节点的一个分枝
'''
myTree[bestFeatLabel][value] = \
createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
print "树值:"+str(myTree[bestFeatLabel][value])
print "树:"+str(myTree)+"\n"


return myTree

#预测决策树(树,叶标签,测试数据向量)
def classify(inputTree,featLabels,testVec):

#树的根节点
firstStr = inputTree.keys()[0]
#树的后续分支
secondDict = inputTree[firstStr]
#叶标签的位置
featIndex = featLabels.index(firstStr)

#循环查找后续分支
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else:
classLabel = secondDict[key]

return classLabel

#保存决策树
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()

#加载决策树
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)



import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

#获取叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#测试节点的数据是否为字典,以此判断是否为叶节点
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs

#获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#测试节点的数据是否为字典,以此判断是否为叶节点
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth

#绘制节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

#绘制连接线
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

#绘制树结构
def plotTree(myTree, parentPt, nodeTxt):
#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree)
#this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0]
#the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
#test to see if the nodes are dictonaires, if not they are leaf nodes
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else:
#it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

#创建决策树图形
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()


测试程序

trees_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14

# exec(open('trees_test.py').read())

import trees
fr=open('lenses.txt')
lenses = [ inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age','prescript','astigmatic','terRate']
lensesTree = trees.createTree(lenses,lensesLabels)
print lensesTree

print trees.classify(lensesTree,['age','prescript','astigmatic','terRate'],
['presbyopic','hyper','no','reduced'])

trees.createPlot(lensesTree)

数据文件

  • UCI数据集(隐形眼镜
    lenses.txt
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    young    myope	no	reduced	no_lenses
    young myope no normal soft
    young myope yes reduced no_lenses
    young myope yes normal hard
    young hyper no reduced no_lenses
    young hyper no normal soft
    young hyper yes reduced no_lenses
    young hyper yes normal hard
    pre myope no reduced no_lenses
    pre myope no normal soft
    pre myope yes reduced no_lenses
    pre myope yes normal hard
    pre hyper no reduced no_lenses
    pre hyper no normal soft
    pre hyper yes reduced no_lenses
    pre hyper yes normal no_lenses
    presbyopic myope no reduced no_lenses
    presbyopic myope no normal no_lenses
    presbyopic myope yes reduced no_lenses
    presbyopic myope yes normal hard
    presbyopic hyper no reduced no_lenses
    presbyopic hyper no normal soft
    presbyopic hyper yes reduced no_lenses
    presbyopic hyper yes normal no_lenses