|
9 | 9 | "# 使用LeNet在MNIST数据集实现图像分类\n",
|
10 | 10 | "\n",
|
11 | 11 | "**作者:** [PaddlePaddle](https://github.com/PaddlePaddle) <br>\n",
|
12 |
| - "**日期:** 2022.4 <br>\n", |
| 12 | + "**日期:** 2022.5 <br>\n", |
13 | 13 | "**摘要:** 本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。"
|
14 | 14 | ]
|
15 | 15 | },
|
|
21 | 21 | "source": [
|
22 | 22 | "## 一、环境配置\n",
|
23 | 23 | "\n",
|
24 |
| - "本教程基于PaddlePaddle 2.3.0-rc0 编写,如果你的环境不是本版本,请先参考官网[安装](https://www.paddlepaddle.org.cn/install/quick) PaddlePaddle 2.3.0-rc0。" |
| 24 | + "本教程基于PaddlePaddle 2.3.0 编写,如果你的环境不是本版本,请先参考官网[安装](https://www.paddlepaddle.org.cn/install/quick) PaddlePaddle 2.3.0。" |
25 | 25 | ]
|
26 | 26 | },
|
27 | 27 | {
|
28 | 28 | "cell_type": "code",
|
29 |
| - "execution_count": 1, |
| 29 | + "execution_count": null, |
30 | 30 | "metadata": {
|
31 | 31 | "collapsed": false
|
32 | 32 | },
|
|
35 | 35 | "name": "stdout",
|
36 | 36 | "output_type": "stream",
|
37 | 37 | "text": [
|
38 |
| - "2.3.0-rc0\n" |
| 38 | + "2.3.0\n" |
39 | 39 | ]
|
40 | 40 | }
|
41 | 41 | ],
|
|
58 | 58 | },
|
59 | 59 | {
|
60 | 60 | "cell_type": "code",
|
61 |
| - "execution_count": 2, |
| 61 | + "execution_count": null, |
62 | 62 | "metadata": {
|
63 | 63 | "collapsed": false
|
64 | 64 | },
|
|
87 | 87 | },
|
88 | 88 | {
|
89 | 89 | "cell_type": "code",
|
90 |
| - "execution_count": 3, |
| 90 | + "execution_count": null, |
91 | 91 | "metadata": {
|
92 | 92 | "collapsed": false
|
93 | 93 | },
|
|
98 | 98 | "text": [
|
99 | 99 | "train_data0 label is: [5]\n"
|
100 | 100 | ]
|
| 101 | + }, |
| 102 | + { |
| 103 | + "data": { |
| 104 | + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACPCAYAAAARM4LLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAACHBJREFUeJzt3V1oVOkZB/D/42j8ql9pZInZYBYVIRT8INYWi0atH13Q4E2JilZZWC/8aMFgTb3QCy+KQi803ixWUrGmFGvYtSwEXcyFuEgSDDbZNasuxs3i1yJq0QtdeXsxx+k8B5M5mXnmnDOZ/w9Czv+cZM4LPr7zzjmTZ8Q5B6JcjYp6ADQysJDIBAuJTLCQyAQLiUywkMgEC4lMsJDIRE6FJCJrRaRPRG6LyH6rQVHhkWyvbItIAsA3AFYBGADQAWCjc+6rwX6nrKzMVVVVZXU+ikZXV9cPzrnpmX5udA7n+DmA2865bwFARP4BoA7AoIVUVVWFzs7OHE5JYROR/iA/l8tTWwWA79LygLfPP5CPRaRTRDofP36cw+kozvK+2HbOfeKcq3HO1UyfnnGGpAKVSyF9D6AyLb/v7aMilEshdQCYIyIfiEgJgHoAn9kMiwpN1ott59yPIrILQBuABIBTzrles5FRQcnlVRucc58D+NxoLFTAeGWbTLCQyAQLiUywkMgEC4lMsJDIBAuJTLCQyAQLiUywkMgEC4lM5HSvrZi8efNG5WfPngX+3aamJpVfvnypcl9fn8onTpxQuaGhQeWWlhaVx40bp/L+/f9/+/zBgwcDjzMXnJHIBAuJTLCQyETRrJHu3bun8qtXr1S+evWqyleuXFH56dOnKp87d85sbJWVlSrv3r1b5dbWVpUnTZqk8rx581RetmyZ2diC4oxEJlhIZIKFRCZG7Brp+vXrKq9YsULl4VwHspZIJFQ+fPiwyhMnTlR58+bNKs+YMUPladOmqTx37txchzhsnJHIBAuJTLCQyMSIXSPNnDlT5bKyMpUt10iLFy9W2b9muXz5ssolJSUqb9myxWwsUeGMRCZYSGSChUQmRuwaqbS0VOWjR4+qfOHCBZUXLFig8p49e4Z8/Pnz56e2L126pI75rwP19PSofOzYsSEfuxBxRiITGQtJRE6JyCMR6UnbVyoiF0Xklvd92lCPQSNfkBmpGcBa3779AL5wzs0B8IWXqYgFao8sIlUA/u2c+5mX+wDUOufui0g5gHbnXMYbPDU1NS4uXW2fP3+usv89Pjt27FD55MmTKp85cya1vWnTJuPRxYeIdDnnajL9XLZrpPecc/e97QcA3svycWiEyHmx7ZJT2qDTGtsjF4dsC+mh95QG7/ujwX6Q7ZGLQ7bXkT4D8DsAf/a+f2o2opBMnjx5yONTpkwZ8nj6mqm+vl4dGzWq+K6qBHn53wLgSwBzRWRARD5CsoBWicgtAL/2MhWxjDOSc27jIIdWGo+FCljxzcGUFyP2XluuDh06pHJXV5fK7e3tqW3/vbbVq1fna1ixxRmJTLCQyAQLiUxk/VGk2YjTvbbhunPnjsoLFy5MbU+dOlUdW758uco1NfpW1c6dO1UWEYsh5kW+77URKSwkMsGX/wHNmjVL5ebm5tT29u3b1bHTp08PmV+8eKHy1q1bVS4vL892mJHhjEQmWEhkgoVEJrhGytKGDRtS27Nnz1bH9u7dq7L/FkpjY6PK/f39Kh84cEDlioqKrMcZFs5IZIKFRCZYSGSCt0jywN9K2f/n4du2bVPZ/2+wcqV+z+DFixftBjdMvEVCoWIhkQkWEpngGikCY8eOVfn169cqjxkzRuW2tjaVa2tr8zKud+EaiULFQiITLCQywXttBm7cuKGy/yO4Ojo6VPavifyqq6tVXrp0aQ6jCwdnJDLBQiITLCQywTVSQP6PVD9+/Hhq+/z58+rYgwcPhvXYo0frfwb/e7YLoU1O/EdIBSFIf6RKEbksIl+JSK+I/N7bzxbJlBJkRvoRwF7nXDWAXwDYKSLVYItkShOk0dZ9APe97f+KyNcAKgDUAaj1fuxvANoB/DEvowyBf11z9uxZlZuamlS+e/du1udatGiRyv73aK9fvz7rx47KsNZIXr/tBQCugS2SKU3gQhKRnwD4F4A/OOdUt/OhWiSzPXJxCFRIIjIGySL6u3Pu7WvdQC2S2R65OGRcI0my58pfAXztnPtL2qGCapH88OFDlXt7e1XetWuXyjdv3sz6XP6PJt23b5/KdXV1KhfCdaJMglyQXAJgC4D/iEi3t+9PSBbQP712yf0AfpufIVIhCPKq7QqAwTpBsUUyAeCVbTIyYu61PXnyRGX/x2R1d3er7G/lN1xLlixJbfv/1n/NmjUqjx8/PqdzFQLOSGSChUQmWEhkoqDWSNeuXUttHzlyRB3zvy96YGAgp3NNmDBBZf/Ht6ffH/N/PHsx4oxEJlhIZKKgntpaW1vfuR2E/0981q1bp3IikVC5oaFBZX93f9I4I5EJFhKZYCGRCba1oSGxrQ2FioVEJlhIZIKFRCZYSGSChUQmWEhkgoVEJlhIZIKFRCZYSGQi1HttIvIYyb/KLQPwQ2gnHp64ji2qcc10zmVs2hBqIaVOKtIZ5EZgFOI6triO6y0+tZEJFhKZiKqQPonovEHEdWxxHReAiNZINPLwqY1MhFpIIrJWRPpE5LaIRNpOWUROicgjEelJ2xeL3uGF2Ns8tEISkQSAEwB+A6AawEavX3dUmgGs9e2LS+/wwutt7pwL5QvALwG0peVGAI1hnX+QMVUB6EnLfQDKve1yAH1Rji9tXJ8CWBXX8TnnQn1qqwDwXVoe8PbFSex6hxdKb3Mutgfhkv/tI31Jm21v8yiEWUjfA6hMy+97++IkUO/wMOTS2zwKYRZSB4A5IvKBiJQAqEeyV3ecvO0dDkTYOzxAb3Mgbr3NQ140fgjgGwB3AByIeAHbguSH9bxGcr32EYCfIvlq6BaASwBKIxrbr5B82roBoNv7+jAu43vXF69skwkutskEC4lMsJDIBAuJTLCQyAQLiUywkMgEC4lM/A+jN2A4bkW+2gAAAABJRU5ErkJggg==\n", |
| 105 | + "text/plain": [ |
| 106 | + "<Figure size 144x144 with 1 Axes>" |
| 107 | + ] |
| 108 | + }, |
| 109 | + "metadata": {}, |
| 110 | + "output_type": "display_data" |
101 | 111 | }
|
102 | 112 | ],
|
103 | 113 | "source": [
|
|
122 | 132 | },
|
123 | 133 | {
|
124 | 134 | "cell_type": "code",
|
125 |
| - "execution_count": 4, |
| 135 | + "execution_count": null, |
126 | 136 | "metadata": {
|
127 | 137 | "collapsed": false
|
128 | 138 | },
|
|
178 | 188 | },
|
179 | 189 | {
|
180 | 190 | "cell_type": "code",
|
181 |
| - "execution_count": 5, |
| 191 | + "execution_count": null, |
182 | 192 | "metadata": {
|
183 | 193 | "collapsed": false
|
184 | 194 | },
|
185 |
| - "outputs": [ |
186 |
| - { |
187 |
| - "name": "stderr", |
188 |
| - "output_type": "stream", |
189 |
| - "text": [ |
190 |
| - "W0422 18:56:10.020583 19533 gpu_context.cc:244] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n", |
191 |
| - "W0422 18:56:10.026566 19533 gpu_context.cc:272] device: 0, cuDNN Version: 7.6.\n" |
192 |
| - ] |
193 |
| - } |
194 |
| - ], |
| 195 | + "outputs": [], |
195 | 196 | "source": [
|
196 | 197 | "from paddle.metric import Accuracy\n",
|
197 | 198 | "model = paddle.Model(LeNet()) # 用Model封装模型\n",
|
|
207 | 208 | },
|
208 | 209 | {
|
209 | 210 | "cell_type": "code",
|
210 |
| - "execution_count": 6, |
| 211 | + "execution_count": null, |
211 | 212 | "metadata": {
|
212 | 213 | "collapsed": false
|
213 | 214 | },
|
214 |
| - "outputs": [ |
215 |
| - { |
216 |
| - "name": "stdout", |
217 |
| - "output_type": "stream", |
218 |
| - "text": [ |
219 |
| - "The loss value printed in the log is the current step, and the metric is the average value of previous steps.\n", |
220 |
| - "Epoch 1/2\n", |
221 |
| - "step 20/938 [..............................] - loss: 1.4646 - acc: 0.3828 - ETA: 17s - 19ms/ste" |
222 |
| - ] |
223 |
| - }, |
224 |
| - { |
225 |
| - "name": "stdout", |
226 |
| - "output_type": "stream", |
227 |
| - "text": [ |
228 |
| - "step 30/938 [..............................] - loss: 1.1068 - acc: 0.4672 - ETA: 14s - 16ms/stepstep 938/938 [==============================] - loss: 0.1653 - acc: 0.9273 - 11ms/step \n", |
229 |
| - "Epoch 2/2\n", |
230 |
| - "step 938/938 [==============================] - loss: 0.0199 - acc: 0.9767 - 11ms/step \n" |
231 |
| - ] |
232 |
| - } |
233 |
| - ], |
| 215 | + "outputs": [], |
234 | 216 | "source": [
|
235 | 217 | "# 训练模型\n",
|
236 | 218 | "model.fit(train_dataset,\n",
|
|
251 | 233 | },
|
252 | 234 | {
|
253 | 235 | "cell_type": "code",
|
254 |
| - "execution_count": 7, |
| 236 | + "execution_count": null, |
255 | 237 | "metadata": {
|
256 | 238 | "collapsed": false
|
257 | 239 | },
|
|
261 | 243 | "output_type": "stream",
|
262 | 244 | "text": [
|
263 | 245 | "Eval begin...\n",
|
264 |
| - "step 157/157 [==============================] - loss: 0.0048 - acc: 0.9780 - 8ms/step \n", |
| 246 | + "step 157/157 [==============================] - loss: 4.2854e-04 - acc: 0.9841 - 7ms/step \n", |
265 | 247 | "Eval samples: 10000\n"
|
266 | 248 | ]
|
267 | 249 | },
|
268 | 250 | {
|
269 | 251 | "data": {
|
270 | 252 | "text/plain": [
|
271 |
| - "{'loss': [0.0047780997], 'acc': 0.978}" |
| 253 | + "{'loss': [0.00042853763], 'acc': 0.9841}" |
272 | 254 | ]
|
273 | 255 | },
|
274 |
| - "execution_count": 7, |
| 256 | + "execution_count": null, |
275 | 257 | "metadata": {},
|
276 | 258 | "output_type": "execute_result"
|
277 | 259 | }
|
|
303 | 285 | },
|
304 | 286 | {
|
305 | 287 | "cell_type": "code",
|
306 |
| - "execution_count": 8, |
| 288 | + "execution_count": null, |
307 | 289 | "metadata": {
|
308 | 290 | "collapsed": false
|
309 | 291 | },
|
|
312 | 294 | "name": "stdout",
|
313 | 295 | "output_type": "stream",
|
314 | 296 | "text": [
|
315 |
| - "epoch: 0, batch_id: 0, loss is: [3.7514806], acc is: [0.21875]\n", |
316 |
| - "epoch: 0, batch_id: 300, loss is: [0.19029362], acc is: [0.953125]\n", |
317 |
| - "epoch: 0, batch_id: 600, loss is: [0.12201739], acc is: [0.953125]\n", |
318 |
| - "epoch: 0, batch_id: 900, loss is: [0.03218058], acc is: [0.984375]\n", |
319 |
| - "epoch: 1, batch_id: 0, loss is: [0.114471], acc is: [0.953125]\n", |
320 |
| - "epoch: 1, batch_id: 300, loss is: [0.00857661], acc is: [1.]\n", |
321 |
| - "epoch: 1, batch_id: 600, loss is: [0.10740176], acc is: [0.96875]\n", |
322 |
| - "epoch: 1, batch_id: 900, loss is: [0.19590104], acc is: [0.9375]\n" |
| 297 | + "epoch: 0, batch_id: 0, loss is: [2.9878871], acc is: [0.140625]\n", |
| 298 | + "epoch: 0, batch_id: 300, loss is: [0.22775462], acc is: [0.921875]\n", |
| 299 | + "epoch: 0, batch_id: 600, loss is: [0.06251755], acc is: [0.984375]\n", |
| 300 | + "epoch: 0, batch_id: 900, loss is: [0.1097075], acc is: [0.96875]\n", |
| 301 | + "epoch: 1, batch_id: 0, loss is: [0.04311676], acc is: [0.984375]\n", |
| 302 | + "epoch: 1, batch_id: 300, loss is: [0.00150577], acc is: [1.]\n", |
| 303 | + "epoch: 1, batch_id: 600, loss is: [0.08764459], acc is: [0.96875]\n", |
| 304 | + "epoch: 1, batch_id: 900, loss is: [0.14419323], acc is: [0.9375]\n" |
323 | 305 | ]
|
324 | 306 | }
|
325 | 307 | ],
|
|
361 | 343 | },
|
362 | 344 | {
|
363 | 345 | "cell_type": "code",
|
364 |
| - "execution_count": 9, |
| 346 | + "execution_count": null, |
365 | 347 | "metadata": {
|
366 | 348 | "collapsed": false
|
367 | 349 | },
|
|
370 | 352 | "name": "stdout",
|
371 | 353 | "output_type": "stream",
|
372 | 354 | "text": [
|
373 |
| - "batch_id: 0, loss is: [0.04440754], acc is: [0.984375]\n", |
374 |
| - "batch_id: 20, loss is: [0.19196557], acc is: [0.9375]\n", |
375 |
| - "batch_id: 40, loss is: [0.09817676], acc is: [0.984375]\n", |
376 |
| - "batch_id: 60, loss is: [0.16782945], acc is: [0.953125]\n", |
377 |
| - "batch_id: 80, loss is: [0.05786889], acc is: [0.96875]\n", |
378 |
| - "batch_id: 100, loss is: [0.00799548], acc is: [1.]\n", |
379 |
| - "batch_id: 120, loss is: [0.00511317], acc is: [1.]\n", |
380 |
| - "batch_id: 140, loss is: [0.01672031], acc is: [1.]\n" |
| 355 | + "batch_id: 0, loss is: [0.01201783], acc is: [1.]\n", |
| 356 | + "batch_id: 20, loss is: [0.09013407], acc is: [0.984375]\n", |
| 357 | + "batch_id: 40, loss is: [0.07025866], acc is: [0.96875]\n", |
| 358 | + "batch_id: 60, loss is: [0.08602518], acc is: [0.984375]\n", |
| 359 | + "batch_id: 80, loss is: [0.00779913], acc is: [1.]\n", |
| 360 | + "batch_id: 100, loss is: [0.00508764], acc is: [1.]\n", |
| 361 | + "batch_id: 120, loss is: [0.00401443], acc is: [1.]\n", |
| 362 | + "batch_id: 140, loss is: [0.03930391], acc is: [0.96875]\n" |
381 | 363 | ]
|
382 | 364 | }
|
383 | 365 | ],
|
|
0 commit comments