Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ topic_model.fit(X, anchors=[[0, 2], 1], anchor_strength=2)

anchors the features of columns 0 and 2 to the first topic, and feature 1 to the second topic.

Different anchor strengths can be specified for each topic, and further for each word in a topic. The below example uses different anchor strengths for "dog" and "cat" in the first topic, and the same anchor strength for "apple" and "pear" in the second:

```python
topic_model.fit(X, words=words, anchors=[['dog','cat'], ['apple','pear'], 'building'], anchor_strength=[[2, 2.5], 2, 1.5])
```

### Anchoring Strategies

In our TACL paper, we explore several anchoring strategies:
Expand Down
14 changes: 12 additions & 2 deletions corextopic/corextopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,18 @@ def fit_transform(self, X, y=None, anchors=None, anchor_strength=1, words=None,
if anchors is not None:
for a in flatten(anchors):
self.alpha[:, a] = 0
for ia, a in enumerate(anchors):
self.alpha[ia, a] = anchor_strength
if type(anchor_strength) != list:
for ia, a in enumerate(anchors):
self.alpha[ia, a] = anchor_strength
else:
assert len(anchors) == len(anchor_strength), 'Number of topics and number of anchor strengths do not match'
for ia, (a, a_s) in enumerate(zip(anchors, anchor_strength)):
if type(a_s) == list:
if len(a_s) == 1:
a_s = a_s[0]
else:
assert len(a_s) == len(a), 'Number of anchor strengths does not match number of seeds for topic number ' + str(ia + 1)
self.alpha[ia, a] = a_s

p_y_given_x, _, log_z = self.calculate_latent(X, self.theta)

Expand Down
46 changes: 45 additions & 1 deletion corextopic/example/corex_topic_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"source": [
"**Author:** [Ryan J. Gallagher](http://ryanjgallagher.github.io/) \n",
"\n",
"**Last updated:** 07/21/2018"
"**Last updated:** 05/22/2020"
]
},
{
Expand Down Expand Up @@ -936,6 +936,50 @@
"**Note:** If you do not specify the column labels through `words`, then you can still anchor by specifying the column indices of the features you wish to anchor on. You may also specify anchors using a mix of strings and indices if desired."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The anchor strength can either be set globally for all topics and words, as above, or independently for each topic. \n",
"In the latter case, each word can be given its own strength or the same value can be used for each word in the topic.\n",
"\n",
"In the example below, the second topic anchoring \"nasa\" and \"politics\" weights \"politics\" more highly than \"nasa\", while the third topic weights them equally."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"anchor_strengths = [6, [6, 7], 7, 6]\n",
"\n",
"anchored_topic_model = ct.Corex(n_hidden=50, seed=2)\n",
"anchored_topic_model.fit(doc_word, words=words, anchors=anchor_words, anchor_strength=anchor_strengths);"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0: nasa,space,orbit,technology,shuttle,launch,development,moon,commercial,mission\n",
"1: nasa,politics,gov,ames,jsc,comprehensive,arc,lewis,larc,shafer\n",
"2: news,nasa,insisting,pasadena,advertising,edwards,hal,llnl,part1,cfv\n",
"3: war,israel,armenians,israeli,armenian,jews,soldiers,military,killed,argic\n"
]
}
],
"source": [
"for n in range(len(anchor_words)):\n",
" topic_words,_ = zip(*anchored_topic_model.get_topics(topic=n))\n",
" print('{}: '.format(n) + ','.join(topic_words))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down