ML.NET入门:二分类问题的解决方案

ML.NET是一个开源的.NET机器学习库,它为.NET开发者提供了一个与机器学习库之间的桥梁。本文将介绍如何在Visual Studio2017中使用ML.NET库来解决一个简单的二分类问题。

假设有两个点(在二维空间中)的集合,分别是红色和蓝色,目标是根据这些点的坐标(x和y)来预测一个点是属于红色组还是蓝色组。训练数据可能如下所示:

3 -2 红色 -2 3 红色 -1 -4 红色 2 3 红色 3 4 红色 -1 9 蓝色 2 14 蓝色 1 17 蓝色 3 12 蓝色 0 8 蓝色

有十个点。每一行的前两个值是每个点的坐标(x和y),第三个值是该点所属的组。因为只有两个输出,即蓝色或红色,所以问题是一个二分类问题。有许多不同的机器学习技术可以解决二分类问题,本文将使用逻辑回归,因为它是最简单的机器学习算法。

创建.NET应用程序并安装ML.NET

为了简单起见,将创建一个名为MyFirstMLDOTNET的控制台应用程序C#(.NET Framework)。在解决方案资源管理器窗口中,将Program.cs重命名为MyFirstMLDOTNET.cs:

可以通过右键单击MyFirstMLDOTNET项目并选择“管理NuGet包”来安装ML.NET:

在NuGet窗口中,选择浏览标签页,并在搜索字段中输入'ML.NET'。最后,选择Microsoft.ML并点击安装按钮:

点击预览更改,然后点击“接受”许可证接受。几秒钟后,Visual Studio将在输出窗口中显示一条消息。

此时,如果尝试运行应用程序,可能会得到以下错误消息:

通过右键单击MyFirstMLDOTNET项目并选择属性来解决此错误。在属性窗口中,在左侧选择构建项,并将平台目标项中的Any CPU更改为x64:

还需要选择.NET Framework的4.7版本(或更高版本),因为会在早期版本中遇到一些错误。可以通过在左侧选择应用程序项并选择目标框架项中的版本来选择.NET Framework的版本。如果没有4.7版本(或更高版本),可以选择安装其他框架,将被重定向到Microsoft页面下载并安装.NET Framework包:

到目前为止,可以尝试再次运行应用程序,它是成功的。

使用代码

在创建机器学习模型之前,必须通过右键单击MyFirstMLDOTNET项目并选择“添加”>“新建项”,选择文本文件类型,并在名称字段中输入myMLData.txt来创建训练数据文件:

点击“添加”按钮。在myMLData.txt窗口中,输入(或复制上面的)训练数据:

3 -2 红色 -2 3 红色 -1 -4 红色 2 3 红色 3 4 红色 -1 9 蓝色 2 14 蓝色 1 17 蓝色 3 12 蓝色 0 8 蓝色

点击“保存”并关闭myMLData.txt窗口。

创建训练数据文件后,还需要创建数据类。一个名为myData的类定义了训练数据的结构(两个坐标(x和y)和一个标签(红色或蓝色)):

public class myData { [Column(ordinal: "0", name: "XCoord")] public float x; [Column(ordinal: "1", name: "YCoord")] public float y; [Column(ordinal: "2", name: "Label")] public string Label; }

一个名为myPrediction的类保存预测信息:

public class myPrediction { [ColumnName("PredictedLabel")] public string PredictedLabels; }

可以创建机器学习模型并训练它:

var pipeline = new LearningPipeline(); string dataPath = "..\\..\\myMLData.txt"; pipeline.Add(new TextLoader(dataPath).CreateFrom<myData>(separator: ',')); pipeline.Add(new Dictionarizer("Label")); pipeline.Add(new ColumnConcatenator("Features", "XCoord", "YCoord")); pipeline.Add(new LogisticRegressionBinaryClassifier()); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); Console.WriteLine("\nStarting training \n"); var model = pipeline.Train<myData, myPrediction>();

可以这样评估机器学习模型:

var testData = new TextLoader(dataPath).CreateFrom<myData>(separator: ','); var evaluator = new BinaryClassificationEvaluator(); var metrics = evaluator.Evaluate(model, testData); double acc = metrics.Accuracy * 100; Console.WriteLine("Model accuracy = " + acc.ToString("F2") + "%");

最后,可以用一个新的点来测试模型:

myData newPoint = new myData() { x = 5f, y = -7f }; myPrediction prediction = model.Predict(newPoint); string result = prediction.PredictedLabels; Console.WriteLine("Prediction = " + result); using System; using Microsoft.ML.Runtime.Api; using System.Threading.Tasks; using Microsoft.ML.Legacy; using Microsoft.ML.Legacy.Data; using Microsoft.ML.Legacy.Transforms; using Microsoft.ML.Legacy.Trainers; using Microsoft.ML.Legacy.Models; namespace MyFirstMLDOTNET { class MyFirstMLDOTNET { public class myData { [Column(ordinal: "0", name: "XCoord")] public float x; [Column(ordinal: "1", name: "YCoord")] public float y; [Column(ordinal: "2", name: "Label")] public string Label; } public class myPrediction { [ColumnName("PredictedLabel")] public string PredictedLabels; } static void Main(string[] args) { var pipeline = new LearningPipeline(); string dataPath = "..\\..\\myMLData.txt"; pipeline.Add(new TextLoader(dataPath).CreateFrom<myData>(separator: ',')); pipeline.Add(new Dictionarizer("Label")); pipeline.Add(new ColumnConcatenator("Features", "XCoord", "YCoord")); pipeline.Add(new LogisticRegressionBinaryClassifier()); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); Console.WriteLine("\nStarting training \n"); var model = pipeline.Train<myData, myPrediction>(); var testData = new TextLoader(dataPath).CreateFrom<myData>(separator: ','); var evaluator = new BinaryClassificationEvaluator(); var metrics = evaluator.Evaluate(model, testData); double acc = metrics.Accuracy * 100; Console.WriteLine("Model accuracy = " + acc.ToString("F2") + "%"); myData newPoint = new myData() { x = 5f, y = -7f }; myPrediction prediction = model.Predict(newPoint); string result = prediction.PredictedLabels; Console.WriteLine("Prediction = " + result); Console.WriteLine("\nEnd ML.NET demo"); Console.ReadLine(); } } }
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485