21
21
import csv
22
22
23
23
from paddlehub .dataset import InputExample , BaseDataset
24
+ from paddlehub .common .logger import logger
24
25
25
26
26
27
class BaseNLPDataset (BaseDataset ):
@@ -51,6 +52,7 @@ def __init__(self,
51
52
52
53
def _read_file (self , input_file , phase = None ):
53
54
"""Reads a tab separated value file."""
55
+ has_warned = False
54
56
with io .open (input_file , "r" , encoding = "UTF-8" ) as file :
55
57
reader = csv .reader (file , delimiter = "\t " , quotechar = None )
56
58
examples = []
@@ -59,22 +61,38 @@ def _read_file(self, input_file, phase=None):
59
61
ncol = len (line )
60
62
if self .if_file_with_header [phase ]:
61
63
continue
62
- if ncol == 1 :
63
- if phase == "predict" :
64
- example = InputExample (guid = i , text_a = line [0 ])
65
- else :
64
+ if phase != "predict" :
65
+ if ncol == 1 :
66
66
raise Exception (
67
67
"the %s file: %s only has one column but it is not a predict file"
68
68
% (phase , input_file ))
69
- elif ncol == 2 :
70
- example = InputExample (
71
- guid = i , text_a = line [0 ], label = line [1 ])
72
- elif ncol == 3 :
73
- example = InputExample (
74
- guid = i , text_a = line [0 ], text_b = line [1 ], label = line [2 ])
69
+ elif ncol == 2 :
70
+ example = InputExample (
71
+ guid = i , text_a = line [0 ], label = line [1 ])
72
+ elif ncol == 3 :
73
+ example = InputExample (
74
+ guid = i ,
75
+ text_a = line [0 ],
76
+ text_b = line [1 ],
77
+ label = line [2 ])
78
+ else :
79
+ raise Exception (
80
+ "the %s file: %s has too many columns (should <=3)"
81
+ % (phase , input_file ))
75
82
else :
76
- raise Exception (
77
- "the %s file: %s has too many columns (should <=3)" %
78
- (phase , input_file ))
83
+ if ncol == 1 :
84
+ example = InputExample (guid = i , text_a = line [0 ])
85
+ elif ncol == 2 :
86
+ if not has_warned :
87
+ logger .warning (
88
+ "the predict file: %s has 2 columns, as it is a predict file, the second one will be regarded as text_b"
89
+ % (input_file ))
90
+ has_warned = True
91
+ example = InputExample (
92
+ guid = i , text_a = line [0 ], text_b = line [1 ])
93
+ else :
94
+ raise Exception (
95
+ "the predict file: %s has too many columns (should <=2)"
96
+ % (input_file ))
79
97
examples .append (example )
80
98
return examples
0 commit comments