|
1 | 1 | {
|
2 | 2 | "cells": [
|
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "f0c2a6e5", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Image Classification with PyTorch\n", |
| 9 | + "Pytorch has been both researcher's and engineer's preferred choice of framework for DL development but when it comes to productionizing pytorch models, there still hasn't been a consensus on what to use. This guide run you through building a simple image classification model using Pytorch and then deploying that to RedisAI" |
| 10 | + ] |
| 11 | + }, |
| 12 | + { |
| 13 | + "cell_type": "code", |
| 14 | + "execution_count": 12, |
| 15 | + "id": "1e657632", |
| 16 | + "metadata": {}, |
| 17 | + "outputs": [], |
| 18 | + "source": [ |
| 19 | + "import torchvision.models as models\n", |
| 20 | + "import torch" |
| 21 | + ] |
| 22 | + }, |
| 23 | + { |
| 24 | + "cell_type": "code", |
| 25 | + "execution_count": 13, |
| 26 | + "id": "56edea97", |
| 27 | + "metadata": {}, |
| 28 | + "outputs": [], |
| 29 | + "source": [ |
| 30 | + "model = models.resnet50(pretrained=True)\n", |
| 31 | + "model.eval()\n", |
| 32 | + "\n", |
| 33 | + "scripted_model = torch.jit.script(model)\n", |
| 34 | + "torch.jit.save(scripted_model, 'resnet50.pt')" |
| 35 | + ] |
| 36 | + }, |
| 37 | + { |
| 38 | + "cell_type": "code", |
| 39 | + "execution_count": 14, |
| 40 | + "id": "4ffd3d48", |
| 41 | + "metadata": {}, |
| 42 | + "outputs": [], |
| 43 | + "source": [ |
| 44 | + "import json\n", |
| 45 | + "import time\n", |
| 46 | + "from redisai import Client\n", |
| 47 | + "import ml2rt\n", |
| 48 | + "from skimage import io" |
| 49 | + ] |
| 50 | + }, |
| 51 | + { |
| 52 | + "cell_type": "code", |
| 53 | + "execution_count": 15, |
| 54 | + "id": "59b6599a", |
| 55 | + "metadata": {}, |
| 56 | + "outputs": [], |
| 57 | + "source": [ |
| 58 | + "import os\n", |
| 59 | + "from redisai import Client\n", |
| 60 | + "\n", |
| 61 | + "REDIS_HOST = os.getenv(\"REDIS_HOST\", \"localhost\")\n", |
| 62 | + "REDIS_PORT = int(os.getenv(\"REDIS_PORT\", 6379))" |
| 63 | + ] |
| 64 | + }, |
| 65 | + { |
| 66 | + "cell_type": "code", |
| 67 | + "execution_count": 16, |
| 68 | + "id": "97bbea43", |
| 69 | + "metadata": {}, |
| 70 | + "outputs": [], |
| 71 | + "source": [ |
| 72 | + "con = Client(host=REDIS_HOST, port=REDIS_PORT)" |
| 73 | + ] |
| 74 | + }, |
| 75 | + { |
| 76 | + "cell_type": "code", |
| 77 | + "execution_count": 17, |
| 78 | + "id": "95325dca", |
| 79 | + "metadata": {}, |
| 80 | + "outputs": [ |
| 81 | + { |
| 82 | + "data": { |
| 83 | + "text/plain": [ |
| 84 | + "True" |
| 85 | + ] |
| 86 | + }, |
| 87 | + "execution_count": 17, |
| 88 | + "metadata": {}, |
| 89 | + "output_type": "execute_result" |
| 90 | + } |
| 91 | + ], |
| 92 | + "source": [ |
| 93 | + "con.ping()" |
| 94 | + ] |
| 95 | + }, |
| 96 | + { |
| 97 | + "cell_type": "code", |
| 98 | + "execution_count": 18, |
| 99 | + "id": "f7ddde68", |
| 100 | + "metadata": {}, |
| 101 | + "outputs": [ |
| 102 | + { |
| 103 | + "data": { |
| 104 | + "text/plain": [ |
| 105 | + "'OK'" |
| 106 | + ] |
| 107 | + }, |
| 108 | + "execution_count": 18, |
| 109 | + "metadata": {}, |
| 110 | + "output_type": "execute_result" |
| 111 | + } |
| 112 | + ], |
| 113 | + "source": [ |
| 114 | + "model = ml2rt.load_model(\"resnet50.pt\")\n", |
| 115 | + "con.modelstore(\"pytorch_model\", backend=\"TORCH\", device=\"CPU\", data=model)" |
| 116 | + ] |
| 117 | + }, |
| 118 | + { |
| 119 | + "cell_type": "code", |
| 120 | + "execution_count": 28, |
| 121 | + "id": "50bb90b1", |
| 122 | + "metadata": {}, |
| 123 | + "outputs": [ |
| 124 | + { |
| 125 | + "data": { |
| 126 | + "text/plain": [ |
| 127 | + "'OK'" |
| 128 | + ] |
| 129 | + }, |
| 130 | + "execution_count": 28, |
| 131 | + "metadata": {}, |
| 132 | + "output_type": "execute_result" |
| 133 | + } |
| 134 | + ], |
| 135 | + "source": [ |
| 136 | + "script = \"\"\"\n", |
| 137 | + "def pre_process(tensors: List[Tensor], keys: List[str], args: List[str]):\n", |
| 138 | + " image = tensors[0]\n", |
| 139 | + " mean = torch.zeros(3).float().to(image.device)\n", |
| 140 | + " std = torch.zeros(3).float().to(image.device)\n", |
| 141 | + " mean[0], mean[1], mean[2] = 0.485, 0.456, 0.406\n", |
| 142 | + " std[0], std[1], std[2] = 0.229, 0.224, 0.225\n", |
| 143 | + " mean = mean.unsqueeze(1).unsqueeze(1)\n", |
| 144 | + " std = std.unsqueeze(1).unsqueeze(1)\n", |
| 145 | + " temp = image.float().div(255).permute(2, 0, 1)\n", |
| 146 | + " return temp.sub(mean).div(std).unsqueeze(0)\n", |
| 147 | + "\n", |
| 148 | + "\n", |
| 149 | + "def post_process(tensors: List[Tensor], keys: List[str], args: List[str]):\n", |
| 150 | + " output = tensors[0]\n", |
| 151 | + " return output.max(1)[1]\n", |
| 152 | + "\"\"\"\n", |
| 153 | + "con.scriptstore(\"processing_script\", device=\"CPU\", script=script, entry_points=(\"pre_process\", \"post_process\"))" |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | + "cell_type": "code", |
| 158 | + "execution_count": 29, |
| 159 | + "id": "f24ce05d", |
| 160 | + "metadata": {}, |
| 161 | + "outputs": [], |
| 162 | + "source": [ |
| 163 | + "image = io.imread(\"../data/cat.jpg\")" |
| 164 | + ] |
| 165 | + }, |
| 166 | + { |
| 167 | + "cell_type": "code", |
| 168 | + "execution_count": 30, |
| 169 | + "id": "40e02215", |
| 170 | + "metadata": {}, |
| 171 | + "outputs": [ |
| 172 | + { |
| 173 | + "data": { |
| 174 | + "text/plain": [ |
| 175 | + "'OK'" |
| 176 | + ] |
| 177 | + }, |
| 178 | + "execution_count": 30, |
| 179 | + "metadata": {}, |
| 180 | + "output_type": "execute_result" |
| 181 | + } |
| 182 | + ], |
| 183 | + "source": [ |
| 184 | + "con.tensorset('image', image)" |
| 185 | + ] |
| 186 | + }, |
| 187 | + { |
| 188 | + "cell_type": "code", |
| 189 | + "execution_count": 31, |
| 190 | + "id": "cacc9eb6", |
| 191 | + "metadata": {}, |
| 192 | + "outputs": [ |
| 193 | + { |
| 194 | + "ename": "ResponseError", |
| 195 | + "evalue": "The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File \"<string>\", line 10, in pre_process mean = mean.unsqueeze(1).unsqueeze(1) std = std.unsqueeze(1).unsqueeze(1) temp = image.float().div(255).permute(2, 0, 1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE return temp.sub(mean).div(std).unsqueeze(0) RuntimeError: number of dims don't match in permute ", |
| 196 | + "output_type": "error", |
| 197 | + "traceback": [ |
| 198 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 199 | + "\u001b[0;31mResponseError\u001b[0m Traceback (most recent call last)", |
| 200 | + "\u001b[0;32m/var/folders/66/g3bgwk8s0mq9fmm1d32nmb8c0000gq/T/ipykernel_4521/4111896467.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscriptexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'processing_script'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pre_process'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'image'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'processed'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", |
| 201 | + "\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redisai/client.py\u001b[0m in \u001b[0;36mscriptexecute\u001b[0;34m(self, key, function, keys, inputs, args, outputs, timeout)\u001b[0m\n\u001b[1;32m 786\u001b[0m \"\"\"\n\u001b[1;32m 787\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbuilder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscriptexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 788\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_command\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 789\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_postprocess\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mprocessor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscriptexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", |
| 202 | + "\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redis/client.py\u001b[0m in \u001b[0;36mexecute_command\u001b[0;34m(self, *args, **options)\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 900\u001b[0m \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_command\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 901\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_response\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcommand_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 902\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mConnectionError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTimeoutError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisconnect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 203 | + "\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redis/client.py\u001b[0m in \u001b[0;36mparse_response\u001b[0;34m(self, connection, command_name, **options)\u001b[0m\n\u001b[1;32m 913\u001b[0m \u001b[0;34m\"Parses a response from the Redis server\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 914\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 915\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconnection\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_response\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 916\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mResponseError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 917\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mEMPTY_RESPONSE\u001b[0m \u001b[0;32min\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 204 | + "\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redis/connection.py\u001b[0m in \u001b[0;36mread_response\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 754\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mResponseError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 756\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 757\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 758\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", |
| 205 | + "\u001b[0;31mResponseError\u001b[0m: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File \"<string>\", line 10, in pre_process mean = mean.unsqueeze(1).unsqueeze(1) std = std.unsqueeze(1).unsqueeze(1) temp = image.float().div(255).permute(2, 0, 1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE return temp.sub(mean).div(std).unsqueeze(0) RuntimeError: number of dims don't match in permute " |
| 206 | + ] |
| 207 | + } |
| 208 | + ], |
| 209 | + "source": [ |
| 210 | + "con.scriptexecute('processing_script', 'pre_process', 'image', 'processed')" |
| 211 | + ] |
| 212 | + }, |
3 | 213 | {
|
4 | 214 | "cell_type": "code",
|
5 | 215 | "execution_count": null,
|
6 |
| - "id": "c63c7ff8", |
| 216 | + "id": "848627bd", |
7 | 217 | "metadata": {},
|
8 | 218 | "outputs": [],
|
9 | 219 | "source": []
|
|
0 commit comments