在本文中,将探讨如何开发一个能够预测比特币价格并检测异常的应用程序。将使用时间序列数据和人工智能技术来实现这一目标。本文假设读者已经具备一定的Python、机器学习和Keras框架的基础知识。整个项目可以在GitHub上找到,同时还有交互式的Jupyter笔记本供参考。
在本文中,将使用LSTM神经网络作为回归器,不使用自动编码器架构。还将在项目笔记本中比较LSTM和卷积神经网络(ConvNets)。虽然ConvNets超出了本系列的范围,但如果LSTM在场景中效果不佳,它们对未来可能会有用。
让通过以下代码创建LSTM NN回归器:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.optimizers import Adam
# 创建模型
regressor = Sequential()
regressor.add(LSTM(256, activation='relu', return_sequences=True, input_shape=(timesteps, nfeatures), dropout=0.2))
regressor.add(LSTM(256, activation='relu', dropout=0.2))
regressor.add(Dense(1))
regressor.compile(loss='mse', optimizer=Adam())
使用均方误差(MSE)作为损失函数,因为正在构建一个回归器而不是分类器。另一方面,模型的最后一层没有激活函数。只设置了一个神经元,因为数据集中只有一个特征,但也在尝试获取给定序列中的下一个值。如果想要获取多个值,建议查看这个笔记本,其中将相同的方法应用于天气数据和其他一些变化。使用下表查看根据任务应该使用哪些激活和损失函数:
要训练模型,请执行以下命令:
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint('/kaggle/working/regressor.hdf5', monitor='val_loss', verbose=1, save_best_only=True, mode='auto', period=1)
history2 = regressor.fit(X_train, y_train, epochs=30, batch_size=128, verbose=1, validation_data=(X_test, y_test), callbacks=[checkpoint], shuffle=False)
训练过程结束后,应该得到类似于以下结果:
Epoch 30/30
175/175 [==============================] - ETA: 0s - loss: 3.6093e-04
Epoch 00030: val_loss did not improve from 0.00024
175/175 [==============================] - 15s 87ms/step - loss: 3.6093e-04 - val_loss: 0.0023
一直在使用keras的ModelCheckpoint方法来保存训练过程中获得的最佳模型。让从现在开始使用它,并在测试集上评估它:
from tensorflow.keras.models import load_model
regressor = load_model('regressor.hdf5')
regressor.evaluate(X_test, y_test)
输出结果:
174/174 [==============================] - 1s 6ms/step - loss: 2.3974e-04
0.0002397427597315982
不错!在评估回归器时,输出值越小,误差越小。误差越小,模型就越好。在这种情况下,得到一个非常接近零的数字,所以相当简单的模型架构足以获得良好的结果。然而,可能会遇到需要设计更复杂模型或投入数小时进行超参数调整的场景,这完全没问题。这一步更多的是试错过程,而不是“冲刺”过程。
让看看实际预测是什么样子的。为此,让取X_test数据集的前24小时,并调用回归器来预测下一个小时的值:
test = regressor.predict(X_test[0].reshape(1, 24, 1))
scaler.inverse_transform(test)
array([[6507.955]], dtype=float32)