在数据分析和机器学习领域,预测未来事件是一项重要任务。贝叶斯学习作为一种评估假设数据真实性的工具,可以帮助基于历史数据预测未来事件。本文将探讨如何使用PyMC库进行贝叶斯学习,以预测未来的需求。
本文将介绍如何使用贝叶斯学习技术,特别是PyMC库,来预测未来的需求。将创建一个模型,分析模型结果,并基于训练好的模型预测未来的需求。
在开始之前,需要准备数据集。将使用pandas和numpy库创建一个包含销售数量和销售日期的合成销售数据集,数据遵循正态分布。以下是创建数据集的代码:
import pandas as pd
import numpy as np
import seaborn as sns
# 创建日期范围
date_range = pd.date_range(start='2022-05-01', end='2022-12-31', freq='D')
# 生成正态分布的日销售额
mean = 50 # 销售均值
std = 10 # 销售标准差
sales = pd.Series(np.random.normal(loc=mean, scale=std, size=len(date_range)), index=date_range)
# 创建包含日销售的DataFrame
sales_data = pd.DataFrame({'sales': sales})
# 显示DataFrame的前5行
sales_data.head()
从直方图中可以看出,数据遵循正态分布。
将使用PyMC库创建模型。PyMC是一个Python模块,允许应用贝叶斯推断。贝叶斯推断是一种统计方法,用于基于观测数据和概率模型估计未知参数的概率分布。PyMC将允许创建一个与数据集中的“销售”变量类似的正态分布模型。
在观察数据集的销售之前,可能已经对这些参数有了先验知识,这些知识可以通过业务专家获得。这种先验知识称为先验。重要的是要知道,先验知识是不确定的,因为可能没有足够的信息来完全了解参数。此外,即使有这些信息,它也可能是不正确的。
因此,在使用先验知识时需要考虑这种不确定性。先验的选择在贝叶斯推断中非常重要,因为它将影响后验分布,以至于后验参数更多地受到先验的影响,而不是数据集参数,如果它们接近观测数据的值。后验分布是在更新先验后,考虑到通过观测数据获得的知识,概率分布将如何变化。
在给定的例子中,均值的先验被定义为均值为80,标准差为20,这可能意味着某个领域知识渊博的人认为这些应该是数据的参数值。
为了执行贝叶斯推断,还需要指定数据的可能性,并从分布中采样以获得参数的后验分布。可能性是给定模型中正态分布的均值和标准差下观测销售数据的概率分布。
采样必须谨慎进行,以确保收集的样本代表后验分布。在选择样本大小时要小心,因为非常大的样本会使过程非常耗时,而非常小的样本可能不适用于进行预测。在示例中,将生成5000个样本,首先执行1000次拟合迭代,以避免低概率样本。
import pymc as pm
# 创建模型
with pm.Model() as model:
# 定义均值的先验
mu = pm.Normal('mu', 80, 20)
# 定义标准差的先验
sigma = pm.HalfNormal('sigma', 10)
# 定义可能性
sales = pm.Normal('sales', mu, sigma, observed=sales_data)
# 采样
trace = pm.sample(5000, tune=1000)
生成采样后,将通过绘图分析结果。这种诊断将用于验证采样的质量。
import arviz as az
# 分析结果
with model:
az.plot_trace(trace)
# 绘制后验直方图的均值
az.plot_posterior(trace)
# 摘要
pm.summary(trace)
可以看出,先验被定义为与观测数据的均值和标准差相差很远的值,这意味着引入了错误的先验信息。记住,先验的影响可能因观测数据的数量、数据的质量以及先验与观测数据之间的差异而有所不同。因此,应该谨慎选择。
# 从后验生成样本
mu_samples = trace['posterior']['mu']
sigma_samples = trace['posterior']['sigma']
# 预测未来需求
future_sales = np.random.normal(mu_samples.mean(), sigma_samples.mean(), size=90)
# 绘制预测的未来需求
import matplotlib.pyplot as plt
plt.plot(future_sales)
plt.title('预测未来需求')
plt.xlabel('天')
plt.ylabel('销售')
plt.show()