AI在时尚设计行业的应用

时尚设计领域,人工智能(AI)技术的应用正逐渐成为推动行业发展的新动力。本文将介绍一个基于深度学习的AI系统,该系统能够通过分析客户需求来革新时尚设计行业。将使用Jupyter Notebook作为集成开发环境(IDE),并结合TensorFlow 2.0、NumPy、Matplotlib等库,以及DeepFashion数据集的一个自定义子集,以减少计算和内存开销。

项目概述

本项目的目标是使用VGG16模型对15种不同的服装类别进行分类,并评估模型的性能。在前一篇文章中,已经展示了如何加载DeepFashion数据集,并如何调整VGG16模型以适应服装分类任务。在本文中,将训练VGG16模型,并评估其性能。

训练VGG16模型

VGG16的迁移学习从冻结模型权重开始,这些权重是通过在如ImageNet这样的大型数据集上训练模型获得的。这些学习到的权重和过滤器为网络提供了出色的特征提取能力,这将有助于在训练模型对服装类别进行分类时提高其性能。因此,只有全连接(FC)层会被训练,而模型的特征提取部分几乎被冻结(通过设置非常低的学习率,如0.001)。让通过将它们设置为False来冻结特征提取层:

for layer in conv_model.layers: layer.trainable = False

接下来,可以编译模型,同时选择学习率(0.001)和优化器(Adamax):

full_model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.Adamax(lr=0.001), metrics=['acc'])

编译完成后,可以使用fit_generator函数开始模型训练,因为使用了ImageDataGenerator来加载数据。将使用train_dataset和val_dataset分别训练和验证网络。将训练三个周期,但这个数字可以根据网络性能增加。

history = full_model.fit_generator(train_dataset, validation_data=val_dataset, workers=0, epochs=3)

运行上述代码将产生以下输出:

现在,为了绘制网络的学习曲线和损失曲线,让添加plot_history函数:

def plot_history(history, yrange): acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc) plt.plot(epochs, val_acc) plt.title('Training and validation accuracy') plt.ylim(yrange) plt.figure() plt.plot(epochs, loss) plt.plot(epochs, val_loss) plt.title('Training and validation loss') plt.show() plot_history(history, yrange=(0.9, 1))

这个函数将生成以下两个图表:

在新图像上评估VGG16

网络在训练过程中表现良好。那么在测试集上也应该表现良好,对吧?将在测试集上测试它。首先,让加载测试集,然后使用model.evaluate函数将测试图像传递给模型以测量网络的准确性。

from tensorflow.keras.preprocessing.image import ImageDataGenerator test_dir = r'C:\Users\abdul\Desktop\ContentLab\P2\DeepFashion\Test' test_datagen = ImageDataGenerator() test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224), batch_size=3, class_mode='categorical') X_test, y_test = next(test_generator) Testresults = full_model.evaluate(test_generator) print("test loss, test acc:", Testresults)
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485