自定义文本分类应用教程

随着人工智能技术的发展,越来越多的平台开始利用AI技术来识别和自动删除不雅或冒犯性帖子。受到启发,决定尝试做类似的事情,并为此编写了一个端到端的教程,帮助构建自己的语料库,用于训练文本分类模型,并将模型导出并部署到Android应用中使用。所有这些操作都将基于选择的自定义数据集进行。

是否对构建自己的文本分类应用感到兴奋?如果是,让开始吧。在开始之前,请允许告诉,将在Google Colab上完成所有模型的超参数配置和训练工作。对于构建Android应用,需要安装Android Studio。如果还没有安装,请下载。

步骤1:构建语料库

正在构建一个文本分类器,该分类器将输入文本分类为冒犯性或非冒犯性。为了简化,将文本分类模型限制为两个类别。可以根据自己的需要创建自定义数据集,并包含任意数量的类别。整个过程将完全相同。

在数据收集方面,让坦白。尽管Kaggle上有大量数据集,但没有找到适合应用的数据集。因此,通过在线抓取文本和引用,手动构建了自己的数据集,认为这些文本可以被最好地分类为冒犯性或非冒犯性。收集了大约1000个句子,并将它们标记为“冒犯性”和“非冒犯性”,这取决于每个句子中使用的语言的性质。训练数据和测试数据的比例为7:3。

以下是的一个快照。创建自己的数据集时,请创建一个包含两列的CSV文件——“句子”和“标签”。在这两个列中添加文本及其相应的标签。第一列是索引列,稍后将由Pandas库自动创建。在创建数据集时,请忽略它。

如果数据集包含许多标签,请确保为每个标签收集足够的数据,以防止数据集变得有偏见。一旦收集了足够的文本用于语料库,请将数据集分成两个不同的CSV文件——“train.csv”和“test.csv”。理想比例为7:3。正如在中反复强调的,“垃圾进,垃圾出”。因此,请花时间构建语料库,不要急于求成。

步骤2:训练模型

克隆仓库到本地计算机。在此仓库中,将找到一个名为“Custom_Text_Classification.ipynb”的笔记本文件。使用Google账户登录,并在Google Colab上打开笔记本并连接到运行时。上传在步骤1中创建的“train.csv”和“test.csv”文件;如下所示。

!pip install -q tflite-model-maker import numpy as np from numpy.random import RandomState import pandas as pd import os from tflite_model_maker import model_spec from tflite_model_maker import text_classifier from tflite_model_maker.config import ExportFormat from tflite_model_maker.config import QuantizationConfig from tflite_model_maker.text_classifier import AverageWordVecSpec from tflite_model_maker.text_classifier import DataLoader import tensorflow as tf assert tf.__version__.startswith('2') tf.get_logger().setLevel('ERROR')

导入数据集并查看数据集,确保数据集已正确导入。

df_train = pd.read_csv('train.csv', error_bad_lines=False, engine="python") df_test = pd.read_csv('test.csv', error_bad_lines=False, engine="python") df_train.head() df_test.head()

选择模型架构。根据选择注释其他模型架构。每个模型架构都与其他架构不同,并将产生不同的结果。MobileBERT模型训练时间较长,因为它的架构相当复杂。然而,可以随意尝试不同的架构,直到找到最佳结果。

spec = model_spec.get('average_word_vec') # spec = model_spec.get('mobilebert_classifier') # spec = model_spec.get('bert_classifier') # spec = AverageWordVecSpec(wordvec_dim=32)

这些模型的架构和工作原理超出了本文的范围。然而,如果想了解更多关于这些模型的信息,请访问。

自定义MobileBERT模型的超参数(可选)。如果选择了“mobilebert_classifier”架构,请仅运行此单元格。

# spec.seq_len = 256

加载训练和测试数据。加载训练和测试数据CSV文件,为模型训练过程做准备。确保test_data的is_training参数设置为False。

train_data = DataLoader.from_csv( filename='train.csv', text_column='sentence', label_column='label', model_spec=spec, is_training=True) test_data = DataLoader.from_csv( filename='test.csv', text_column='sentence', label_column='label', model_spec=spec, is_training=False)

开始在训练数据集上训练模型。随意尝试不同的epoch数量,直到找到最佳结果。

model = text_classifier.create(train_data, model_spec=spec, epochs=100)

检查模型结构——神经网络的层。

model.summary()

评估模型。在测试数据上评估模型的准确性,并亲自查看模型是否需要进行调整,例如增加数据集或调整超参数以提高准确性。

loss, acc = model.evaluate(test_data)

导出TF Lite模型。最终模型将被导出为TF Lite模型,可以下载并直接部署到Android应用中。

model.export(export_dir='average_word_vec')

运行上述单元格后,将创建一个名为“average_word_vec”的文件夹,其中包含名为model.tflite的TF Lite模型。请将此模型下载到本地计算机。

步骤3:创建Android应用

已经根据官方TensorFlowLite文本分类应用构建了一个Android应用,并根据需要进行了定制,以便可以直观地表示预测结果。可以在之前克隆的仓库中的Android_App文件夹中找到此应用。

现在,让将模型部署到应用中。将model.tflite文件复制到Custom-Text-Classification-on-Android-using-TF-Lite/Android_App/lib_task_api/src/main/assets目录中。

一旦将模型复制到指定目录,打开Android Studio中的项目并让其构建一段时间。在项目构建的同时,让去散步并喝一杯饮料。更喜欢在任何给定的一天喝茶而不是咖啡。🙂

沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485