梯度下降算法详解

梯度下降算法是一种在机器学习深度学习中广泛使用的优化算法,主要用于调整模型参数以最小化特定函数值,从而达到局部最小值。在线性回归中,该算法用于寻找权重和偏置,而在深度学习的反向传播中也采用这种方法。其核心目标是识别模型参数,如权重和偏置,以减少模型在训练数据上的误差。

梯度是什么?

梯度是函数输出对输入变量微小变化的敏感度量。在机器学习中,梯度是一个多输入变量函数的导数,也称为函数的斜率,它衡量了所有权重对误差变化的影响。

学习率的重要性

学习率是算法设计者可以设置的一个参数。如果使用的学习率太小,会导致更新非常缓慢,需要更多的迭代次数才能获得更好的解决方案。

梯度下降的类型

梯度下降主要有三种流行类型,它们主要的区别在于使用的数据量不同:

批量梯度下降,也称为传统梯度下降,计算训练数据集中每个样本的错误,但在评估完所有训练样本后才更新模型。整个过程称为一个周期,也称为训练周期。

批量梯度下降的优点包括计算效率高,能够产生稳定的误差梯度和稳定的收敛。缺点是稳定的误差梯度有时会导致模型收敛到非最优解,同时需要整个训练数据集都在内存中供算法使用。

class GDRegressor: def __init__(self, learning_rate=0.01, epochs=100): self.coef_ = None self.intercept_ = None self.lr = learning_rate self.epochs = epochs def fit(self, X_train, y_train): self.intercept_ = 0 self.coef_ = np.ones(X_train.shape[1]) for i in range(self.epochs): y_hat = np.dot(X_train, self.coef_) + self.intercept_ intercept_der = -2 * np.mean(y_train - y_hat) self.intercept_ -= (self.lr * intercept_der) coef_der = -2 * np.dot((y_train - y_hat), X_train) / X_train.shape[0] self.coef_ -= (self.lr * coef_der) def predict(self, X_test): return np.dot(X_test, self.coef_) + self.intercept_ from sklearn.linear_model import SGDClassifier X = [[0., 0.], [1., 1.]] y = [0, 1] clf = SGDClassifier(loss="hinge", penalty="l2", max_iter=5) clf.fit(X, y) class MBGDRegressor: def __init__(self, batch_size, learning_rate=0.01, epochs=100): self.coef_ = None self.intercept_ = None self.lr = learning_rate self.epochs = epochs self.batch_size = batch_size def fit(self, X_train, y_train): self.intercept_ = 0 self.coef_ = np.ones(X_train.shape[1]) for i in range(self.epochs): for j in range(int(X_train.shape[0] / self.batch_size)): idx = random.sample(range(X_train.shape[0]), self.batch_size) y_hat = np.dot(X_train[idx], self.coef_) + self.intercept_ intercept_der = -2 * np.mean(y_train[idx] - y_hat) self.intercept_ -= (self.lr * intercept_der) coef_der = -2 * np.dot((y_train[idx] - y_hat), X_train[idx]) self.coef_ -= (self.lr * coef_der) def predict(self, X_test): return np.dot(X_test, self.coef_) + self.intercept_
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485