Skip to content

Commit 2429381

Browse files
TensorFlow Datasets Teamcopybara-github
authored andcommitted
Remove unnecessary <s> and </s> tokens. And add a test.
PiperOrigin-RevId: 246080224
1 parent dbda817 commit 2429381

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

tensorflow_datasets/text/cnn_dailymail.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,6 @@ def _subset_filenames(dl_paths, split):
145145

146146
DM_SINGLE_CLOSE_QUOTE = u'\u2019' # unicode
147147
DM_DOUBLE_CLOSE_QUOTE = u'\u201d'
148-
SENTENCE_START = '<s>'
149-
SENTENCE_END = '</s>'
150148
# acceptable ways to end a sentence
151149
END_TOKENS = ['.', '!', '?', '...', "'", '`', '"',
152150
DM_SINGLE_CLOSE_QUOTE, DM_DOUBLE_CLOSE_QUOTE, ')']
@@ -201,31 +199,31 @@ def fix_missing_period(line):
201199

202200
# Make abstract into a single string, putting <s> and </s> tags around
203201
# the sentences.
204-
abstract = ' '.join(['%s %s %s' % (SENTENCE_START, sent,
205-
SENTENCE_END) for sent in highlights])
202+
abstract = ' '.join(highlights)
206203

207204
return article, abstract
208205

209206

210207
class CnnDailymail(tfds.core.GeneratorBasedBuilder):
211208
"""CNN/DailyMail non-anonymized summarization dataset."""
209+
# 0.0.2 is like 0.0.1 but without special tokens <s> and </s>.
212210
BUILDER_CONFIGS = [
213211
CnnDailymailConfig(
214212
name='plain_text',
215-
version='0.0.1',
213+
version='0.0.2',
216214
description='Plain text',
217215
),
218216
CnnDailymailConfig(
219217
name='bytes',
220-
version='0.0.1',
218+
version='0.0.2',
221219
description=('Uses byte-level text encoding with '
222220
'`tfds.features.text.ByteTextEncoder`'),
223221
text_encoder_config=tfds.features.text.TextEncoderConfig(
224222
encoder=tfds.features.text.ByteTextEncoder()),
225223
),
226224
CnnDailymailConfig(
227225
name='subwords32k',
228-
version='0.0.1',
226+
version='0.0.2',
229227
description=('Uses `tfds.features.text.SubwordTextEncoder` with '
230228
'32k vocab size'),
231229
text_encoder_config=tfds.features.text.TextEncoderConfig(
@@ -260,8 +258,7 @@ def _split_generators(self, dl_manager):
260258
# Generate shared vocabulary
261259
# maybe_build_from_corpus uses SubwordTextEncoder if that's configured
262260
self.info.features[_ARTICLE].maybe_build_from_corpus(
263-
self._vocab_text_gen(train_files),
264-
reserved_tokens=[SENTENCE_START, SENTENCE_END])
261+
self._vocab_text_gen(train_files))
265262
encoder = self.info.features[_ARTICLE].encoder
266263
# Use maybe_set_encoder because the encoder may have been restored from
267264
# package data.

tensorflow_datasets/text/cnn_dailymail_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,29 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import tempfile
22+
2123
import tensorflow_datasets.testing as tfds_test
2224
from tensorflow_datasets.text import cnn_dailymail
2325

2426

27+
_STORY_FILE = b"""Some article.
28+
This is some article text.
29+
30+
@highlight
31+
32+
highlight text
33+
34+
@highlight
35+
36+
Highlight two
37+
38+
@highlight
39+
40+
highlight Three
41+
"""
42+
43+
2544
class CnnDailymailTest(tfds_test.DatasetBuilderTestCase):
2645
DATASET_CLASS = cnn_dailymail.CnnDailymail
2746
SPLITS = {
@@ -35,6 +54,19 @@ class CnnDailymailTest(tfds_test.DatasetBuilderTestCase):
3554
'train_urls': 'all_train.txt',
3655
'val_urls': 'all_val.txt'}
3756

57+
def test_get_art_abs(self):
58+
with tempfile.NamedTemporaryFile(delete=True) as f:
59+
f.write(_STORY_FILE)
60+
f.flush()
61+
article, abstract = cnn_dailymail._get_art_abs(f.name)
62+
self.assertEqual('some article. this is some article text.',
63+
article)
64+
# This is a bit weird, but the original code at
65+
# https://github.com/abisee/cnn-dailymail/ adds space before period
66+
# for abstracts and we retain this behavior.
67+
self.assertEqual('highlight text . highlight two . highlight three .',
68+
abstract)
69+
3870

3971
if __name__ == '__main__':
4072
tfds_test.test_main()

0 commit comments

Comments
 (0)