diff --git a/notebooks/08.CNN-for-Text-Classification.ipynb b/notebooks/08.CNN-for-Text-Classification.ipynb index 020827f..c1c7d3d 100644 --- a/notebooks/08.CNN-for-Text-Classification.ipynb +++ b/notebooks/08.CNN-for-Text-Classification.ipynb @@ -26,10 +26,8 @@ }, { "cell_type": "code", - "execution_count": 58, - "metadata": { - "collapsed": true - }, + "execution_count": 1, + "metadata": {}, "outputs": [], "source": [ "import torch\n", @@ -51,14 +49,13 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "USE_CUDA = torch.cuda.is_available()\n", "gpus = [0]\n", - "torch.cuda.set_device(gpus[0])\n", + "if USE_CUDA:\n", + " torch.cuda.set_device(gpus[0])\n", "\n", "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", @@ -68,9 +65,7 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "def getBatch(batch_size, train_data):\n", @@ -91,10 +86,8 @@ }, { "cell_type": "code", - "execution_count": 110, - "metadata": { - "collapsed": true - }, + "execution_count": 4, + "metadata": {}, "outputs": [], "source": [ "def pad_to_batch(batch):\n", @@ -111,14 +104,15 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": { - "collapsed": true - }, + "execution_count": 5, + "metadata": {}, "outputs": [], "source": [ - "def prepare_sequence(seq, to_index):\n", + "def prepare_sequence(seq, to_index, min_len = 5):\n", + " if len(seq) < min_len:\n", + " seq.extend( (min_len - len(seq)) * [\"\"])\n", " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n", + " \n", " return Variable(LongTensor(idxs))" ] }, @@ -148,32 +142,70 @@ }, { "cell_type": "code", - "execution_count": 53, - "metadata": { - "collapsed": true - }, + "execution_count": 6, + "metadata": {}, "outputs": [], "source": [ - "data = open('../dataset/train_5500.label.txt', 'r', encoding='latin-1').readlines()" + "data_lines = open('../dataset/train_5500.label.txt', 'r', encoding='latin-1').readlines()" ] }, { "cell_type": "code", - "execution_count": 54, - "metadata": { - "collapsed": true - }, + "execution_count": 7, + "metadata": {}, "outputs": [], "source": [ - "data = [[d.split(':')[1][:-1], d.split(':')[0]] for d in data]" + "data = []\n", + "for line in data_lines:\n", + " fields = line.strip().split()\n", + " qtype = fields[0] # e.g. DESC:manner \n", + " qtype_main = qtype.split(':')[0] # We use only main question type, not sub-type\n", + " text = ' '.join(fields[1:-1]) # we exclude the last punctuation sign\n", + " data.append( (text, qtype_main) )" ] }, { "cell_type": "code", - "execution_count": 61, - "metadata": { - "collapsed": true - }, + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data size: 5452\n" + ] + }, + { + "data": { + "text/plain": [ + "[('What films featured the character Popeye Doyle', 'ENTY'),\n", + " (\"How can I find a list of celebrities ' real names\", 'DESC'),\n", + " ('What fowl grabs the spotlight after the Chinese Year of the Monkey',\n", + " 'ENTY'),\n", + " ('What is the full form of .com', 'ABBR'),\n", + " ('What contemptible scoundrel stole the cork from my lunch', 'HUM'),\n", + " (\"What team did baseball 's St. Louis Browns become\", 'HUM'),\n", + " ('What is the oldest profession', 'HUM'),\n", + " ('What are liver enzymes', 'DESC'),\n", + " ('Name the scar-faced bounty hunter of The Old West', 'HUM')]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Let's look at the data!\n", + "print('Data size:', len(data))\n", + "data[1:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [], "source": [ "X, y = list(zip(*data))\n", @@ -196,10 +228,8 @@ }, { "cell_type": "code", - "execution_count": 62, - "metadata": { - "collapsed": true - }, + "execution_count": 10, + "metadata": {}, "outputs": [], "source": [ "for i, x in enumerate(X):\n", @@ -215,10 +245,8 @@ }, { "cell_type": "code", - "execution_count": 63, - "metadata": { - "collapsed": true - }, + "execution_count": 11, + "metadata": {}, "outputs": [], "source": [ "vocab = list(set(flatten(X)))" @@ -226,18 +254,16 @@ }, { "cell_type": "code", - "execution_count": 64, - "metadata": { - "collapsed": false - }, + "execution_count": 12, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "9117" + "9201" ] }, - "execution_count": 64, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -248,10 +274,8 @@ }, { "cell_type": "code", - "execution_count": 31, - "metadata": { - "collapsed": false - }, + "execution_count": 13, + "metadata": {}, "outputs": [ { "data": { @@ -259,7 +283,7 @@ "6" ] }, - "execution_count": 31, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -270,10 +294,8 @@ }, { "cell_type": "code", - "execution_count": 94, - "metadata": { - "collapsed": true - }, + "execution_count": 14, + "metadata": {}, "outputs": [], "source": [ "word2index={'': 0, '': 1}\n", @@ -295,10 +317,8 @@ }, { "cell_type": "code", - "execution_count": 95, - "metadata": { - "collapsed": true - }, + "execution_count": 15, + "metadata": {}, "outputs": [], "source": [ "X_p, y_p = [], []\n", @@ -329,10 +349,8 @@ }, { "cell_type": "code", - "execution_count": 41, - "metadata": { - "collapsed": true - }, + "execution_count": 16, + "metadata": {}, "outputs": [], "source": [ "import gensim" @@ -340,21 +358,26 @@ }, { "cell_type": "code", - "execution_count": 43, - "metadata": { - "collapsed": true - }, - "outputs": [], + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/lib/python3.6/site-packages/smart_open/smart_open_lib.py:398: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n", + " 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n" + ] + } + ], "source": [ "model = gensim.models.KeyedVectors.load_word2vec_format('../dataset/GoogleNews-vectors-negative300.bin', binary=True)" ] }, { "cell_type": "code", - "execution_count": 48, - "metadata": { - "collapsed": false - }, + "execution_count": 18, + "metadata": {}, "outputs": [ { "data": { @@ -362,7 +385,7 @@ "3000000" ] }, - "execution_count": 48, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -373,10 +396,8 @@ }, { "cell_type": "code", - "execution_count": 96, - "metadata": { - "collapsed": true - }, + "execution_count": 19, + "metadata": {}, "outputs": [], "source": [ "pretrained = []\n", @@ -401,16 +422,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", + "\n", "
borrowed image from http://www.aclweb.org/anthology/D14-1181
" ] }, { "cell_type": "code", - "execution_count": 117, - "metadata": { - "collapsed": true - }, + "execution_count": 20, + "metadata": {}, "outputs": [], "source": [ "class CNNClassifier(nn.Module):\n", @@ -456,15 +475,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It takes for a while if you use just cpu." + "It is quite fast even on CPU." ] }, { "cell_type": "code", - "execution_count": 145, - "metadata": { - "collapsed": true - }, + "execution_count": 21, + "metadata": {}, "outputs": [], "source": [ "EPOCH = 5\n", @@ -476,10 +493,8 @@ }, { "cell_type": "code", - "execution_count": 146, - "metadata": { - "collapsed": true - }, + "execution_count": 22, + "metadata": {}, "outputs": [], "source": [ "model = CNNClassifier(len(word2index), 300, len(target2index), KERNEL_DIM, KERNEL_SIZES)\n", @@ -494,20 +509,18 @@ }, { "cell_type": "code", - "execution_count": 147, - "metadata": { - "collapsed": false - }, + "execution_count": 23, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[0/5] mean_loss : 2.13\n", - "[1/5] mean_loss : 0.12\n", - "[2/5] mean_loss : 0.08\n", - "[3/5] mean_loss : 0.02\n", - "[4/5] mean_loss : 0.05\n" + "[0/5] mean_loss : 2.07\n", + "[1/5] mean_loss : 0.68\n", + "[2/5] mean_loss : 0.42\n", + "[3/5] mean_loss : 0.25\n", + "[4/5] mean_loss : 0.11\n" ] } ], @@ -521,7 +534,7 @@ " preds = model(inputs, True)\n", " \n", " loss = loss_function(preds, targets)\n", - " losses.append(loss.data.tolist()[0])\n", + " losses.append(loss.data.item())\n", " loss.backward()\n", " \n", " #for param in model.parameters():\n", @@ -543,35 +556,26 @@ }, { "cell_type": "code", - "execution_count": 150, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "accuracy = 0" - ] - }, - { - "cell_type": "code", - "execution_count": 151, - "metadata": { - "collapsed": false - }, + "execution_count": 24, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "97.61904761904762\n" + "85.8974358974359\n" ] } ], "source": [ + "accuracy = 0\n", + "qty = 0\n", + "\n", "for test in test_data:\n", " pred = model(test[0]).max(1)[1]\n", " pred = pred.data.tolist()[0]\n", " target = test[1].data.tolist()[0][0]\n", + "\n", " if pred == target:\n", " accuracy += 1\n", "\n", @@ -610,15 +614,6 @@ "* Bag of Tricks for Efficient Text Classification\n", "* Which Encoding is the Best for Text Classification in Chinese, English, Japanese and Korean?" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] } ], "metadata": { @@ -637,7 +632,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.2" + "version": "3.6.5" } }, "nbformat": 4, diff --git a/notebooks/images b/notebooks/images new file mode 120000 index 0000000..e4c5bd0 --- /dev/null +++ b/notebooks/images @@ -0,0 +1 @@ +../images/ \ No newline at end of file diff --git a/script/prepare_dataset.sh b/script/prepare_dataset.sh index f0b1a2a..e9bdbeb 100755 --- a/script/prepare_dataset.sh +++ b/script/prepare_dataset.sh @@ -18,7 +18,7 @@ wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt" wget "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt" -P "$OUT_DIR/ptb" echo "download TREC question dataset..." -curl -o "$OUT_DIR/train_5500.label.txt" "http://cogcomp.org/Data/QA/QC/train_5500.label" +curl -o "$OUT_DIR/train_5500.label.txt" "https://cogcomp.seas.upenn.edu/Data/QA/QC/train_5500.label" echo "download Stanford sentment treebank..." curl -o "$OUT_DIR/trainDevTestTrees_PTB.zip" "https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip" @@ -38,4 +38,4 @@ echo "download dependency parser dataset... (clone from https://github.com/rguth mkdir -v -p "$OUT_DIR/dparser" wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/train.txt" -P "$OUT_DIR/dparser" wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/vocab.txt" -P "$OUT_DIR/dparser" -wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/dev.txt" -P "$OUT_DIR/dparser" \ No newline at end of file +wget "https://raw.githubusercontent.com/rguthrie3/DeepDependencyParsingProblemSet/master/data/dev.txt" -P "$OUT_DIR/dparser"