codecamp

scikit-learn 提供的绘图工具

Scikit-learn定义了一个简单的API,创建用于机器学习的可视化对象。该API的特点是无需重新计算即可进行快速绘图和视觉调整。在以下示例中,我们绘制了利用支持向量机算法产生的ROC曲线:

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

svc_disp = plot_roc_curve(svc, X_test, y_test)


返回的svc_disp对象使我们可以在以后的图中继续使用已经计算出的SVC的ROC曲线。在本例中,svc_disp是一个 RocCurveDisplay,它将计算得到的值储存到称作roc_aucfpr,和tpr的属性中。接下来,我们训练一个随机森林分类器,并使用Display对象的plot 方法再次绘制先前计算的roc曲线。

import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier

rfc = RandomForestClassifier(random_state=42)
rfc.fit(X_train, y_train)

ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)


请注意,我们传递alpha=0.8给绘图函数来调整曲线的透明度。

例子:
带有可视化API的ROC曲线
局部依赖的高级绘图
显示对象的可视化

5.1.1 函数

inspection.plot_partial_dependence(…[, …]) 部分依赖图。
metrics.plot_confusion_matrix(estimator, X, …) 绘制混淆矩阵。
metrics.plot_precision_recall_curve(…[, …]) 绘制二元分类器的精确度、召回率曲线。
metrics.plot_roc_curve(estimator, X, y, *) 绘制受试者工作特性(ROC)曲线。

5.1.2 可视化对象

inspection.PartialDependenceDisplay(…) 部分依赖图(PDP)可视化。
metrics.ConfusionMatrixDisplay(…[, …]) 混淆矩阵可视化。
metrics.PrecisionRecallDisplay(precision, …) 精确度、召回率可视化。
metrics.RocCurveDisplay(*, fpr, tpr[, …]) ROC曲线可视化。


scikit-learn 基于排列的特征重要性
scikit-learn 管道和复合估算器
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

scikit-learn 用户指南

scikit-learn 5.可视化

scikit-learn 7.数据集加载实用程序

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }