11.8.保存模型
Windows 10
Python 3.7.3 @ MSC v.1915 64 bit (AMD64)
Latest build date 2020.05.10
sklearn version: 0.22.1
import pickle
import joblib
from sklearn import svm
from sklearn import datasets
clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X, y)
SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None,
coef0=0.0,
decision_function_shape='ovr', degree=3, gamma='scale',
kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
pickle
可以通过使用 Python 的内置持久化模型将训练好的模型保存在 scikit 中,它名为 pickle:
# 序列化为 bytes
model = pickle.dumps(clf)
# 从 bytes 加载模型
clf2 = pickle.loads(data=model, fix_imports=True,
encoding='ASCII', errors='strict')
# 保存到文件
svm.model = pickle.dump(obj=clf, file="svm.model",
protocol=None, fix_imports=True)
# 从文件加载模型
clf2 = pickle.load(file="svm.model", fix_imports=True,
encoding='ASCII', errors='strict')
joblib
使用 joblib 来替换 pickle 可能会更有意思,这对于内部带有 numpy 数组的对象来说更为高效, 通常情况下适合 scikit-learn estimators(预估器),但是也只能是使用 pickle 保存模型到硬盘的情况,而不是pickle 保存模型到字符串。
joblib.dump(value=clf, filename="svm.model",
compress=0, protocol=None, cache_size=None)
joblib.load(filename="svm.model")
安全性和可维护性的局限性
pickle(和通过扩展的 joblib),在安全性和可维护性方面存在一些问题。 有以下原因:
- 绝对不要使用未经 pickle 的不受信任的数据,因为它可能会在加载时执行恶意代码。
- 虽然一个版本的 scikit-learn 模型可以在其他版本中加载,但这完全不建议并且也是不可取的。 还应该了解到,对于这些数据执行的操作可能会产生不同及意想不到的结果。
为了用以后版本的 scikit-learn 来重构类似的模型, 额外的元数据应该随着 pickled model 一起被保存:
- 训练数据,例如:引用不可变的快照
- 用于生成模型的 python 源代码
- scikit-learn 的各版本以及各版本对应的依赖包
- 在训练数据的基础上获得的交叉验证得分
这样可以检查交叉验证得分是否与以前相同。
由于模型内部表示可能在两种不同架构上不一样,因此不支持在一个架构上转储模型并将其加载到另一个体系架构上。
如果您想要了解更多关于这些问题以及其它可能的序列化方法,请参阅这个 Alex Gaynor 的演讲。