|
28 | 28 | "source": [
|
29 | 29 | "#hide\n",
|
30 | 30 | "import torch\n",
|
31 |
| - "import torch.nn as nn\n", |
32 |
| - "\n", |
33 |
| - "from nbdev.showdoc import show_doc\n", |
34 |
| - "from IPython.display import Markdown, display" |
35 |
| - ] |
36 |
| - }, |
37 |
| - { |
38 |
| - "cell_type": "code", |
39 |
| - "execution_count": null, |
40 |
| - "metadata": {}, |
41 |
| - "outputs": [], |
42 |
| - "source": [ |
43 |
| - "# hide\n", |
44 |
| - "def print_doc(func_name):\n", |
45 |
| - " doc = show_doc(func_name, title_level=4, disp=False)\n", |
46 |
| - " display(Markdown(doc))" |
| 31 | + "import torch.nn as nn" |
47 | 32 | ]
|
48 | 33 | },
|
49 | 34 | {
|
|
59 | 44 | "metadata": {},
|
60 | 45 | "outputs": [],
|
61 | 46 | "source": [
|
62 |
| - "#hide\n", |
63 | 47 | "from model_constructor.yaresnet import YaResBlock"
|
64 | 48 | ]
|
65 | 49 | },
|
66 |
| - { |
67 |
| - "cell_type": "code", |
68 |
| - "execution_count": null, |
69 |
| - "metadata": {}, |
70 |
| - "outputs": [], |
71 |
| - "source": [ |
72 |
| - "#hide_input\n", |
73 |
| - "# print_doc(YaResBlock)" |
74 |
| - ] |
75 |
| - }, |
76 | 50 | {
|
77 | 51 | "cell_type": "code",
|
78 | 52 | "execution_count": null,
|
|
341 | 315 | " (se): SEModule(\n",
|
342 | 316 | " (squeeze): AdaptiveAvgPool2d(output_size=1)\n",
|
343 | 317 | " (excitation): Sequential(\n",
|
344 |
| - " (fc_reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
| 318 | + " (reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
345 | 319 | " (se_act): ReLU(inplace=True)\n",
|
346 |
| - " (fc_expand): Linear(in_features=32, out_features=512, bias=True)\n", |
| 320 | + " (expand): Linear(in_features=32, out_features=512, bias=True)\n", |
347 | 321 | " (se_gate): Sigmoid()\n",
|
348 | 322 | " )\n",
|
349 | 323 | " )\n",
|
|
443 | 417 | " (se): SEModule(\n",
|
444 | 418 | " (squeeze): AdaptiveAvgPool2d(output_size=1)\n",
|
445 | 419 | " (excitation): Sequential(\n",
|
446 |
| - " (fc_reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
| 420 | + " (reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
447 | 421 | " (se_act): ReLU(inplace=True)\n",
|
448 |
| - " (fc_expand): Linear(in_features=32, out_features=512, bias=True)\n", |
| 422 | + " (expand): Linear(in_features=32, out_features=512, bias=True)\n", |
449 | 423 | " (se_gate): Sigmoid()\n",
|
450 | 424 | " )\n",
|
451 | 425 | " )\n",
|
|
468 | 442 | ],
|
469 | 443 | "source": [
|
470 | 444 | "#collapse_output\n",
|
471 |
| - "bl = YaResBlock(4, 64, 128, stride=2, pool=pool, act_fn=nn.LeakyReLU(), dw=True,\n", |
472 |
| - " se=SEModule, sa=SimpleSelfAttention)\n", |
| 445 | + "bl = YaResBlock(\n", |
| 446 | + " 4, 64, 128,\n", |
| 447 | + " stride=2,\n", |
| 448 | + " pool=pool,\n", |
| 449 | + " act_fn=nn.LeakyReLU(),\n", |
| 450 | + " dw=True,\n", |
| 451 | + " se=SEModule,\n", |
| 452 | + " sa=SimpleSelfAttention)\n", |
473 | 453 | "bl"
|
474 | 454 | ]
|
475 | 455 | },
|
|
528 | 508 | {
|
529 | 509 | "data": {
|
530 | 510 | "text/plain": [
|
531 |
| - "([64, 64, 128, 256, 512], [2, 2, 2, 2])" |
| 511 | + "([64, 128, 256, 512], [2, 2, 2, 2])" |
532 | 512 | ]
|
533 | 513 | },
|
534 | 514 | "execution_count": null,
|
|
982 | 962 | " (se): SEModule(\n",
|
983 | 963 | " (squeeze): AdaptiveAvgPool2d(output_size=1)\n",
|
984 | 964 | " (excitation): Sequential(\n",
|
985 |
| - " (fc_reduce): Linear(in_features=64, out_features=4, bias=True)\n", |
| 965 | + " (reduce): Linear(in_features=64, out_features=4, bias=True)\n", |
986 | 966 | " (se_act): ReLU(inplace=True)\n",
|
987 |
| - " (fc_expand): Linear(in_features=4, out_features=64, bias=True)\n", |
| 967 | + " (expand): Linear(in_features=4, out_features=64, bias=True)\n", |
988 | 968 | " (se_gate): Sigmoid()\n",
|
989 | 969 | " )\n",
|
990 | 970 | " )\n",
|
|
1003 | 983 | "yaresnet.body.l_0.bl_0"
|
1004 | 984 | ]
|
1005 | 985 | },
|
| 986 | + { |
| 987 | + "cell_type": "markdown", |
| 988 | + "metadata": {}, |
| 989 | + "source": [ |
| 990 | + "# YaResnet34, YaResnet50" |
| 991 | + ] |
| 992 | + }, |
| 993 | + { |
| 994 | + "cell_type": "code", |
| 995 | + "execution_count": null, |
| 996 | + "metadata": {}, |
| 997 | + "outputs": [], |
| 998 | + "source": [ |
| 999 | + "from model_constructor.yaresnet import YaResNet34, YaResNet50" |
| 1000 | + ] |
| 1001 | + }, |
| 1002 | + { |
| 1003 | + "cell_type": "code", |
| 1004 | + "execution_count": null, |
| 1005 | + "metadata": {}, |
| 1006 | + "outputs": [ |
| 1007 | + { |
| 1008 | + "data": { |
| 1009 | + "text/plain": [ |
| 1010 | + "YaResnet34 constructor\n", |
| 1011 | + " in_chans: 3, num_classes: 1000\n", |
| 1012 | + " expansion: 1, groups: 1, dw: False, div_groups: None\n", |
| 1013 | + " sa: False, se: False\n", |
| 1014 | + " stem sizes: [3, 32, 32, 64], stride on 0\n", |
| 1015 | + " body sizes [64, 128, 256, 512]\n", |
| 1016 | + " layers: [3, 4, 6, 3]" |
| 1017 | + ] |
| 1018 | + }, |
| 1019 | + "execution_count": null, |
| 1020 | + "metadata": {}, |
| 1021 | + "output_type": "execute_result" |
| 1022 | + } |
| 1023 | + ], |
| 1024 | + "source": [ |
| 1025 | + "yaresnet34 = YaResNet34()\n", |
| 1026 | + "yaresnet34" |
| 1027 | + ] |
| 1028 | + }, |
| 1029 | + { |
| 1030 | + "cell_type": "code", |
| 1031 | + "execution_count": null, |
| 1032 | + "metadata": {}, |
| 1033 | + "outputs": [ |
| 1034 | + { |
| 1035 | + "data": { |
| 1036 | + "text/plain": [ |
| 1037 | + "YaResnet50 constructor\n", |
| 1038 | + " in_chans: 3, num_classes: 1000\n", |
| 1039 | + " expansion: 4, groups: 1, dw: False, div_groups: None\n", |
| 1040 | + " sa: False, se: False\n", |
| 1041 | + " stem sizes: [3, 32, 32, 64], stride on 0\n", |
| 1042 | + " body sizes [64, 128, 256, 512]\n", |
| 1043 | + " layers: [3, 4, 6, 3]" |
| 1044 | + ] |
| 1045 | + }, |
| 1046 | + "execution_count": null, |
| 1047 | + "metadata": {}, |
| 1048 | + "output_type": "execute_result" |
| 1049 | + } |
| 1050 | + ], |
| 1051 | + "source": [ |
| 1052 | + "yaresnet50 = YaResNet50()\n", |
| 1053 | + "yaresnet50" |
| 1054 | + ] |
| 1055 | + }, |
1006 | 1056 | {
|
1007 | 1057 | "cell_type": "markdown",
|
1008 | 1058 | "metadata": {},
|
|
0 commit comments