Skip to content
This repository was archived by the owner on Jul 4, 2023. It is now read-only.

Commit 86a44fd

Browse files
authored
Merge pull request #84 from PetrochukM/index_to_token
PyTorch-NLP 0.5.0
2 parents 49bb1d7 + 7f82397 commit 86a44fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+1124
-855
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ target/
7575
logs/**
7676

7777
# Coverage
78-
coverage/**
78+
coverage/**
7979
cover/
8080

8181
# Data
@@ -94,3 +94,6 @@ data/**
9494

9595
# ReadTheDocs build files
9696
docs/_build
97+
98+
# Python's virtual env
99+
venv

.travis.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
language: python
22
matrix:
33
include:
4+
- python: 3.5
5+
dist: xenial
6+
sudo: true
7+
env: RUN_DOCTEXT=false # Python 3.5 prints differently from Python 3.6
48
- python: 3.6
59
dist: xenial
610
sudo: true
11+
env: RUN_DOCTEXT=true
712
- python: 3.7
813
dist: xenial
914
sudo: true
15+
env: RUN_DOCTEXT=true
1016

1117
cache: pip
1218

README.md

Lines changed: 126 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,45 @@
11
<p align="center"><img width="55%" src="docs/_static/img/logo.svg" /></p>
22

3-
<h3 align="center">Supporting Rapid Prototyping with a Deep Learning NLP Toolkit&nbsp;&nbsp;
4-
<a href="https://twitter.com/intent/tweet?text=Supporting%20rapid%20prototyping%20for%20research,%20PyTorch-NLP%20has%20LAUNCHED,%20a%20deep%20learning%20natural%20language%20processing%20(NLP)%20toolkit!%20&url=https://github.com/PetrochukM/PyTorch-NLP&hashtags=pytorch,nlp,research">
5-
<img style='vertical-align: text-bottom !important;' src="https://img.shields.io/twitter/url/http/shields.io.svg?style=social" alt="Tweet">
6-
</a>
7-
</h3>
3+
<h3 align="center">Basic Utilities for PyTorch NLP Software</h3>
84

9-
PyTorch-NLP, or torchnlp for short, is a library of neural network layers, text processing modules and datasets designed to accelerate Natural Language Processing (NLP) research.
10-
11-
Join our community, add datasets and neural network layers! Chat with us on [Gitter](https://gitter.im/PyTorch-NLP/Lobby) and join the [Google Group](https://groups.google.com/forum/#!forum/pytorch-nlp), we're eager to collaborate with you.
5+
PyTorch-NLP, or `torchnlp` for short, is a library of basic utilities for PyTorch
6+
Natural Language Processing (NLP). `torchnlp` extends PyTorch to provide you with
7+
basic text data processing functions.
128

139
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pytorch-nlp.svg?style=flat-square)
1410
[![Codecov](https://img.shields.io/codecov/c/github/PetrochukM/PyTorch-NLP/master.svg?style=flat-square)](https://codecov.io/gh/PetrochukM/PyTorch-NLP)
1511
[![Downloads](http://pepy.tech/badge/pytorch-nlp)](http://pepy.tech/project/pytorch-nlp)
16-
[![Documentation Status]( https://img.shields.io/readthedocs/pytorchnlp/latest.svg?style=flat-square)](http://pytorchnlp.readthedocs.io/en/latest/?badge=latest&style=flat-square)
12+
[![Documentation Status](https://img.shields.io/readthedocs/pytorchnlp/latest.svg?style=flat-square)](http://pytorchnlp.readthedocs.io/en/latest/?badge=latest&style=flat-square)
1713
[![Build Status](https://img.shields.io/travis/PetrochukM/PyTorch-NLP/master.svg?style=flat-square)](https://travis-ci.org/PetrochukM/PyTorch-NLP)
14+
[![Twitter: PetrochukM](https://img.shields.io/twitter/follow/MPetrochuk.svg?style=social)](https://twitter.com/MPetrochuk)
1815

19-
_Logo by [Chloe Yeo](http://www.yeochloe.com/)_
16+
_Logo by [Chloe Yeo](http://www.yeochloe.com/), Corporate Sponsorship by [WellSaid Labs](https://wellsaidlabs.com/)_
2017

21-
## Installation
18+
## Installation 🐾
2219

23-
Make sure you have Python 3.6+ and PyTorch 1.0+. You can then install `pytorch-nlp` using
20+
Make sure you have Python 3.5+ and PyTorch 1.0+. You can then install `pytorch-nlp` using
2421
pip:
2522

26-
pip install pytorch-nlp
23+
```python
24+
pip install pytorch-nlp
25+
```
2726

2827
Or to install the latest code via:
2928

30-
pip install git+https://github.com/PetrochukM/PyTorch-NLP.git
29+
```python
30+
pip install git+https://github.com/PetrochukM/PyTorch-NLP.git
31+
```
3132

32-
## Docs 📖
33+
## Docs
3334

34-
The complete documentation for PyTorch-NLP is available via [our ReadTheDocs website](https://pytorchnlp.readthedocs.io).
35+
The complete documentation for PyTorch-NLP is available
36+
via [our ReadTheDocs website](https://pytorchnlp.readthedocs.io).
3537

36-
## Basics
38+
## Get Started
3739

38-
Add PyTorch-NLP to your project by following one of the common use cases:
40+
Within an NLP data pipeline, you'll want to implement these basic steps:
3941

40-
### Load a [Dataset](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.datasets.html)
42+
### Load Your Data 🐿
4143

4244
Load the IMDB dataset, for example:
4345

@@ -49,51 +51,133 @@ train = imdb_dataset(train=True)
4951
train[0] # RETURNS: {'text': 'For a movie that gets..', 'sentiment': 'pos'}
5052
```
5153

52-
### Apply [Neural Networks](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html) Layers
54+
Load a custom dataset, for example:
55+
56+
```python
57+
from pathlib import Path
58+
59+
from torchnlp.download import download_file_maybe_extract
60+
61+
directory_path = Path('data/')
62+
train_file_path = Path('trees/train.txt')
63+
64+
download_file_maybe_extract(
65+
url='http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip',
66+
directory=directory_path,
67+
check_files=[train_file_path])
68+
69+
open(directory_path / train_file_path)
70+
```
71+
72+
Don't worry we'll handle caching for you!
5373

54-
For example, from the neural network package, apply state-of-the-art LockedDropout:
74+
### Text To Tensor
75+
76+
Tokenize and encode your text as a tensor. For example, a `WhitespaceEncoder` breaks
77+
text into terms whenever it encounters a whitespace character.
78+
79+
```python
80+
from torchnlp.encoders.text import WhitespaceEncoder
81+
82+
loaded_data = ["now this ain't funny", "so don't you dare laugh"]
83+
encoder = WhitespaceEncoder(loaded_data)
84+
encoded_data = [encoder.encode(example) for example in loaded_data]
85+
```
86+
87+
### Tensor To Batch
88+
89+
With your loaded and encoded data in hand, you'll want to batch your dataset.
5590

5691
```python
5792
import torch
58-
from torchnlp.nn import LockedDropout
93+
from torchnlp.samplers import BucketBatchSampler
94+
from torchnlp.utils import collate_tensors
95+
from torchnlp.encoders.text import stack_and_pad_tensors
5996

60-
input_ = torch.randn(6, 3, 10)
61-
dropout = LockedDropout(0.5)
97+
encoded_data = [torch.randn(2), torch.randn(3), torch.randn(4), torch.randn(5)]
6298

63-
# Apply a LockedDropout to `input_`
64-
dropout(input_) # RETURNS: torch.FloatTensor (6x3x10)
99+
train_sampler = torch.utils.data.sampler.SequentialSampler(encoded_data)
100+
train_batch_sampler = BucketBatchSampler(
101+
train_sampler, batch_size=2, drop_last=False, sort_key=lambda i: encoded_data[i].shape[0])
102+
103+
batches = [[encoded_data[i] for i in batch] for batch in train_batch_sampler]
104+
batches = [collate_tensors(batch, stack_tensors=stack_and_pad_tensors) for batch in batches]
65105
```
66106

67-
### [Encode Text](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.encoders.text.html)
107+
PyTorch-NLP builds on top of PyTorch's existing `torch.utils.data.sampler`, `torch.stack`
108+
and `default_collate` to support sequential inputs of varying lengths!
109+
110+
### Your Good To Go!
111+
112+
With your batch in hand, you can use PyTorch to develop and train your model using gradient descent.
113+
114+
### Last But Not Least
68115

69-
Tokenize and encode text as a tensor. For example, a `WhitespaceEncoder` breaks text into terms whenever it encounters a whitespace character.
116+
PyTorch-NLP has a couple more NLP focused utility packages to support you! 🤗
117+
118+
#### Deterministic Functions
119+
120+
Now you've setup your pipeline, you may want to ensure that some functions run deterministically.
121+
Wrap any code that's random, with `fork_rng` and you'll be good to go, like so:
70122

71123
```python
124+
import random
125+
import numpy
126+
import torch
127+
128+
from torchnlp.random import fork_rng
129+
130+
with fork_rng(seed=123): # Ensure determinism
131+
print('Random:', random.randint(1, 2**31))
132+
print('Numpy:', numpy.random.randint(1, 2**31))
133+
print('Torch:', int(torch.randint(1, 2**31, (1,))))
134+
```
135+
136+
This will always print:
137+
138+
```text
139+
Random: 224899943
140+
Numpy: 843828735
141+
Torch: 843828736
142+
```
143+
144+
#### Pre-Trained Word Vectors
145+
146+
Now that you've computed your vocabulary, you may want to make use of
147+
pre-trained word vectors, like so:
148+
149+
```python
150+
import torch
72151
from torchnlp.encoders.text import WhitespaceEncoder
152+
from torchnlp.word_to_vector import GloVe
73153

74-
# Create a `WhitespaceEncoder` with a corpus of text
75154
encoder = WhitespaceEncoder(["now this ain't funny", "so don't you dare laugh"])
76155

77-
# Encode and decode phrases
78-
encoder.encode("this ain't funny.") # RETURNS: torch.Tensor([6, 7, 1])
79-
encoder.decode(encoder.encode("This ain't funny.")) # RETURNS: "this ain't funny."
156+
vocab = set(encoder.vocab)
157+
pretrained_embedding = GloVe(name='6B', dim=100, is_include=lambda w: w in vocab)
158+
embedding_weights = torch.Tensor(encoder.vocab_size, pretrained_embedding.dim)
159+
for i, token in enumerate(encoder.vocab):
160+
embedding_weights[i] = pretrained_embedding[token]
80161
```
81162

82-
### Load [Word Vectors](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.word_to_vector.html)
163+
#### Neural Networks Layers
83164

84-
For example, load FastText, state-of-the-art English word vectors:
165+
For example, from the neural network package, apply the state-of-the-art `LockedDropout`:
85166

86167
```python
87-
from torchnlp.word_to_vector import FastText
168+
import torch
169+
from torchnlp.nn import LockedDropout
170+
171+
input_ = torch.randn(6, 3, 10)
172+
dropout = LockedDropout(0.5)
88173

89-
vectors = FastText()
90-
# Load vectors for any word as a `torch.FloatTensor`
91-
vectors['hello'] # RETURNS: [torch.FloatTensor of size 300]
174+
# Apply a LockedDropout to `input_`
175+
dropout(input_) # RETURNS: torch.FloatTensor (6x3x10)
92176
```
93177

94-
### Compute [Metrics](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.metrics.html)
178+
#### Metrics
95179

96-
Finally, compute common metrics such as the BLEU score.
180+
Compute common NLP metrics such as the BLEU score.
97181

98182
```python
99183
from torchnlp.metrics import get_moses_multi_bleu
@@ -131,8 +215,8 @@ AllenNLP is designed to be a platform for research. PyTorch-NLP is designed to b
131215

132216
## Authors
133217

134-
* [Michael Petrochuk](https://github.com/PetrochukM/) — Developer
135-
* [Chloe Yeo](http://www.yeochloe.com/) — Logo Design
218+
- [Michael Petrochuk](https://github.com/PetrochukM/) — Developer
219+
- [Chloe Yeo](http://www.yeochloe.com/) — Logo Design
136220

137221
## Citing
138222

build_tools/travis/install.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ pip install -r requirements.txt --progress-bar off
2727
pip install spacy --progress-bar off
2828
pip install nltk --progress-bar off
2929
pip install sacremoses --progress-bar off
30+
pip install pandas --progress-bar off
31+
pip install requests --progress-bar off
3032

3133
# SpaCy English web model
3234
python -m spacy download en

build_tools/travis/test_script.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ if [[ "$RUN_FLAKE8" == "true" ]]; then
2121
fi
2222

2323
run_tests() {
24-
TEST_CMD="python -m pytest tests/ torchnlp/ --verbose --durations=20 --cov=torchnlp --doctest-modules"
24+
TEST_CMD="python -m pytest tests/ torchnlp/ -c /dev/null --verbose --durations=10 --cov=torchnlp"
2525
if [[ "$RUN_SLOW" == "true" ]]; then
2626
TEST_CMD="$TEST_CMD --runslow"
2727
fi
28+
if [[ "$RUN_DOCTEXT" == "true" ]]; then
29+
TEST_CMD="$TEST_CMD --doctest-modules"
30+
fi
2831
$TEST_CMD
2932
}
3033

examples/snli/train.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from functools import partial
22

3+
import glob
4+
import itertools
35
import os
46
import time
5-
import glob
67

78
from torch.utils.data import DataLoader
9+
from torch.utils.data.sampler import SequentialSampler
810

911
import torch
1012
import torch.optim as optim
1113
import torch.nn as nn
1214

1315
from torchnlp.samplers import BucketBatchSampler
1416
from torchnlp.datasets import snli_dataset
15-
from torchnlp.utils import datasets_iterator
1617
from torchnlp.encoders.text import WhitespaceEncoder
1718
from torchnlp.encoders import LabelEncoder
1819
from torchnlp import word_to_vector
@@ -29,20 +30,20 @@
2930
train, dev, test = snli_dataset(train=True, dev=True, test=True)
3031

3132
# Preprocess
32-
for row in datasets_iterator(train, dev, test):
33+
for row in itertools.chain(train, dev, test):
3334
row['premise'] = row['premise'].lower()
3435
row['hypothesis'] = row['hypothesis'].lower()
3536

3637
# Make Encoders
37-
sentence_corpus = [row['premise'] for row in datasets_iterator(train, dev, test)]
38-
sentence_corpus += [row['hypothesis'] for row in datasets_iterator(train, dev, test)]
38+
sentence_corpus = [row['premise'] for row in itertools.chain(train, dev, test)]
39+
sentence_corpus += [row['hypothesis'] for row in itertools.chain(train, dev, test)]
3940
sentence_encoder = WhitespaceEncoder(sentence_corpus)
4041

41-
label_corpus = [row['label'] for row in datasets_iterator(train, dev, test)]
42+
label_corpus = [row['label'] for row in itertools.chain(train, dev, test)]
4243
label_encoder = LabelEncoder(label_corpus)
4344

4445
# Encode
45-
for row in datasets_iterator(train, dev, test):
46+
for row in itertools.chain(train, dev, test):
4647
row['premise'] = sentence_encoder.encode(row['premise'])
4748
row['hypothesis'] = sentence_encoder.encode(row['hypothesis'])
4849
row['label'] = label_encoder.encode(row['label'])
@@ -88,11 +89,12 @@
8889
for epoch in range(args.epochs):
8990
n_correct, n_total = 0, 0
9091

91-
train_sampler = BucketBatchSampler(
92-
train, args.batch_size, True, sort_key=lambda r: len(row['premise']))
92+
train_sampler = SequentialSampler(train)
93+
train_batch_sampler = BucketBatchSampler(
94+
train_sampler, args.batch_size, True, sort_key=lambda r: len(row['premise']))
9395
train_iterator = DataLoader(
9496
train,
95-
batch_sampler=train_sampler,
97+
batch_sampler=train_batch_sampler,
9698
collate_fn=collate_fn,
9799
pin_memory=torch.cuda.is_available(),
98100
num_workers=0)
@@ -139,11 +141,13 @@
139141

140142
# calculate accuracy on validation set
141143
n_dev_correct, dev_loss = 0, 0
142-
dev_sampler = BucketBatchSampler(
144+
145+
dev_sampler = SequentialSampler(train)
146+
dev_batch_sampler = BucketBatchSampler(
143147
dev, args.batch_size, True, sort_key=lambda r: len(row['premise']))
144148
dev_iterator = DataLoader(
145149
dev,
146-
batch_sampler=dev_sampler,
150+
batch_sampler=dev_batch_sampler,
147151
collate_fn=partial(collate_fn, train=False),
148152
pin_memory=torch.cuda.is_available(),
149153
num_workers=0)

examples/snli/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from torchnlp.encoders.text import pad_batch
7+
from torchnlp.encoders.text import stack_and_pad_tensors
88

99

1010
def makedirs(name):
@@ -55,8 +55,8 @@ def get_args():
5555

5656
def collate_fn(batch, train=True):
5757
""" list of tensors to a batch tensors """
58-
premise_batch, _ = pad_batch([row['premise'] for row in batch])
59-
hypothesis_batch, _ = pad_batch([row['hypothesis'] for row in batch])
58+
premise_batch, _ = stack_and_pad_tensors([row['premise'] for row in batch])
59+
hypothesis_batch, _ = stack_and_pad_tensors([row['hypothesis'] for row in batch])
6060
label_batch = torch.stack([row['label'] for row in batch])
6161

6262
# PyTorch RNN requires batches to be transposed for speed and integration with CUDA

0 commit comments

Comments
 (0)