在数据科学领域,尤其是机器学习中,多重共线性是一个不可忽视的问题。本文将从统计学的角度出发,详细解释多重共线性的概念、成因、检测方法以及解决方案,并提供Python代码实现。
多重共线性指的是在回归模型中,两个或多个自变量(预测变量)之间存在高度相关性。这意味着一个自变量可以通过另一个自变量来预测。例如,身高和体重、学生消费和父亲收入、年龄和经验、汽车里程和价格等。
以学生消费为例,如果使用零花钱和父亲收入作为自变量来预测学生消费,由于这两个自变量高度相关(父亲收入增加时,零花钱也增加;父亲收入减少时,零花钱也减少),无法准确或单独评估每个自变量对因变量(学生消费)的影响,这就是多重共线性问题。
在回归模型中,目标是找出每个预测变量对目标变量的单独影响,这也是普通最小二乘法(Ordinary Least Squares, OLS)的一个假设。因此,为了实现研究目标,必须解决多重共线性问题,这对预测也至关重要。
假设有如下线性方程:Y = a0 + a1*X1 + a2*X2,其中X1和X2是自变量。在多重共线性存在的情况下,自变量高度相关,如果改变X1,X2也会随之改变,无法看到它们对Y的单独影响,这与研究目标相悖。
“这使得X1对Y的影响难以与X2对Y的影响区分开来。”
注意:多重共线性可能不会显著影响模型的准确性,但可能会失去确定模型中各个独立特征对因变量影响的可靠性,这在需要解释模型时可能会成为问题。
多重共线性可能由以下原因引起:
1. 数据集在创建时存在的问题,如实验设计不当、高度观察性数据或无法操纵数据(称为数据相关多重共线性)。
2. 在数据预处理或特征工程中创建的新变量依赖于其他变量(称为结构相关多重共线性)。
3. 数据集中存在相同的变量。
4. 编码分类特征为数值特征时,错误使用虚拟变量也可能导致多重共线性(称为虚拟变量陷阱)。
5. 在某些情况下,数据不足也可能导致多重共线性问题。
让尝试使用VIF(方差膨胀因子)来检测数据集中的多重共线性,以了解可能发生的问题。
虽然相关矩阵和散点图也可以用来发现多重共线性,但它们的发现只显示自变量之间的二元关系。VIF更受青睐,因为它可以显示一个变量与一组其他变量之间的相关性。
VIF衡量自变量之间的相关性强度。它是通过将一个变量与其他所有变量回归来预测的。
R^2值用于确定一个自变量如何被其他自变量描述。高R^2值意味着该变量与其他变量高度相关,这被VIF捕获,如下所示:
VIF = 1 / (1 - R^2)
因此,R^2值越接近1,VIF值越高,特定自变量的多重共线性越高。VIF从1开始(当R^2=0时,VIF=1——VIF的最小值),没有上限。
VIF = 1,表示自变量与其他变量之间没有相关性。VIF超过5或10表示该自变量与其他变量之间存在高度多重共线性。
一些研究人员认为VIF>5是模型的严重问题,而另一些研究人员认为VIF>10是严重问题,这因人而异。
1. 方差膨胀因子(VIF)
- 如果VIF=1;没有多重共线性
- 如果VIF<5;低多重共线性或中等相关性
- 如果VIF>=5;高多重共线性或高度相关性
2. 容忍度(VIF的倒数)
- 如果VIF高,则容忍度低,即高多重共线性。
- 如果VIF低,则容忍度高,即低多重共线性。
1. 删除导致问题的变量。
2. 如果保留所有X变量,则避免对个别参数进行推断。
3. 重新编码自变量的形式。
import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.datasets import load_boston
boston = load_boston()
boston.DESCR #查看完整数据集的描述。
X = boston["data"] #自变量
Y = boston["target"] #因变量
names = list(boston["feature_names"]) #所有属性的名称
df = pd.DataFrame(X, columns=names) #制作用于数据分析的pandas数据框
print(df.head()) #查看数据框中的前5个样本
for index in range(0, len(names)):
y = df.loc[:, df.columns == names[index]]
x = df.loc[:, df.columns != names[index]]
model = sm.OLS(y, x) #拟合普通最小二乘法
results = model.fit()
rsq = results.rsquared
vif = round(1 / (1 - rsq), 2)
print("R Square值的{}列是{},将所有其他列作为自变量".format(
names[index], (round(rsq, 2))
))
print("Variance Inflation Factor的{}列是{}".format(
names[index], vif
))