构建机器学习模型的API服务

在本文中,将探讨如何构建一个机器学习模型的API服务,以满足Google MLOps Maturity Model的第二级要求。假设读者已经对Python、深度学习、Docker、DevOps和Flask有一定的了解。在之前的文章中,讨论了机器学习CI/CD管道中的单元测试步骤。本文将介绍如何构建模型API以支持预测服务。

项目流程图

下面的流程图展示了在项目过程中的位置。

代码文件结构

本文中的代码与之前的文章几乎相同,因此只关注差异部分。完整的代码可以在下面的仓库中找到,因为下面展示的是代码片段的简化版本。

task.py文件在容器内协调程序执行,如下所示:

import tensorflow as tf from tensorflow.keras.models import load_model import jsonpickle import data_utils, email_notifications import sys import os from google.cloud import storage import datetime import numpy as np import jsonpickle import cv2 from flask import Flask, Response, request, jsonify import threading import requests import time # IMPORTANT # 如果希望在本地运行此容器并通过本地浏览器访问API,请使用 http://172.17.0.2:5000/ app = Flask(__name__) # 一般变量声明 model_name = 'best_model.hdf5' bucket_name = 'automatictrainingcicd-aiplatform' global model @app.before_first_request def before_first_request(): def initialize_job(): if len(tf.config.experimental.list_physical_devices('GPU')) > 0: tf.config.set_soft_device_placement(True) tf.debugging.set_log_device_placement(True) global model # 检查GCS上是否有测试模型 model_gcs = data_utils.previous_model(bucket_name, model_name) # 如果prod上存在模型,则加载它,在数据上测试它,并在API上使用它 if model_gcs[0] == True: model_gcs = data_utils.load_model(bucket_name, model_name) if model_gcs[0] == True: try: model = load_model(model_name) except Exception as e: email_notifications.exception('尝试加载生产模型时出错。异常:' + str(e)) sys.exit(1) else: email_notifications.exception('尝试加载生产模型时出错。异常:' + str(model_gcs[1])) sys.exit(1) if model_gcs[0] == False: email_notifications.send_update('模型注册表中没有工件。请检查GCP以获取更多信息。') sys.exit(1) if model_gcs[0] == None: email_notifications.exception('尝试检查生产模型是否存在时出错。异常:' + model_gcs[1] + '. 中止执行。') sys.exit(1) thread = threading.Thread(target=initialize_job) thread.start() @app.route('/init', methods=['GET', 'POST']) def init(): message = {'message': 'API initialized.'} response = jsonpickle.encode(message) return Response(response=response, status=200, mimetype="application/json") @app.route('/', methods=['POST']) def index(): if request.method == 'POST': try: # 将包含图像的字符串转换为uint8 image = np.fromstring(request.data, np.uint8) image = image.reshape((128, 128, 3)) image = [image] image = np.array(image) image = image.astype(np.float16) result = model.predict(image) result = np.argmax(result) message = {'message': '{}'.format(str(result))} json_response = jsonify(message) return json_response except Exception as e: message = {'message': 'Error'} json_response = jsonify(message) email_notifications.exception('尝试通过生产API进行预测时出错。异常:' + str(e) + '. 中止执行。') return json_response else: message = {'message': 'Error. Please use this API in a proper manner.'} json_response = jsonify(message) return json_response def self_initialize(): def initialization(): global started started = False while started == False: try: server_response = requests.get('http://127.0.0.1:5000/init') if server_response.status_code == 200: print('API has started successfully, quitting initialization job.') started = True except: print('API has not started. Still attempting to initialize it.') time.sleep(3) thread = threading.Thread(target=initialization) thread.start() if __name__ == '__main__': self_initialize() app.run(host='0.0.0.0', debug=True, threaded=True)

data_utils.py文件与之前的版本不同之处在于它从生产注册表加载模型的部分。差异如下:

status = storage.Blob(bucket=bucket, name='{}/{}'.format('testing', model_filename)).exists(storage_client) by status = storage.Blob(bucket=bucket, name='{}/{}'.format('production', model_filename)).exists(storage_client) blob1 = bucket.blob('{}/{}'.format('testing', model_filename)) by blob1 = bucket.blob('{}/{}'.format('production', model_filename))

在Dockerfile中,将以下内容替换:

RUN git clone https://github.com/sergiovirahonda/AutomaticTraining-UnitTesting.git with RUN git clone https://github.com/sergiovirahonda/AutomaticTraining-PredictionAPI.git
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485