scikit-learn 提供的绘图工具
Scikit-learn定义了一个简单的API,创建用于机器学习的可视化对象。该API的特点是无需重新计算即可进行快速绘图和视觉调整。在以下示例中,我们绘制了利用支持向量机算法产生的ROC曲线:
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)
svc_disp = plot_roc_curve(svc, X_test, y_test)
返回的svc_disp
对象使我们可以在以后的图中继续使用已经计算出的SVC的ROC曲线。在本例中,svc_disp
是一个 RocCurveDisplay
,它将计算得到的值储存到称作roc_auc
,fpr
,和tpr
的属性中。接下来,我们训练一个随机森林分类器,并使用Display
对象的plot
方法再次绘制先前计算的roc曲线。
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
rfc = RandomForestClassifier(random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)
请注意,我们传递alpha=0.8
给绘图函数来调整曲线的透明度。
例子: |
---|
带有可视化API的ROC曲线 局部依赖的高级绘图 显示对象的可视化 |
5.1.1 函数
inspection.plot_partial_dependence (…[, …]) |
部分依赖图。 |
---|---|
metrics.plot_confusion_matrix (estimator, X, …) |
绘制混淆矩阵。 |
metrics.plot_precision_recall_curve (…[, …]) |
绘制二元分类器的精确度、召回率曲线。 |
metrics.plot_roc_curve (estimator, X, y, *) |
绘制受试者工作特性(ROC)曲线。 |
5.1.2 可视化对象
inspection.PartialDependenceDisplay (…) |
部分依赖图(PDP)可视化。 |
---|---|
metrics.ConfusionMatrixDisplay (…[, …]) |
混淆矩阵可视化。 |
metrics.PrecisionRecallDisplay (precision, …) |
精确度、召回率可视化。 |
metrics.RocCurveDisplay (*, fpr, tpr[, …]) |
ROC曲线可视化。 |