Skip to content

Commit a1224fe

Browse files
committed
enhance predict dataset
1 parent 0f352ca commit a1224fe

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

paddlehub/dataset/base_nlp_dataset.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import csv
2222

2323
from paddlehub.dataset import InputExample, BaseDataset
24+
from paddlehub.common.logger import logger
2425

2526

2627
class BaseNLPDataset(BaseDataset):
@@ -51,6 +52,7 @@ def __init__(self,
5152

5253
def _read_file(self, input_file, phase=None):
5354
"""Reads a tab separated value file."""
55+
has_warned = False
5456
with io.open(input_file, "r", encoding="UTF-8") as file:
5557
reader = csv.reader(file, delimiter="\t", quotechar=None)
5658
examples = []
@@ -59,22 +61,38 @@ def _read_file(self, input_file, phase=None):
5961
ncol = len(line)
6062
if self.if_file_with_header[phase]:
6163
continue
62-
if ncol == 1:
63-
if phase == "predict":
64-
example = InputExample(guid=i, text_a=line[0])
65-
else:
64+
if phase != "predict":
65+
if ncol == 1:
6666
raise Exception(
6767
"the %s file: %s only has one column but it is not a predict file"
6868
% (phase, input_file))
69-
elif ncol == 2:
70-
example = InputExample(
71-
guid=i, text_a=line[0], label=line[1])
72-
elif ncol == 3:
73-
example = InputExample(
74-
guid=i, text_a=line[0], text_b=line[1], label=line[2])
69+
elif ncol == 2:
70+
example = InputExample(
71+
guid=i, text_a=line[0], label=line[1])
72+
elif ncol == 3:
73+
example = InputExample(
74+
guid=i,
75+
text_a=line[0],
76+
text_b=line[1],
77+
label=line[2])
78+
else:
79+
raise Exception(
80+
"the %s file: %s has too many columns (should <=3)"
81+
% (phase, input_file))
7582
else:
76-
raise Exception(
77-
"the %s file: %s has too many columns (should <=3)" %
78-
(phase, input_file))
83+
if ncol == 1:
84+
example = InputExample(guid=i, text_a=line[0])
85+
elif ncol == 2:
86+
if not has_warned:
87+
logger.warning(
88+
"the predict file: %s has 2 columns, as it is a predict file, the second one will be regarded as text_b"
89+
% (input_file))
90+
has_warned = True
91+
example = InputExample(
92+
guid=i, text_a=line[0], text_b=line[1])
93+
else:
94+
raise Exception(
95+
"the predict file: %s has too many columns (should <=2)"
96+
% (input_file))
7997
examples.append(example)
8098
return examples

0 commit comments

Comments
 (0)