多分类Logistic回归ROC曲线
首先回顾一下在二分类问题中,ROC曲线的概念和绘制方法。
ROC曲线是通过可视化的方法实现模型好坏的评估,依赖于两个指标值进行绘制,x轴为1-Specificity即负例错判率;y轴为Sensitivity,即正例覆盖率。ROC曲线的绘制思想是,通过考虑不同阈值(即分类判别标准)下Sensitivity和1-Specificity之间的组合变化,然后通过这些不同组合所构成的点所组成的折线,并计算折线下所围成部分的面积(AUC),面积越大,说明模型的效果越好;一般而言,AUC只要在0.8以上,模型就可以基本接受。
简而言之,在二分类问题中,要绘制ROC曲线首先要需要确定一个分类判别的概率标准。比如给定一个阈值0.6,在此判别标准下利用所构建的模型对样本测试集进行预测,所得到的样本测试集中概率大于0.6的我们视其为1分类,小于0.6的将其视作0分类;于是在阈值为0.6时,根据混淆矩阵的定义,可以得出此种条件下模型的正例覆盖率(True Positive Rate)和负例错判率(False Positive Rate)。以此类推,当阈值发生变化时,可以得到一系列随着阈值而变化的TPR和FPR;将FPR值的大小作为x轴,TPR的值的大小作为y轴,将上述过程中所得到的所有的点画出,即成为了模型的ROC曲线,而AUC即是该线段下的面积。
对于多分类问题,假设因变量取值为n种分类,所得的模型的预测值为在n种不同分类下的概率值。也就是对于样本测试集中的每一个样本来说,存在着n个预测出的概率值,分别对应着n个种类,其意义为通过模型预测出属于该分类的概率。而通常认为,在n个预测之中最大值对应的那个分类,就是该样本所属的分类。可以构造一个标签矩阵,根据模型预测值矩阵,对每一行中最大值所在的列记为1,表示该样本属于该种类;其余元素记为0,表示该样本不属于此种类。标签矩阵的行数为样本测试集的行数,列数为种类数n。
因此,在多分类问题中如何去确定TPR和FPR的值成为了绘制ROC曲线的关键。这里介绍两种方法:Micro-average和Macro-average。
Micro-average方法
指通过所构建的标签矩阵和预测值矩阵,将其分别按行展开,得到一个标签数组和预测值(概率)数组,由于标签数组由元素1和0组成,于是可以将这两个数组以列的形式构造成一个矩阵,成为一个二分类问题的预测结果。
以一个n=3,样本数为2的测试集为例有:
概率矩阵:
构造标签矩阵:
经上述过程构造得到:
因此可以按照二分类问题中的ROC曲线的绘制方法去处理得到对应的ROC曲线(利用sklearn.metrics.roc_auc_score计算对应的TPR和FPR,sklearn.metrics.auc计算AUC值)
Macro-average方法
指通过所构建的标签矩阵和预测值矩阵,在每种类别下,都可以得到测试样本属于此类别的概率。所以,根据概率矩阵以及标签矩阵中对应的每一列,可以计算出各个分类下的FPR和TPR的一系列组合值,可以得到不同分类下的ROC曲线,共有n条;之后对n条ROC曲线的TPR值取算数平均值作为最终的TPR值,即可得到最终的ROC曲线。
在计算过程中,需将不同分类下的FPR值整合到一个最终的FPR值中,在计算其对应的TPR平均值时,由于可能缺失FPR所对应的TPR,可先利用插补法(numpy.interp)得到在其对应的TPR值。
代码案例
# 引入必要的库
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 将标签二值化
y = label_binarize(y, classes=[0, 1, 2])
# 设置种类
n_classes = y.shape[1]
# 训练模型并预测
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=.5,random_state=0)
# Learn to predict each class against the other
classifier = OneVsRestClassifier(svm.SVC(kernel='linear',probability=True,
random_state=random_state))
y_score = classifier.fit(X_train, y_train).decision_function(X_test)
# 计算每一类的ROC
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ =roc_curve(y_test[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i],tpr[i])
# Compute micro-average ROC curve and ROC area(方法二)
fpr["micro"], tpr["micro"], _ =roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"],tpr["micro"])
# Compute macro-average ROC curve and ROC area(方法一)
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i inrange(n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"],tpr["macro"])
# Plot all ROC curves
lw=2
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label='micro-averageROC curve (area = {0:0.2f})'
''.format(roc_auc["micro"]),
color='deeppink',linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label='macro-averageROC curve (area = {0:0.2f})'
''.format(roc_auc["macro"]),
color='navy',linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i],color=color, lw=lw,
label='ROC curveof class {0} (area = {1:0.2f})'
''.format(i,roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic tomulti-class')
plt.legend(loc="lower right")
plt.show()
参考文献
ROC原理介绍及利用python实现二分类和多分类的ROC曲线.闰土不用叉.CSDN博客(https://blog.csdn.net/xyz1584172808/article/details/81839230)