from pandas import DataFrame
Windows 10
Python 3.7.3 @ MSC v.1915 64 bit (AMD64)
Latest build date 2020.05.05
sklearn version: 0.22.1
模型验证函数
Model validation
sklearn 提供了6个函数用来验证模型:
函数 | 说明 |
---|---|
cross_validate(estimator, X) |
通过交叉验证得到评估指标,并记录拟合/得分时间 |
cross_val_predict(estimator, X) |
为每个输入数据点生成交叉验证的估计 |
cross_val_score(estimator, X) |
通过交叉验证得到评估指标 |
learning_curve(estimator, X, y) |
学习曲线 |
permutation_test_score(…) |
通过重排列样本的label,估计交叉验证的显著性 |
validation_curve(estimator, …) |
验证曲线 |
cross_val_score:交叉评分
cross_val_score使用拆分器把数据拆分为k份,使用k-1份数据进行训练,再使用剩下的1份数据进行测试评分,返回每次测试评分值。
cross_val_score(estimator, X, y=None, groups=None, scoring=None,
cv=None, n_jobs=1,verbose=0, fit_params=None,
pre_dispatch='2*n_jobs')
y
:在监督学习中,尝试预测的目标变量groups
:like-array,with shape (n _samples,), optional。分组数组集的样本组标签。scoring
:-
cv
:交叉验证拆分数据集的策略。可能的取值为:对于integer或None的输入,如果估计器是一个分类器,同时y是either binary 或者多分类,将会默认使用StratifiedKFold拆分器,其他情况,将使用KFold拆分器。
-
None
,默认使用3-fold cross vaildation -
integer
,指定使用何种KFold -
CV splitter
,交叉验证生成器对象 -
一个可产生训练集、测试集的迭代器
-
-
n_jobs
:integer,用于计算的CPU数量,-1表示使用所有CPU。 verbose
: integer, optional,The verbosity level。fit_params
:dict,要传递给估计器的参数。
from sklearn.model_selection import cross_val_score
from sklearn.datasets import load_wine
from sklearn.svm import SVC
wine = load_wine()
svc = SVC(kernel='linear')
scores = cross_val_score(svc, wine.data, wine.target)
print('交叉验证得分:{}'.format(scores))
交叉验证得分:[0.88888889 0.94444444 0.97222222 1. 1. ]
cross_validate:交叉评分
cross_validate
和cross_val_score
的不同在于scoring
参数和返回值:
- 评估的时候可以选择多个指标
- 返回一个包含拟合时间、打分时间、测试分数、训练分数的字典。
参数:
-
scoring:
-
单个字符串(使用预定义的评分方法,参见 The scoring parameter: defining model evaluation rules)
-
可引用对象(可自定义评分方法,参见 Defining your scoring strategy from metric functions)
-
单个字符串:
from sklearn.model_selection import cross_validate
from sklearn.datasets import load_wine
from sklearn.svm import SVC
import pprint
wine = load_wine()
svc = SVC(kernel='linear')
scores = cross_validate(svc, wine.data, wine.target, scoring="precision_macro")
print(type(scores))
DataFrame(scores)
<class 'dict'>
fit_time score_time test_score
0 0.104000 0.002001 0.897436
1 0.119005 0.002005 0.944056
2 0.140989 0.002004 0.977778
3 0.121999 0.002005 1.000000
4 0.099990 0.002003 1.000000
多指标指定的时候可以是一个list,tuple或者set,元素是预先定义好的一些指标名称:
from sklearn.metrics.scorer import make_scorer
from sklearn.metrics import recall_score
scoring = {'prec_macro': 'precision_macro',
'rec_micro': make_scorer(recall_score, average='macro')}
scores = cross_validate(svc, wine.data, wine.target, scoring=scoring)
print(type(scores))
DataFrame(scores)
<class 'dict'>
fit_time score_time test_prec_macro test_rec_micro
0 0.112014 0.001981 0.897436 0.904762
1 0.074024 0.002000 0.944056 0.952381
2 0.124978 0.001999 0.977778 0.972222
3 0.097982 0.004998 1.000000 1.000000
4 0.119018 0.002000 1.000000 1.000000
也可以是一个字典,从指标名称对应到预先定义好的评分函数:
scoring = {'prec_macro': 'precision_macro',
'rec_micro': make_scorer(recall_score, average='macro')}
scores = cross_validate(svc, wine.data, wine.target, scoring=scoring)
print(type(scores))
DataFrame(scores)
<class 'dict'>
fit_time score_time test_prec_macro test_rec_micro
0 0.078000 0.003024 0.897436 0.904762
1 0.078000 0.002000 0.944056 0.952381
2 0.095998 0.001001 0.977778 0.972222
3 0.087999 0.002000 1.000000 1.000000
4 0.090978 0.186004 1.000000 1.000000
cross_val_predict:交叉预测
cross_val_predict
的使用方法类似于cross_val_score
,但cross_val_predict
返回的是一个使用交叉验证得到的预测值,而不是评分标准。
它的运行过程是这样的:使用交叉验证的方法计算出每次划分中测试集的预测值,直到所有数据都有了预测值。假如数据划分为$[1,2,3,4,5]$份,它先用$[1,2,3,4]$训练模型,计算出来第$5$份的预测值,然后用$[1,2,3,5]$计算出第4份的预测值,直到都结束为止。
from sklearn import datasets, linear_model
from sklearn.model_selection import cross_val_predict
diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]
lasso = linear_model.Lasso()
y_pred = cross_val_predict(lasso, X, y, cv=3)
permutation_test_score:显著性检验
为了检验交叉验证的得分是否有意义,sklearn提供了一个显著性检验的函数,其原理如下:
- 先进行交叉验证,得出模型评分。
- 打乱 y(分类标签) 的排序,
在标签随机排列后重复分类过程的技术。 然后 p 值由得到的分数大于最初得到的分类分数的分数百分比给出。
# https://scikit-learn.org/stable/auto_examples/feature_selection/plot_permutation_test_for_classification.html#sphx-glr-auto-examples-feature-selection-plot-permutation-test-for-classification-py
# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
# License: BSD 3 clause
# print(__doc__)
# import numpy as np
# import matplotlib.pyplot as plt
# from sklearn.svm import SVC
# from sklearn.model_selection import StratifiedKFold
# from sklearn.model_selection import permutation_test_score
# from sklearn import datasets
# # #############################################################################
# # Loading a dataset
# iris = datasets.load_iris()
# X = iris.data
# y = iris.target
# n_classes = np.unique(y).size
# # Some noisy data not correlated
# random = np.random.RandomState(seed=0)
# E = random.normal(size=(len(X), 2200))
# # Add noisy data to the informative features for make the task harder
# X = np.c_[X, E]
# svm = SVC(kernel='linear')
# cv = StratifiedKFold(2)
# score, permutation_scores, pvalue = permutation_test_score(
# svm, X, y, scoring="accuracy", cv=cv, n_permutations=100, n_jobs=1)
# print("Classification score %s (pvalue : %s)" % (score, pvalue))
# # #############################################################################
# # View histogram of permutation scores
# plt.hist(permutation_scores, 20, label='Permutation scores',
# edgecolor='black')
# ylim = plt.ylim()
# # BUG: vlines(..., linestyle='--') fails on older versions of matplotlib
# # plt.vlines(score, ylim[0], ylim[1], linestyle='--',
# # color='g', linewidth=3, label='Classification Score'
# # ' (pvalue %s)' % pvalue)
# # plt.vlines(1.0 / n_classes, ylim[0], ylim[1], linestyle='--',
# # color='k', linewidth=3, label='Luck')
# plt.plot(2 * [score], ylim, '--g', linewidth=3,
# label='Classification Score'
# ' (pvalue %s)' % pvalue)
# plt.plot(2 * [1. / n_classes], ylim, '--k', linewidth=3, label='Luck')
# plt.ylim(ylim)
# plt.legend()
# plt.xlabel('Score')
# plt.show()
validation_curve
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_validation_curve.html#sphx-glr-auto-examples-model-selection-plot-validation-curve-py
learning_curve
# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.learning_curve.html#sklearn.model_selection.learning_curve