模型验证的重要性与实现

机器学习领域,模型验证是一个至关重要的步骤。验证的目的是确保模型不仅在训练数据集上表现良好,而且在未见过的数据上也能有良好的预测性能。如果模型在训练数据上表现优异,但在测试数据上表现不佳,称之为过拟合。过拟合意味着模型过于复杂,以至于它学习到了训练数据中的噪声和细节,而没有捕捉到数据的真实分布。为了避免这种情况,需要使用未包含在训练数据集中的测试数据来验证模型。

在本系列的前两篇文章中,已经介绍了如何使用CNTKC#来训练前馈神经网络模型。现在,将进入模型验证的实现。模型训练完成后,需要将模型和训练器传递给评估方法。评估方法将加载测试数据,并使用传递的模型计算输出。然后,它将计算出的(预测的)值与测试数据集的输出进行比较,并计算准确率。

以下是评估实现的C#代码示例:

private static void EvaluateIrisModel(Function ffnn_model, Trainer trainer, DeviceDescriptor device) { var dataFolder = "Data"; var trainPath = Path.Combine(dataFolder, "testIris_cntk.txt"); var featureStreamName = "features"; var labelsStreamName = "label"; var feature = ffnn_model.Arguments[0]; var label = ffnn_model.Output; var streamConfig = new StreamConfiguration[] { new StreamConfiguration(featureStreamName, feature.Shape[0]), new StreamConfiguration(labelsStreamName, label.Shape[0]) }; var testMinibatchSource = MinibatchSource.TextFormatMinibatchSource( trainPath, streamConfig, MinibatchSource.InfinitelyRepeat, true); var featureStreamInfo = testMinibatchSource.StreamInfo(featureStreamName); var labelStreamInfo = testMinibatchSource.StreamInfo(labelsStreamName); int batchSize = 20; int miscountTotal = 0, totalCount = 20; while (true) { var minibatchData = testMinibatchSource.GetNextMinibatch((uint)batchSize, device); if (minibatchData == null || minibatchData.Count == 0) break; totalCount += (int)minibatchData[featureStreamInfo].numberOfSamples; var labelData = minibatchData[labelStreamInfo].data.GetDenseData(label); var expectedLabels = labelData.Select(l => l.IndexOf(l.Max())).ToList(); var inputDataMap = new Dictionary<Variable, Value>() { { feature, minibatchData[featureStreamInfo].data } }; var outputDataMap = new Dictionary<Variable, Value>() { { label, null } }; ffnn_model.Evaluate(inputDataMap, outputDataMap, device); var outputData = outputDataMap[label].GetDenseData(label); var actualLabels = outputData.Select(l => l.IndexOf(l.Max())).ToList(); int misMatches = actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 0 : 1).Sum(); miscountTotal += misMatches; Console.WriteLine($"Validating Model: Total Samples = {totalCount}, Mis-classify Count = {miscountTotal}"); if (totalCount >= 20) break; } Console.WriteLine("---------------"); Console.WriteLine("------TESTING SUMMARY--------"); float accuracy = (1.0F - miscountTotal / totalCount); Console.WriteLine($"Model Accuracy = {accuracy}"); }

在上述代码中,首先定义了数据文件夹和测试数据文件的路径。然后,设置了特征和标签的流配置,并准备了测试数据。接下来,使用一个循环来获取小批量数据,并计算预测值与实际值之间的差异。最后,计算模型的准确率,并输出验证结果。

在前一篇文章中,已经介绍了如何训练模型。现在,将调用评估方法来验证模型:

EvaluateIrisModel(ffnn_model, trainer, device);
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485