Skip to content

Commit b3f0f3d

Browse files
authored
Merge pull request #766 from qingqing01/sentiment
Support predicting the samples from sys.stdin
2 parents dad11db + c5c295d commit b3f0f3d

File tree

4 files changed

+68
-65
lines changed

4 files changed

+68
-65
lines changed

demo/sentiment/predict.py

+32-31
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
15+
import os, sys
1616
import numpy as np
1717
from optparse import OptionParser
1818
from py_paddle import swig_paddle, DataProviderConverter
@@ -66,35 +66,27 @@ def load_label(self, label_file):
6666
for v in open(label_file, 'r'):
6767
self.label[int(v.split('\t')[1])] = v.split('\t')[0]
6868

69-
def get_data(self, data_file):
69+
def get_index(self, data):
7070
"""
71-
Get input data of paddle format.
71+
transform word into integer index according to the dictionary.
7272
"""
73-
with open(data_file, 'r') as fdata:
74-
for line in fdata:
75-
words = line.strip().split()
76-
word_slot = [
77-
self.word_dict[w] for w in words if w in self.word_dict
78-
]
79-
if not word_slot:
80-
print "all words are not in dictionary: %s", line
81-
continue
82-
yield [word_slot]
73+
words = data.strip().split()
74+
word_slot = [
75+
self.word_dict[w] for w in words if w in self.word_dict
76+
]
77+
return word_slot
8378

84-
def predict(self, data_file):
85-
"""
86-
data_file: file name of input data.
87-
"""
88-
input = self.converter(self.get_data(data_file))
79+
def batch_predict(self, data_batch):
80+
input = self.converter(data_batch)
8981
output = self.network.forwardTest(input)
9082
prob = output[0]["value"]
91-
lab = np.argsort(-prob)
92-
if self.label is None:
93-
print("%s: predicting label is %d" % (data_file, lab[0][0]))
94-
else:
95-
print("%s: predicting label is %s" %
96-
(data_file, self.label[lab[0][0]]))
97-
83+
labs = np.argsort(-prob)
84+
for idx, lab in enumerate(labs):
85+
if self.label is None:
86+
print("predicting label is %d" % (lab[0]))
87+
else:
88+
print("predicting label is %s" %
89+
(self.label[lab[0]]))
9890

9991
def option_parser():
10092
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
@@ -119,11 +111,13 @@ def option_parser():
119111
default=None,
120112
help="dictionary file")
121113
parser.add_option(
122-
"-i",
123-
"--data",
114+
"-c",
115+
"--batch_size",
116+
type="int",
124117
action="store",
125-
dest="data",
126-
help="data file to predict")
118+
dest="batch_size",
119+
default=1,
120+
help="the batch size for prediction")
127121
parser.add_option(
128122
"-w",
129123
"--model",
@@ -137,14 +131,21 @@ def option_parser():
137131
def main():
138132
options, args = option_parser()
139133
train_conf = options.train_conf
140-
data = options.data
134+
batch_size = options.batch_size
141135
dict_file = options.dict_file
142136
model_path = options.model_path
143137
label = options.label
144138
swig_paddle.initPaddle("--use_gpu=0")
145139
predict = SentimentPrediction(train_conf, dict_file, model_path, label)
146-
predict.predict(data)
147140

141+
batch = []
142+
for line in sys.stdin:
143+
batch.append([predict.get_index(line)])
144+
if len(batch) == batch_size:
145+
predict.batch_predict(batch)
146+
batch=[]
147+
if len(batch) > 0:
148+
predict.batch_predict(batch)
148149

149150
if __name__ == '__main__':
150151
main()

demo/sentiment/predict.sh

+6-6
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ set -e
1919
model=model_output/pass-00002/
2020
config=trainer_config.py
2121
label=data/pre-imdb/labels.list
22-
python predict.py \
23-
-n $config\
24-
-w $model \
25-
-b $label \
26-
-d ./data/pre-imdb/dict.txt \
27-
-i ./data/aclImdb/test/pos/10007_10.txt
22+
cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
23+
--tconf=$config\
24+
--model=$model \
25+
--label=$label \
26+
--dict=./data/pre-imdb/dict.txt \
27+
--batch_size=1

doc/tutorials/sentiment_analysis/index_en.md

+15-14
Original file line numberDiff line numberDiff line change
@@ -293,20 +293,21 @@ predict.sh:
293293
model=model_output/pass-00002/
294294
config=trainer_config.py
295295
label=data/pre-imdb/labels.list
296-
python predict.py \
297-
-n $config\
298-
-w $model \
299-
-b $label \
300-
-d data/pre-imdb/dict.txt \
301-
-i data/aclImdb/test/pos/10007_10.txt
302-
```
303-
304-
* `predict.py`: predicting interface.
305-
* -n $config : set network configure.
306-
* -w $model: set model path.
307-
* -b $label: set dictionary about corresponding relation between integer label and string label.
308-
* -d data/pre-imdb/dict.txt: set dictionary.
309-
* -i data/aclImdb/test/pos/10014_7.txt: set one example file to predict.
296+
cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
297+
--tconf=$config\
298+
--model=$model \
299+
--label=$label \
300+
--dict=./data/pre-imdb/dict.txt \
301+
--batch_size=1
302+
```
303+
304+
* `cat ./data/aclImdb/test/pos/10007_10.txt` : the input sample.
305+
* `predict.py` : predicting interface.
306+
* `--tconf=$config` : set network configure.
307+
* ` --model=$model` : set model path.
308+
* `--label=$label` : set dictionary about corresponding relation between integer label and string label.
309+
* `--dict=data/pre-imdb/dict.txt` : set dictionary.
310+
* `--batch_size=1` : set batch size.
310311

311312
Note you should make sure the default model path `model_output/pass-00002`
312313
exists or change the model path.

doc_cn/demo/sentiment_analysis/sentiment_analysis.md

+15-14
Original file line numberDiff line numberDiff line change
@@ -291,20 +291,21 @@ predict.sh:
291291
model=model_output/pass-00002/
292292
config=trainer_config.py
293293
label=data/pre-imdb/labels.list
294-
python predict.py \
295-
-n $config\
296-
-w $model \
297-
-b $label \
298-
-d data/pre-imdb/dict.txt \
299-
-i data/aclImdb/test/pos/10007_10.txt
300-
```
301-
302-
* `predict.py`: 预测接口脚本。
303-
* -n $config : 设置网络配置。
304-
* -w $model: 设置模型路径。
305-
* -b $label: 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。
306-
* -d data/pre-imdb/dict.txt: 设置字典文件。
307-
* -i data/aclImdb/test/pos/10014_7.txt: 设置一个要预测的示例文件。
294+
cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
295+
--tconf=$config\
296+
--model=$model \
297+
--label=$label \
298+
--dict=./data/pre-imdb/dict.txt \
299+
--batch_size=1
300+
```
301+
302+
* `cat ./data/aclImdb/test/pos/10007_10.txt` : 输入预测样本。
303+
* `predict.py` : 预测接口脚本。
304+
* `--tconf=$config` : 设置网络配置。
305+
* `--model=$model` : 设置模型路径。
306+
* `--label=$label` : 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。
307+
* `--dict=data/pre-imdb/dict.txt` : 设置字典文件。
308+
* `--batch_size=1` : 设置batch size。
308309

309310
注意应该确保默认模型路径`model_output / pass-00002`存在或更改为其它模型路径。
310311

0 commit comments

Comments
 (0)