Skip to content

Commit 1eabdba

Browse files
author
hhsecond
committed
pytorch example wip
1 parent d9f9279 commit 1eabdba

File tree

3 files changed

+214
-32
lines changed

3 files changed

+214
-32
lines changed

DecisionTreeWithApacheSpark.ipynb

Lines changed: 0 additions & 31 deletions
This file was deleted.

ImageClassificationWithPytorch.ipynb

Lines changed: 211 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,219 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "f0c2a6e5",
6+
"metadata": {},
7+
"source": [
8+
"# Image Classification with PyTorch\n",
9+
"Pytorch has been both researcher's and engineer's preferred choice of framework for DL development but when it comes to productionizing pytorch models, there still hasn't been a consensus on what to use. This guide run you through building a simple image classification model using Pytorch and then deploying that to RedisAI"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 12,
15+
"id": "1e657632",
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"import torchvision.models as models\n",
20+
"import torch"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": 13,
26+
"id": "56edea97",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"model = models.resnet50(pretrained=True)\n",
31+
"model.eval()\n",
32+
"\n",
33+
"scripted_model = torch.jit.script(model)\n",
34+
"torch.jit.save(scripted_model, 'resnet50.pt')"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 14,
40+
"id": "4ffd3d48",
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"import json\n",
45+
"import time\n",
46+
"from redisai import Client\n",
47+
"import ml2rt\n",
48+
"from skimage import io"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": 15,
54+
"id": "59b6599a",
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"import os\n",
59+
"from redisai import Client\n",
60+
"\n",
61+
"REDIS_HOST = os.getenv(\"REDIS_HOST\", \"localhost\")\n",
62+
"REDIS_PORT = int(os.getenv(\"REDIS_PORT\", 6379))"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 16,
68+
"id": "97bbea43",
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"con = Client(host=REDIS_HOST, port=REDIS_PORT)"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 17,
78+
"id": "95325dca",
79+
"metadata": {},
80+
"outputs": [
81+
{
82+
"data": {
83+
"text/plain": [
84+
"True"
85+
]
86+
},
87+
"execution_count": 17,
88+
"metadata": {},
89+
"output_type": "execute_result"
90+
}
91+
],
92+
"source": [
93+
"con.ping()"
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": 18,
99+
"id": "f7ddde68",
100+
"metadata": {},
101+
"outputs": [
102+
{
103+
"data": {
104+
"text/plain": [
105+
"'OK'"
106+
]
107+
},
108+
"execution_count": 18,
109+
"metadata": {},
110+
"output_type": "execute_result"
111+
}
112+
],
113+
"source": [
114+
"model = ml2rt.load_model(\"resnet50.pt\")\n",
115+
"con.modelstore(\"pytorch_model\", backend=\"TORCH\", device=\"CPU\", data=model)"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": 28,
121+
"id": "50bb90b1",
122+
"metadata": {},
123+
"outputs": [
124+
{
125+
"data": {
126+
"text/plain": [
127+
"'OK'"
128+
]
129+
},
130+
"execution_count": 28,
131+
"metadata": {},
132+
"output_type": "execute_result"
133+
}
134+
],
135+
"source": [
136+
"script = \"\"\"\n",
137+
"def pre_process(tensors: List[Tensor], keys: List[str], args: List[str]):\n",
138+
" image = tensors[0]\n",
139+
" mean = torch.zeros(3).float().to(image.device)\n",
140+
" std = torch.zeros(3).float().to(image.device)\n",
141+
" mean[0], mean[1], mean[2] = 0.485, 0.456, 0.406\n",
142+
" std[0], std[1], std[2] = 0.229, 0.224, 0.225\n",
143+
" mean = mean.unsqueeze(1).unsqueeze(1)\n",
144+
" std = std.unsqueeze(1).unsqueeze(1)\n",
145+
" temp = image.float().div(255).permute(2, 0, 1)\n",
146+
" return temp.sub(mean).div(std).unsqueeze(0)\n",
147+
"\n",
148+
"\n",
149+
"def post_process(tensors: List[Tensor], keys: List[str], args: List[str]):\n",
150+
" output = tensors[0]\n",
151+
" return output.max(1)[1]\n",
152+
"\"\"\"\n",
153+
"con.scriptstore(\"processing_script\", device=\"CPU\", script=script, entry_points=(\"pre_process\", \"post_process\"))"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": 29,
159+
"id": "f24ce05d",
160+
"metadata": {},
161+
"outputs": [],
162+
"source": [
163+
"image = io.imread(\"../data/cat.jpg\")"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": 30,
169+
"id": "40e02215",
170+
"metadata": {},
171+
"outputs": [
172+
{
173+
"data": {
174+
"text/plain": [
175+
"'OK'"
176+
]
177+
},
178+
"execution_count": 30,
179+
"metadata": {},
180+
"output_type": "execute_result"
181+
}
182+
],
183+
"source": [
184+
"con.tensorset('image', image)"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": 31,
190+
"id": "cacc9eb6",
191+
"metadata": {},
192+
"outputs": [
193+
{
194+
"ename": "ResponseError",
195+
"evalue": "The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File \"<string>\", line 10, in pre_process mean = mean.unsqueeze(1).unsqueeze(1) std = std.unsqueeze(1).unsqueeze(1) temp = image.float().div(255).permute(2, 0, 1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE return temp.sub(mean).div(std).unsqueeze(0) RuntimeError: number of dims don't match in permute ",
196+
"output_type": "error",
197+
"traceback": [
198+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
199+
"\u001b[0;31mResponseError\u001b[0m Traceback (most recent call last)",
200+
"\u001b[0;32m/var/folders/66/g3bgwk8s0mq9fmm1d32nmb8c0000gq/T/ipykernel_4521/4111896467.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscriptexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'processing_script'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pre_process'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'image'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'processed'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
201+
"\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redisai/client.py\u001b[0m in \u001b[0;36mscriptexecute\u001b[0;34m(self, key, function, keys, inputs, args, outputs, timeout)\u001b[0m\n\u001b[1;32m 786\u001b[0m \"\"\"\n\u001b[1;32m 787\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbuilder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscriptexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 788\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_command\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 789\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_postprocess\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mprocessor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscriptexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
202+
"\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redis/client.py\u001b[0m in \u001b[0;36mexecute_command\u001b[0;34m(self, *args, **options)\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 900\u001b[0m \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_command\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 901\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_response\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcommand_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 902\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mConnectionError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTimeoutError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisconnect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
203+
"\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redis/client.py\u001b[0m in \u001b[0;36mparse_response\u001b[0;34m(self, connection, command_name, **options)\u001b[0m\n\u001b[1;32m 913\u001b[0m \u001b[0;34m\"Parses a response from the Redis server\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 914\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 915\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconnection\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_response\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 916\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mResponseError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 917\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mEMPTY_RESPONSE\u001b[0m \u001b[0;32min\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
204+
"\u001b[0;32m~/asgard/redisai-examples/venv/lib/python3.8/site-packages/redis/connection.py\u001b[0m in \u001b[0;36mread_response\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 754\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mResponseError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 756\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 757\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 758\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
205+
"\u001b[0;31mResponseError\u001b[0m: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File \"<string>\", line 10, in pre_process mean = mean.unsqueeze(1).unsqueeze(1) std = std.unsqueeze(1).unsqueeze(1) temp = image.float().div(255).permute(2, 0, 1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE return temp.sub(mean).div(std).unsqueeze(0) RuntimeError: number of dims don't match in permute "
206+
]
207+
}
208+
],
209+
"source": [
210+
"con.scriptexecute('processing_script', 'pre_process', 'image', 'processed')"
211+
]
212+
},
3213
{
4214
"cell_type": "code",
5215
"execution_count": null,
6-
"id": "c63c7ff8",
216+
"id": "848627bd",
7217
"metadata": {},
8218
"outputs": [],
9219
"source": []

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,6 @@ webencodings==0.5.1
9898
widgetsnbextension==3.5.2
9999
wrapt==1.13.3
100100
zipp==3.6.0
101+
torch==1.10.0
102+
torchvision==0.11.1
103+
scikit-image==0.18.3

0 commit comments

Comments
 (0)