在本教程中,我們將使用Flask來部署PyTorch模型,並用講解用於模型推斷的 REST API。特別是,我們將部署一個預訓練的DenseNet 121模型來檢測圖像。
備註:
可在https://github.com/avinassh/pytorch-flask-api上獲取本文用到的完整代碼
這是在生產中部署PyTorch模型的系列教程中的第一篇。到目前為止,以這種方式使用Flask是開始為PyTorch模型提供服務的最簡單方法,但不適用於具有高性能要求的用例。因此:
我們將首先定義API端點、請求和響應類型。我們的API端點將位於/ predict,它接受帶有包含圖像的file參數的HTTP POST請求。響應將是包含預測的JSON響應:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
運行下面的命令來下載我們需要的依賴:
$ pip install Flask==1.0.3 torchvision-0.3.0
以下是一個簡單的Web伺服器,摘自Flask文檔
from flask import Flaskapp = Flask(__name__)@app.route('/')def hello(): return 'Hello World!'
將以上代碼段保存在名為app.py的文件中,您現在可以通過輸入以下內容來運行Flask開發伺服器:
$ FLASK_ENV=development FLASK_APP=app.py flask run
當您在web瀏覽器中訪問http://localhost:5000/時,您會收到文本Hello World的問候!
我們將對以上代碼片段進行一些更改,以使其適合我們的API定義。首先,我們將重命名predict方法。我們將端點路徑更新為/predict。由於圖像文件將通過HTTP POST請求發送,因此我們將對其進行更新,使其也僅接受POST請求:
@app.route('/predict', methods=['POST'])def predict(): return 'Hello World!'
我們還將更改響應類型,以使其返回包含ImageNet類的id和name的JSON響應。更新後的app.py文件現在為:
from flask import Flask, jsonifyapp = Flask(__name__)@app.route('/predict', methods=['POST'])def predict(): return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
在下一部分中,我們將重點介紹編寫推理代碼。這將涉及兩部分,第一部分是準備圖像,以便可以將其饋送到DenseNet;第二部分,我們將編寫代碼以從模型中獲取實際的預測。
DenseNet模型要求圖像為尺寸為224 x 224的 3 通道RGB圖像。我們還將使用所需的均值和標準偏差值對圖像張量進行歸一化。你可以點擊https://pytorch.org/docs/stable/torchvision/models.html來了解更多關於它的內容。
我們將使用來自torchvision庫的transforms來建立轉換管道,該轉換管道可根據需要轉換圖像。您可以在https://pytorch.org/docs/stable/torchvision/transforms.html閱讀有關轉換的更多信息。
import ioimport torchvision.transforms as transformsfrom PIL import Imagedef transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0)
上面的方法以位元組為單位獲取圖像數據,應用一系列變換並返回張量。要測試上述方法,請以位元組模式讀取圖像文件(首先將../_static/img/sample_file.jpeg替換為計算機上文件的實際路徑),然後查看是否獲得了張量:
with open("../_static/img/sample_file.jpeg", 'rb') as f: image_bytes = f.read() tensor = transform_image(image_bytes=image_bytes) print(tensor)
tensor([[[[ 0.4508, 0.4166, 0.3994, ..., -1.3473, -1.3302, -1.3473], [ 0.5364, 0.4851, 0.4508, ..., -1.2959, -1.3130, -1.3302], [ 0.7077, 0.6392, 0.6049, ..., -1.2959, -1.3302, -1.3644], ..., [ 1.3755, 1.3927, 1.4098, ..., 1.1700, 1.3584, 1.6667], [ 1.8893, 1.7694, 1.4440, ..., 1.2899, 1.4783, 1.5468], [ 1.6324, 1.8379, 1.8379, ..., 1.4783, 1.7352, 1.4612]], [[ 0.5728, 0.5378, 0.5203, ..., -1.3704, -1.3529, -1.3529], [ 0.6604, 0.6078, 0.5728, ..., -1.3004, -1.3179, -1.3354], [ 0.8529, 0.7654, 0.7304, ..., -1.3004, -1.3354, -1.3704], ..., [ 1.4657, 1.4657, 1.4832, ..., 1.3256, 1.5357, 1.8508], [ 2.0084, 1.8683, 1.5182, ..., 1.4657, 1.6583, 1.7283], [ 1.7458, 1.9384, 1.9209, ..., 1.6583, 1.9209, 1.6408]], [[ 0.7228, 0.6879, 0.6531, ..., -1.6476, -1.6302, -1.6476], [ 0.8099, 0.7576, 0.7228, ..., -1.6476, -1.6476, -1.6650], [ 1.0017, 0.9145, 0.8797, ..., -1.6476, -1.6650, -1.6999], ..., [ 1.6291, 1.6291, 1.6465, ..., 1.6291, 1.8208, 2.1346], [ 2.1868, 2.0300, 1.6814, ..., 1.7685, 1.9428, 2.0125], [ 1.9254, 2.0997, 2.0823, ..., 1.9428, 2.2043, 1.9080]]]])
現在將使用預訓練的DenseNet 121模型來預測圖像的類別。我們將使用torchvision庫中的一個庫,加載模型並進行推斷。在此示例中,我們將使用預訓練的模型,但您可以對自己的模型使用相同的方法。在這個https://pytorch.org/tutorials/beginner/saving_loading_models.html中了解有關加載模型的更多信息。
from torchvision import models# 確保使用`pretrained`作為`True`來使用預訓練的權重:model = models.densenet121(pretrained=True)# 由於我們僅將模型用於推理,因此請切換到「eval」模式:model.eval()def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) return y_hat
張量y_hat將包含預測的類的id的索引。但是,我們需要一個易於閱讀的類名。為此,我們需要一個類id來命名映射。將https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json下載為imagenet_class_index.json並記住它的保存位置(或者,如果您按照本教程中的確切步驟操作,請將其保存在tutorials/_static中)。此文件包含ImageNet類的id到ImageNet類的name的映射。我們將加載此JSON文件並獲取預測索引的類的name。
import jsonimagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx]
在使用字典imagenet_class_index之前,首先我們將張量值轉換為字符串值,因為字典imagenet_class_index中的keys是字符串。我們將測試上述方法:
with open("../_static/img/sample_file.jpeg", 'rb') as f: image_bytes = f.read() print(get_prediction(image_bytes=image_bytes))
['n02124075', 'Egyptian_cat']
你會得到這樣的一個響應:
['n02124075', 'Egyptian_cat']
數組中的第一項是ImageNet類的id,第二項是人類可讀的name。
注意:您是否注意到模型變量不是get_prediction方法的一部分?或者為什麼模型是全局變量?就內存和計算而言,加載模型可能是
一項昂貴的操作。如果將模型加載到get_prediction方法中,則每次調用該方法時都會不必要地加載該模型。由於我們正在構建Web服務
器,因此每秒可能有成千上萬的請求,因此我們不應該浪費時間為每個推斷重複加載模型。因此,我們僅將模型加載到內存中一次。在生
產系統中,必須有效利用計算以能夠大規模處理請求,因此通常應在處理請求之前加載模型。
在最後一部分中,我們將模型添加到Flask API伺服器中。由於我們的API伺服器應該獲取圖像文件,因此我們將更新predict方法以從請求中讀取文件:
from flask import request@app.route('/predict', methods=['POST'])def predict(): if request.method == 'POST': # 從請求中獲得文件 file = request.files['file'] # 轉化為位元組 img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name})
app.py文件現已完成。以下是完整版本;將路徑替換為保存文件的路徑,它的運行應是如下:
import ioimport jsonfrom torchvision import modelsimport torchvision.transforms as transformsfrom PIL import Imagefrom flask import Flask, jsonify, requestapp = Flask(__name__)imagenet_class_index = json.load(open('/imagenet_class_index.json'))model = models.densenet121(pretrained=True)model.eval()def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])def predict(): if request.method == 'POST': file = request.files['file'] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name})if __name__ == '__main__': app.run()
讓我們測試一下我們的web伺服器,運行:
$ FLASK_ENV=development FLASK_APP=app.py flask run
我們可以使用https://pypi.org/project/requests/庫來發送一個POST請求到我們的app:
import requestsresp = requests.post("http://localhost:5000/predict", files={"file": open('/cat.jpg','rb')})
列印resp.json()會顯示下面的結果:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
我們編寫的伺服器非常瑣碎,可能無法完成生產應用程式所需的一切。因此,您可以採取一些措施來改善它: