21
21
import csv
22
22
23
23
from paddlehub .dataset import InputExample , BaseDataset
24
+ from paddlehub .common .logger import logger
24
25
25
26
26
- class BaseNLPDatast (BaseDataset ):
27
+ class BaseNLPDataset (BaseDataset ):
27
28
def __init__ (self ,
28
29
base_path ,
29
30
train_file = None ,
@@ -32,49 +33,66 @@ def __init__(self,
32
33
predict_file = None ,
33
34
label_file = None ,
34
35
label_list = None ,
35
- train_file_with_head = False ,
36
- dev_file_with_head = False ,
37
- test_file_with_head = False ,
38
- predict_file_with_head = False ):
39
- super (BaseNLPDatast , self ).__init__ (
36
+ train_file_with_header = False ,
37
+ dev_file_with_header = False ,
38
+ test_file_with_header = False ,
39
+ predict_file_with_header = False ):
40
+ super (BaseNLPDataset , self ).__init__ (
40
41
base_path = base_path ,
41
42
train_file = train_file ,
42
43
dev_file = dev_file ,
43
44
test_file = test_file ,
44
45
predict_file = predict_file ,
45
46
label_file = label_file ,
46
47
label_list = label_list ,
47
- train_file_with_head = train_file_with_head ,
48
- dev_file_with_head = dev_file_with_head ,
49
- test_file_with_head = test_file_with_head ,
50
- predict_file_with_head = predict_file_with_head )
48
+ train_file_with_header = train_file_with_header ,
49
+ dev_file_with_header = dev_file_with_header ,
50
+ test_file_with_header = test_file_with_header ,
51
+ predict_file_with_header = predict_file_with_header )
51
52
52
53
def _read_file (self , input_file , phase = None ):
53
54
"""Reads a tab separated value file."""
55
+ has_warned = False
54
56
with io .open (input_file , "r" , encoding = "UTF-8" ) as file :
55
57
reader = csv .reader (file , delimiter = "\t " , quotechar = None )
56
58
examples = []
57
59
for (i , line ) in enumerate (reader ):
58
60
if i == 0 :
59
61
ncol = len (line )
60
- if self .if_file_with_head [phase ]:
62
+ if self .if_file_with_header [phase ]:
61
63
continue
62
- if ncol == 1 :
63
- if phase != "predict" :
64
- example = InputExample (guid = i , text_a = line [0 ])
65
- else :
64
+ if phase != "predict" :
65
+ if ncol == 1 :
66
66
raise Exception (
67
67
"the %s file: %s only has one column but it is not a predict file"
68
68
% (phase , input_file ))
69
- elif ncol == 2 :
70
- example = InputExample (
71
- guid = i , text_a = line [0 ], label = line [1 ])
72
- elif ncol == 3 :
73
- example = InputExample (
74
- guid = i , text_a = line [0 ], text_b = line [1 ], label = line [2 ])
69
+ elif ncol == 2 :
70
+ example = InputExample (
71
+ guid = i , text_a = line [0 ], label = line [1 ])
72
+ elif ncol == 3 :
73
+ example = InputExample (
74
+ guid = i ,
75
+ text_a = line [0 ],
76
+ text_b = line [1 ],
77
+ label = line [2 ])
78
+ else :
79
+ raise Exception (
80
+ "the %s file: %s has too many columns (should <=3)"
81
+ % (phase , input_file ))
75
82
else :
76
- raise Exception (
77
- "the %s file: %s has too many columns (should <=3)" %
78
- (phase , input_file ))
83
+ if ncol == 1 :
84
+ example = InputExample (guid = i , text_a = line [0 ])
85
+ elif ncol == 2 :
86
+ if not has_warned :
87
+ logger .warning (
88
+ "the predict file: %s has 2 columns, as it is a predict file, the second one will be regarded as text_b"
89
+ % (input_file ))
90
+ has_warned = True
91
+ example = InputExample (
92
+ guid = i , text_a = line [0 ], text_b = line [1 ])
93
+ else :
94
+ raise Exception (
95
+ "the predict file: %s has too many columns (should <=2)"
96
+ % (input_file ))
79
97
examples .append (example )
80
98
return examples
0 commit comments