跳转至

11.8.保存模型

Windows 10
Python 3.7.3 @ MSC v.1915 64 bit (AMD64)
Latest build date 2020.05.10
sklearn version:  0.22.1
import pickle
import joblib
from sklearn import svm
from sklearn import datasets

clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X, y)
SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None,
coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale',
kernel='rbf',
    max_iter=-1, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

pickle

可以通过使用 Python 的内置持久化模型将训练好的模型保存在 scikit 中,它名为 pickle:

# 序列化为 bytes
model = pickle.dumps(clf)

# 从 bytes 加载模型
clf2 = pickle.loads(data=model, fix_imports=True,
                   encoding='ASCII', errors='strict')


# 保存到文件
svm.model = pickle.dump(obj=clf, file="svm.model", 
                        protocol=None, fix_imports=True)

# 从文件加载模型
clf2 = pickle.load(file="svm.model", fix_imports=True,
                   encoding='ASCII', errors='strict')

joblib

使用 joblib 来替换 pickle 可能会更有意思,这对于内部带有 numpy 数组的对象来说更为高效, 通常情况下适合 scikit-learn estimators(预估器),但是也只能是使用 pickle 保存模型到硬盘的情况,而不是pickle 保存模型到字符串。

joblib.dump(value=clf, filename="svm.model",
            compress=0, protocol=None, cache_size=None)

joblib.load(filename="svm.model")

安全性和可维护性的局限性

pickle(和通过扩展的 joblib),在安全性和可维护性方面存在一些问题。 有以下原因:

  • 绝对不要使用未经 pickle 的不受信任的数据,因为它可能会在加载时执行恶意代码。
  • 虽然一个版本的 scikit-learn 模型可以在其他版本中加载,但这完全不建议并且也是不可取的。 还应该了解到,对于这些数据执行的操作可能会产生不同及意想不到的结果。

为了用以后版本的 scikit-learn 来重构类似的模型, 额外的元数据应该随着 pickled model 一起被保存:

  • 训练数据,例如:引用不可变的快照
  • 用于生成模型的 python 源代码
  • scikit-learn 的各版本以及各版本对应的依赖包
  • 在训练数据的基础上获得的交叉验证得分

这样可以检查交叉验证得分是否与以前相同。

由于模型内部表示可能在两种不同架构上不一样,因此不支持在一个架构上转储模型并将其加载到另一个体系架构上。

如果您想要了解更多关于这些问题以及其它可能的序列化方法,请参阅这个 Alex Gaynor 的演讲