在机器学习中,模型的拟合状态是一个重要的概念。一个未经过训练(拟合)的模型是无法进行预测的。因此,需要一种机制来验证模型是否已经拟合。本文将介绍如何通过检查模型的属性来验证其拟合状态,并在模型未拟合时抛出异常。
在Python的scikit-learn库中,提供了一个名为check_is_fitted
的函数,用于检查模型是否已经拟合。这个函数通过检查模型实例中是否存在以下划线结尾的属性来判断模型是否已经拟合。如果模型没有设置任何以下划线结尾的属性,可以通过定义一个名为__sklearn_is_fitted__
的方法来指定模型是否已经拟合。
如果模型没有传入任何属性,那么只要模型实例中存在一个以下划线结尾且不以双下划线开头的属性,就认为模型已经拟合。此外,模型可以通过设置requires_fit
标签来表明自己是无状态的。关于requires_fit
标签的更多信息,可以参考Estimator Tags
文档。需要注意的是,如果传入了属性,requires_fit
标签将被忽略。
estimator
:要进行检查的模型实例。
attributes
:字符串、字符串列表或元组,默认为None。要检查的属性名称。例如:["coef_", "estimator_"]
或"coef_"
。如果为None,则认为模型已经拟合,如果存在一个以下划线结尾且不以双下划线开头的属性。
msg
:字符串,默认为None。自定义错误消息。如果消息字符串中包含%(name)s
,则会被替换为模型名称。例如:"Estimator, %(name)s, must be fitted before sparsifying"
。
all_or_any
:可调用的,{all, any},默认为all。指定是否所有或任何一个给定的属性必须存在。
TypeError
:如果传入的模型是一个类而不是模型实例。
NotFittedError
:如果未找到属性。
以下是一个使用check_is_fitted
函数的示例代码:
from sklearn.linear_model import LogisticRegression
from sklearn.utils.validation import check_is_fitted
from sklearn.exceptions import NotFittedError
# 创建一个逻辑回归模型实例
lr = LogisticRegression()
try:
# 尝试检查模型是否已经拟合
check_is_fitted(lr)
except NotFittedError as exc:
# 如果模型未拟合,打印错误信息
print(f"Model is not fitted yet.")
# 使用数据训练模型
lr.fit([[1, 2], [1, 3]], [1, 0])
# 再次检查模型是否已经拟合
check_is_fitted(lr)
在这个示例中,首先创建了一个逻辑回归模型实例,然后尝试使用check_is_fitted
函数检查模型是否已经拟合。由于模型尚未训练,因此会抛出NotFittedError
异常。接着,使用数据训练模型,再次检查模型是否已经拟合,这次检查将通过。