Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions Decision Tree/Decision Tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from math import log
import operator
import pickle
from collections import defaultdict

"""
函数说明:计算给定数据集的经验熵(香农熵)
Expand All @@ -21,7 +22,7 @@
"""
def calcShannonEnt(dataSet):
numEntires = len(dataSet) #返回数据集的行数
labelCounts = {} #保存每个标签(Label)出现次数的字典
labelCounts = defaultdict(int) #简化初始化
for featVec in dataSet: #对每组特征向量进行统计
currentLabel = featVec[-1] #提取标签(Label)信息
if currentLabel not in labelCounts.keys(): #如果标签(Label)没有放入统计次数的字典,添加进去
Expand Down Expand Up @@ -204,7 +205,7 @@ def getNumLeafs(myTree):
firstStr = next(iter(myTree)) #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
secondDict = myTree[firstStr] #获取下一组字典
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
if isinstance(secondDict[key],dict): #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
Expand All @@ -228,7 +229,7 @@ def getTreeDepth(myTree):
firstStr = next(iter(myTree)) #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
secondDict = myTree[firstStr] #获取下一个字典
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
if isinstance(secondDict[key],dict): #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth #更新层数
Expand All @@ -252,11 +253,21 @@ def getTreeDepth(myTree):
2017-07-24
"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
arrow_args = dict(arrowstyle="<-") #定义箭头格式
font = FontProperties(fname=r"c:\windows\fonts\simsunb.ttf", size=14) #设置中文字体
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', #绘制结点
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
arrow_args = dict(arrowstyle="<-") # 定义箭头格式
font = FontProperties(fname=r"c:\windows\fonts\simsunb.ttf", size=14) # 设置中文字体

createPlot.ax1.annotate(
nodeTxt,
xy=parentPt,
xycoords='axes fraction',
xytext=centerPt,
textcoords='axes fraction',
va="center",
ha="center",
bbox=nodeType,
arrowprops=arrow_args,
fontproperties=font # 直接传递 FontProperties 对象
)

"""
函数说明:标注有向边属性值
Expand Down Expand Up @@ -361,7 +372,7 @@ def classify(inputTree, featLabels, testVec):
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
if isinstance(secondDict[key],dict):
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
Expand Down Expand Up @@ -407,9 +418,10 @@ def grabTree(filename):
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
labels_copy = labels.copy() # 创建副本
myTree = createTree(dataSet, labels, featLabels)
createPlot(myTree)
testVec = [0,1] #测试数据
testVec = [0,1,1,0] #测试数据
result = classify(myTree, featLabels, testVec)
if result == 'yes':
print('放贷')
Expand Down
Loading