scikit-learn 1.6 官网中文文档翻译已经上线了,欢迎使用:http://www.aidoczh.com/scikit-learn/


在训练完一个 scikit-learn 模型后,希望有一种方法可以将模型保存下来以备将来使用,而不必重新训练。下面的章节将为您提供一些关于如何持久化 scikit-learn 模型的提示。

9.1. Python 特定的序列化

可以使用 Python 内置的持久化模块 pickle 来保存 scikit-learn 模型:

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> X, y = datasets.load_iris(return_X_y=True)
>>> clf.fit(X, y)
SVC()

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0:1])
array([0])
>>> y[0]
0

在 scikit-learn 的特定情况下,使用 joblib 替代 pickle(dumpload)可能更好,因为它对于内部包含大型 numpy 数组的对象更高效,而这在拟合的 scikit-learn 估计器中经常发生,但它只能将模型保存到磁盘而不能保存到字符串:

>>> from joblib import dump, load
>>> dump(clf, 'filename.joblib')

稍后,您可以使用以下代码在其他 Python 进程中加载已保存的模型:

>>> clf = load('filename.joblib')

注意:

dumpload 函数也接受文件对象而不是文件名。有关使用 Joblib 进行数据持久化的更多信息,请参阅此处

当使用与保存模型的 scikit-learn 版本不一致的 scikit-learn 版本来反序列化估计器时,会引发 InconsistentVersionWarning。可以捕获此警告以获取估计器的原始版本:

from sklearn.exceptions import InconsistentVersionWarning
warnings.simplefilter("error", InconsistentVersionWarning)

try:
    est = pickle.loads("model_from_prevision_version.pickle")
except InconsistentVersionWarning as w:
    print(w.original_sklearn_version)

9.1.1. 安全性和可维护性限制

pickle(以及通过扩展的 joblib)在可维护性和安全性方面存在一些问题。因此,

  • 永远不要对不受信任的数据进行反序列化,因为这可能导致在加载时执行恶意代码。
  • 尽管使用一个版本的 scikit-learn 保存的模型可能在其他版本中加载,但这完全不受支持且不可取。还应该记住,对这些数据执行的操作可能会产生不同和意外的结果。

为了能够使用未来版本的 scikit-learn 重新构建类似的模型,应该在序列化的模型中保存附加的元数据:

  • 训练数据,例如对不可变快照的引用
  • 用于生成模型的 Python 源代码
  • scikit-learn 及其依赖的版本
  • 在训练数据上获得的交叉验证分数

这样可以检查交叉验证分数是否与之前的范围相同。

除了一些例外情况外,pickle 的模型在相同版本的依赖项和 Python 使用的情况下应该可以在不同架构之间移植。如果遇到无法移植的估计器,请在 GitHub 上提交问题。为了冻结环境和依赖项,经常使用容器(如 Docker)在生产中部署 pickle 的模型。

如果您想了解更多关于这些问题并探索其他可能的序列化方法,请参考 Alex Gaynor 的这个演讲

9.1.2. 更安全的格式:skops

skops 通过 skops.io 模块提供了一种更安全的格式。它避免使用 pickle,并且只加载具有默认或用户信任的类型和函数引用的文件。其 API 与 pickle 非常相似,您可以使用 skops.io.dumpskops.io.dumps 将模型持久化,如文档所述:

import skops.io as sio
obj = sio.dumps(clf)

您可以使用 skops.io.loadskops.io.loads 将它们加载回来。但是,您需要指定您信任的类型。您可以使用 skops.io.get_untrusted_types 获取转储对象/文件中的现有未知类型,并在检查其内容后将其传递给加载函数:

unknown_types = sio.get_untrusted_types(data=obj)
clf = sio.loads(obj, trusted=unknown_types)

如果您信任文件/对象的来源,可以传递 trusted=True

clf = sio.loads(obj, trusted=True)

请在 skops 问题跟踪器上报告与此格式相关的问题和功能请求。

9.2. 互操作格式

为了满足可重现性和质量控制的需求,当考虑到不同的架构和环境时,将模型导出为 Open Neural Network Exchange 格式或 Predictive Model Markup Language (PMML) 格式可能比仅使用 pickle 更好。这些格式对于您可能希望在与训练模型不同的环境中使用模型进行预测非常有帮助。

ONNX 是模型的二进制序列化表示。它的开发旨在提高数据模型的可互操作性。它旨在促进在不同的机器学习框架之间转换数据模型,并提高在不同计算架构上的可移植性。有关详细信息,请参阅ONNX 教程。要将 scikit-learn 模型转换为 ONNX,可以使用特定工具 sklearn-onnx

PMML 是实现了用于表示数据模型及其生成数据的 XML 文档标准。作为人类和机器可读的格式,PMML 是在不同平台上进行模型验证和长期存档的良好选择。然而,与一般的 XML 一样,其冗长性在生产环境中对性能至关重要。要将 scikit-learn 模型转换为 PMML,您可以使用例如 sklearn2pmml(在 Affero GPLv3 许可下分发)。

Logo

欢迎加入 MCP 技术社区!与志同道合者携手前行,一同解锁 MCP 技术的无限可能!

更多推荐