codecamp

AI人工智能 决策树分类器

决策树基本上是一个二叉树流程图,其中每个节点根据某个特征变量分割一组观察值。

在这里,我们正在构建一个用于预测男性或女性的决策树分类器。这里将采取一个非常小的数据集,有 19 个样本。 这些样本将包含两个特征 - “身高”和“头发长度”。

前提条件

为了构建以下分类器,我们需要安装 pydotplusgraphviz。 基本上,graphviz 是使用点文件绘制图形的工具,pydotplus 是 Graphviz 的 Dot 语言模块。 它可以与包管理器或使用pip来安装。

现在,可以在以下 Python 代码的帮助下构建决策树分类器 -

首先,导入一些重要的库如下 -

import pydotplus
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.metrics import classification_report
from sklearn import cross_validation
import collections

现在,提供如下数据集 -

X = [[165,19],[175,32],[136,35],[174,65],[141,28],[176,15],[131,32],
[166,6],[128,32],[179,10],[136,34],[186,2],[126,25],[176,28],[112,38],
[169,9],[171,36],[116,25],[196,25]]


Y = ['Man','Woman','Woman','Man','Woman','Man','Woman','Man','Woman',
'Man','Woman','Man','Woman','Woman','Woman','Man','Woman','Woman','Man']
data_feature_names = ['height','length of hair']


X_train, X_test, Y_train, Y_test = cross_validation.train_test_split
(X, Y, test_size=0.40, random_state=5)

在提供数据集之后,需要拟合可以如下完成的模型 -

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X,Y)

预测可以使用以下 Python 代码来完成 -

prediction = clf.predict([[133,37]])
print(prediction)

使用以下 Python 代码来实现可视化决策树 -

dot_data = tree.export_graphviz(clf,feature_names = data_feature_names,
            out_file = None,filled = True,rounded = True)
graph = pydotplus.graph_from_dot_data(dot_data)
colors = ('orange', 'yellow')
edges = collections.defaultdict(list)


for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))


for edge in edges: edges[edge].sort()


for i in range(2):dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])
graph.write_png('Decisiontree16.png')

它会将上述代码的预测作为[‘Woman’]并创建以下决策树 -

img

可以改变预测中的特征值来测试它。

AI人工智能 逻辑回归
AI人工智能 随机森林分类器
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

AI人工智能监督学习(回归)

AI人工智能无监督学习:聚类

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }