Skip to content

Commit 94d2016

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Updates to IMDB-based graph-NSL tutorial.
- use 2 neighbors for graph regularization (demonstrates use of > 1 nbr) - increase # training epochs from 4 to 10 (reduces variance and instability) - recompute accuracy values for base and graph-regularized model for various supervision ratio values for both the Bi-LSTM and feed-forward NN architectures. PiperOrigin-RevId: 266485487
1 parent b7dae99 commit 94d2016

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

g3doc/tutorials/graph_keras_lstm_imdb.ipynb

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -743,13 +743,13 @@
743743
" ### neural graph learning parameters\n",
744744
" self.distance_type = nsl.configs.DistanceType.L2\n",
745745
" self.graph_regularization_multiplier = 0.1\n",
746-
" self.num_neighbors = 1\n",
746+
" self.num_neighbors = 2\n",
747747
" ### model architecture\n",
748748
" self.num_embedding_dims = 16\n",
749749
" self.num_lstm_dims = 64\n",
750750
" self.num_fc_units = 64\n",
751751
" ### training parameters\n",
752-
" self.train_epochs = 4\n",
752+
" self.train_epochs = 10\n",
753753
" self.batch_size = 128\n",
754754
" ### eval parameters\n",
755755
" self.eval_steps = None # All instances in the test set are evaluated.\n",
@@ -1459,11 +1459,12 @@
14591459
"# Accuracy values for both the Bi-LSTM model and the feed forward NN model have\n",
14601460
"# been precomputed for the following supervision ratios.\n",
14611461
"\n",
1462-
"supervision_ratios = [0.3, 0.15, 0.05, 0.03, 0.01]\n",
1462+
"supervision_ratios = [0.3, 0.15, 0.05, 0.03, 0.02, 0.01, 0.005]\n",
14631463
"\n",
14641464
"model_tags = ['Bi-LSTM model', 'Feed Forward NN model']\n",
1465-
"base_model_accs = [[85, 85, 62, 58, 50], [85, 79, 61, 53, 50]]\n",
1466-
"graph_reg_model_accs = [[85, 84, 76, 63, 51], [85, 79, 73, 62, 50]]\n",
1465+
"base_model_accs = [[84, 84, 83, 80, 65, 52, 50], [87, 86, 76, 74, 67, 52, 51]]\n",
1466+
"graph_reg_model_accs = [[84, 84, 83, 83, 65, 63, 50],\n",
1467+
" [87, 86, 80, 75, 67, 52, 50]]\n",
14671468
"\n",
14681469
"plt.clf() # clear figure\n",
14691470
"\n",
@@ -1498,12 +1499,12 @@
14981499
"It can be observed that as the superivision ratio decreases, model accuracy also\n",
14991500
"decreases. This is true for both the base model and for the graph-regularized\n",
15001501
"model, regardless of the model architecture used. However, notice that the\n",
1501-
"graph-regularized model is consistenly better than the base model -- sometimes\n",
1502-
"by as much as 15% -- and further, as the supervision ratio decreases, the\n",
1503-
"decrease in accuracy is much less for the graph-regularized model than the base\n",
1504-
"model. This is primarily because of semi-supervised learning for the\n",
1505-
"graph-regularized model, where structural similarity among training samples is\n",
1506-
"used in addition to the training samples themselves."
1502+
"graph-regularized model performs better than the base model for both the\n",
1503+
"architectures. In particular, for the Bi-LSTM model, when the supervision ratio\n",
1504+
"is 0.01, the accuracy of the graph-regularized model is **~20%** higher than\n",
1505+
"that of the base model. This is primarily because of semi-supervised learning\n",
1506+
"for the graph-regularized model, where structural similarity among training\n",
1507+
"samples is used in addition to the training samples themselves."
15071508
]
15081509
},
15091510
{

0 commit comments

Comments
 (0)