模型部署与实时预测应用

如果已经构建了用于实时预测机器学习模型,并希望将模型部署为网络应用以提高其可访问性,那么本文正是所需要的。将探讨如何将已经构建的机器学习或深度学习模型部署到网络上。

模型部署概念理解

在开始之前,需要理解模型部署的概念。部署意味着将训练好的模型应用到实际环境中,以便用户可以通过网络访问并与之交互。本文将通过一个狗品种分类器的例子,展示如何使用Streamlit来部署深度学习分类器模型。

为什么选择Streamlit

Streamlit是一个轻量级的框架,允许使用简单的代码为机器学习项目创建应用。它支持热重载,这意味着可以在编辑并保存文件时实时更新应用。使用Streamlit创建应用非常简单,添加一个控件就像声明一个变量一样简单。不需要编写后端代码,也不需要定义不同的路由或处理HTTP请求。更多信息可以在他们的官方网站上找到:。

部署狗品种分类器模型

假设已经通过提供的链接学习了如何创建狗品种分类器,现在是时候开始部署部分了。将使用在不同狗品种上训练过的模型。请按照上述链接训练模型并保存以下文件:feature_extractor.h5、dog_breed.h5和dog_breeds_category.pickle。feature_extractor.h5是一个保存模型,用于从图像中提取特征;dog_breed.h5是另一个用于预测的保存模型;dog_breeds_category.pickle文件用于将类编号转换为类标签。

使用Streamlit进行模型部署

一旦有了所有必需的文件,让开始Streamlit的安装过程并构建一个网络应用。

pip install streamlit

运行上述命令将安装所有依赖项,并在Python环境中设置Streamlit。

虽然创建目录树不是必需的,但组织文件和文件夹是一个好习惯。首先创建一个project_folder,在project_folder中创建另一个名为static的文件夹,并将所有下载的文件放入static中。同时,在static中创建一个名为images的文件夹。现在创建一个空的main.py和helper.py文件,并将其放置在项目目录中。

创建一个predictor函数,该函数接受上传图片的路径作为输入,并输出不同的狗品种类别。predictor函数将处理所有图像处理和模型加载,这些是预测所需的。predictor函数将在helper.py中编码,以保持结构有序。

import cv2 import os import numpy as np import pickle import tensorflow as tf from tensorflow.keras import layers from tensorflow.keras import models, utils import pandas as pd from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import load_img, img_to_array from tensorflow.python.keras import utils

在上述代码块中,使用pickle文件加载了不同类别的狗品种,然后加载了包含训练权重的权重文件(.h5文件)。现在将定义一个predictor函数,该函数接受图像的路径作为输入,并返回预测。

def predictor(img_path): img = load_img(img_path, target_size=(331,331)) img = img_to_array(img) img = np.expand_dims(img, axis=0) features = feature_extractor.predict(img) prediction = predictor_model.predict(features)*100 prediction = pd.DataFrame(np.round(prediction,1),columns = dog_breeds).transpose() prediction.columns = ['values'] prediction = prediction.nlargest(5, 'values') prediction = prediction.reset_index() prediction.columns = ['name', 'values'] return(prediction)

在上述代码块中,执行了以下操作:首先将图像路径传递给predictor函数。该函数以(331,331)的尺寸读取图像并将其转换为数组。然后,它将图像数组转换为用于预测的张量(4-d)数组。然后,它将张量传递给feature_extractor函数以获取用于预测的提取特征,现在提取的特征将成为predictor_model的输入。最后,它将提取的特征传递给predictor_model并获得最终预测,然后将其转换为数据框以获得所需的预测格式。predictor函数返回具有预测置信度的前5个检测到的狗品种的数据框。

目标是创建一个网络应用,可以在其中上传图片,然后将该图片保存在static/images目录中以进行预测部分。

from helper import * import streamlit as st import os import matplotlib.pyplot as plt import seaborn as sns sns.set_theme(style="darkgrid") sns.set() from PIL import Image st.title('狗品种分类器')

在上述代码中,首先导入了所有依赖项,然后创建了一个标题为“狗品种分类器”的应用。

def save_uploaded_file(uploaded_file): try: with open(os.path.join('static/images',uploaded_file.name),'wb') as f: f.write(uploaded_file.getbuffer()) return 1 except: return 0

此函数将上传的图片保存到static/images文件夹中。

uploaded_file = st.file_uploader("上传图片") if uploaded_file is not None: if save_uploaded_file(uploaded_file): display_image = Image.open(uploaded_file) st.image(display_image) prediction = predictor(os.path.join('static/images',uploaded_file.name)) os.remove('static/images/'+uploaded_file.name) st.text('预测结果:') fig, ax = plt.subplots() ax = sns.barplot(y = 'name',x='values', data = prediction,order = prediction.sort_values('values',ascending=False).name) ax.set(xlabel='置信度 %', ylabel='品种') st.pyplot(fig)
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485