From cd72dfe2e22bb85430da0680ab1722c7220cfef8 Mon Sep 17 00:00:00 2001 From: Guy Aglionby Date: Fri, 22 May 2020 18:28:25 +0100 Subject: [PATCH] More flexibility in setting anchor strength (fixes #16) --- README.md | 6 +++ corextopic/corextopic.py | 14 +++++- corextopic/example/corex_topic_example.ipynb | 46 +++++++++++++++++++- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 359aa4a..7ec88b8 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,12 @@ topic_model.fit(X, anchors=[[0, 2], 1], anchor_strength=2) anchors the features of columns 0 and 2 to the first topic, and feature 1 to the second topic. +Different anchor strengths can be specified for each topic, and further for each word in a topic. The below example uses different anchor strengths for "dog" and "cat" in the first topic, and the same anchor strength for "apple" and "pear" in the second: + +```python +topic_model.fit(X, words=words, anchors=[['dog','cat'], ['apple','pear'], 'building'], anchor_strength=[[2, 2.5], 2, 1.5]) +``` + ### Anchoring Strategies In our TACL paper, we explore several anchoring strategies: diff --git a/corextopic/corextopic.py b/corextopic/corextopic.py index 0055b07..b10204f 100644 --- a/corextopic/corextopic.py +++ b/corextopic/corextopic.py @@ -198,8 +198,18 @@ def fit_transform(self, X, y=None, anchors=None, anchor_strength=1, words=None, if anchors is not None: for a in flatten(anchors): self.alpha[:, a] = 0 - for ia, a in enumerate(anchors): - self.alpha[ia, a] = anchor_strength + if type(anchor_strength) != list: + for ia, a in enumerate(anchors): + self.alpha[ia, a] = anchor_strength + else: + assert len(anchors) == len(anchor_strength), 'Number of topics and number of anchor strengths do not match' + for ia, (a, a_s) in enumerate(zip(anchors, anchor_strength)): + if type(a_s) == list: + if len(a_s) == 1: + a_s = a_s[0] + else: + assert len(a_s) == len(a), 'Number of anchor strengths does not match number of seeds for topic number ' + str(ia + 1) + self.alpha[ia, a] = a_s p_y_given_x, _, log_z = self.calculate_latent(X, self.theta) diff --git a/corextopic/example/corex_topic_example.ipynb b/corextopic/example/corex_topic_example.ipynb index 4cf1af8..2df3a39 100644 --- a/corextopic/example/corex_topic_example.ipynb +++ b/corextopic/example/corex_topic_example.ipynb @@ -13,7 +13,7 @@ "source": [ "**Author:** [Ryan J. Gallagher](http://ryanjgallagher.github.io/) \n", "\n", - "**Last updated:** 07/21/2018" + "**Last updated:** 05/22/2020" ] }, { @@ -936,6 +936,50 @@ "**Note:** If you do not specify the column labels through `words`, then you can still anchor by specifying the column indices of the features you wish to anchor on. You may also specify anchors using a mix of strings and indices if desired." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The anchor strength can either be set globally for all topics and words, as above, or independently for each topic. \n", + "In the latter case, each word can be given its own strength or the same value can be used for each word in the topic.\n", + "\n", + "In the example below, the second topic anchoring \"nasa\" and \"politics\" weights \"politics\" more highly than \"nasa\", while the third topic weights them equally." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "anchor_strengths = [6, [6, 7], 7, 6]\n", + "\n", + "anchored_topic_model = ct.Corex(n_hidden=50, seed=2)\n", + "anchored_topic_model.fit(doc_word, words=words, anchors=anchor_words, anchor_strength=anchor_strengths);" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0: nasa,space,orbit,technology,shuttle,launch,development,moon,commercial,mission\n", + "1: nasa,politics,gov,ames,jsc,comprehensive,arc,lewis,larc,shafer\n", + "2: news,nasa,insisting,pasadena,advertising,edwards,hal,llnl,part1,cfv\n", + "3: war,israel,armenians,israeli,armenian,jews,soldiers,military,killed,argic\n" + ] + } + ], + "source": [ + "for n in range(len(anchor_words)):\n", + " topic_words,_ = zip(*anchored_topic_model.get_topics(topic=n))\n", + " print('{}: '.format(n) + ','.join(topic_words))" + ] + }, { "cell_type": "markdown", "metadata": {},