Skip to content

Commit ce09788

Browse files
Merge pull request #286 from wanghaoshuang/understand_sentiment
Add demo of understand_sentiment on paddlecloud
2 parents ccce04f + 185541c commit ce09788

File tree

1 file changed

+222
-0
lines changed

1 file changed

+222
-0
lines changed

demo/understand_sentiment/train.py

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.v2 as paddle
16+
import paddle.v2.dataset.common as common
17+
import os
18+
import sys
19+
import glob
20+
import pickle
21+
22+
# NOTE: must change this to your own username on paddlecloud.
23+
USERNAME = "demo"
24+
DC = os.getenv("PADDLE_CLOUD_CURRENT_DATACENTER")
25+
common.DATA_HOME = "/pfs/%s/home/%s" % (DC, USERNAME)
26+
TRAIN_FILES_PATH = os.path.join(common.DATA_HOME, "imdb")
27+
TEST_FILES_PATH = os.path.join(common.DATA_HOME, "imdb")
28+
29+
TRAINER_ID = int(os.getenv("PADDLE_INIT_TRAINER_ID", "-1"))
30+
TRAINER_COUNT = int(os.getenv("PADDLE_INIT_NUM_GRADIENT_SERVERS", "-1"))
31+
32+
def prepare_dataset():
33+
word_dict = paddle.dataset.imdb.word_dict()
34+
# convert will also split the dataset by line-count
35+
common.convert(TRAIN_FILES_PATH,
36+
lambda: paddle.dataset.imdb.train(word_dict),
37+
1000, "train")
38+
common.convert(TEST_FILES_PATH,
39+
lambda: paddle.dataset.imdb.test(word_dict),
40+
1000, "test")
41+
42+
def cluster_reader_recordio(trainer_id, trainer_count, flag):
43+
'''
44+
read from cloud dataset which is stored as recordio format
45+
each trainer will read a subset of files of the whole dataset.
46+
'''
47+
import recordio
48+
def reader():
49+
PATTERN_STR = "%s-*" % flag
50+
FILES_PATTERN = os.path.join(TRAIN_FILES_PATH, PATTERN_STR)
51+
file_list = glob.glob(FILES_PATTERN)
52+
file_list.sort()
53+
my_file_list = []
54+
# read files for current trainer_id
55+
for idx, f in enumerate(file_list):
56+
if idx % trainer_count == trainer_id:
57+
my_file_list.append(f)
58+
for f in my_file_list:
59+
print "processing ", f
60+
reader = recordio.reader(f)
61+
record_raw = reader.read()
62+
while record_raw:
63+
yield pickle.loads(record_raw)
64+
record_raw = reader.read()
65+
reader.close()
66+
return reader
67+
68+
69+
70+
def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128):
71+
data = paddle.layer.data("word",
72+
paddle.data_type.integer_value_sequence(input_dim))
73+
emb = paddle.layer.embedding(input=data, size=emb_dim)
74+
conv_3 = paddle.networks.sequence_conv_pool(
75+
input=emb, context_len=3, hidden_size=hid_dim)
76+
conv_4 = paddle.networks.sequence_conv_pool(
77+
input=emb, context_len=4, hidden_size=hid_dim)
78+
output = paddle.layer.fc(
79+
input=[conv_3, conv_4], size=class_dim, act=paddle.activation.Softmax())
80+
lbl = paddle.layer.data("label", paddle.data_type.integer_value(2))
81+
cost = paddle.layer.classification_cost(input=output, label=lbl)
82+
return cost
83+
84+
85+
def stacked_lstm_net(input_dim,
86+
class_dim=2,
87+
emb_dim=128,
88+
hid_dim=512,
89+
stacked_num=3):
90+
"""
91+
A Wrapper for sentiment classification task.
92+
This network uses bi-directional recurrent network,
93+
consisting three LSTM layers. This configure is referred to
94+
the paper as following url, but use fewer layrs.
95+
http://www.aclweb.org/anthology/P15-1109
96+
97+
input_dim: here is word dictionary dimension.
98+
class_dim: number of categories.
99+
emb_dim: dimension of word embedding.
100+
hid_dim: dimension of hidden layer.
101+
stacked_num: number of stacked lstm-hidden layer.
102+
"""
103+
assert stacked_num % 2 == 1
104+
105+
layer_attr = paddle.attr.Extra(drop_rate=0.5)
106+
fc_para_attr = paddle.attr.Param(learning_rate=1e-3)
107+
lstm_para_attr = paddle.attr.Param(initial_std=0., learning_rate=1.)
108+
para_attr = [fc_para_attr, lstm_para_attr]
109+
bias_attr = paddle.attr.Param(initial_std=0., l2_rate=0.)
110+
relu = paddle.activation.Relu()
111+
linear = paddle.activation.Linear()
112+
113+
data = paddle.layer.data("word",
114+
paddle.data_type.integer_value_sequence(input_dim))
115+
emb = paddle.layer.embedding(input=data, size=emb_dim)
116+
117+
fc1 = paddle.layer.fc(
118+
input=emb, size=hid_dim, act=linear, bias_attr=bias_attr)
119+
lstm1 = paddle.layer.lstmemory(
120+
input=fc1, act=relu, bias_attr=bias_attr, layer_attr=layer_attr)
121+
122+
inputs = [fc1, lstm1]
123+
for i in range(2, stacked_num + 1):
124+
fc = paddle.layer.fc(
125+
input=inputs,
126+
size=hid_dim,
127+
act=linear,
128+
param_attr=para_attr,
129+
bias_attr=bias_attr)
130+
lstm = paddle.layer.lstmemory(
131+
input=fc,
132+
reverse=(i % 2) == 0,
133+
act=relu,
134+
bias_attr=bias_attr,
135+
layer_attr=layer_attr)
136+
inputs = [fc, lstm]
137+
138+
fc_last = paddle.layer.pooling(
139+
input=inputs[0], pooling_type=paddle.pooling.Max())
140+
lstm_last = paddle.layer.pooling(
141+
input=inputs[1], pooling_type=paddle.pooling.Max())
142+
output = paddle.layer.fc(
143+
input=[fc_last, lstm_last],
144+
size=class_dim,
145+
act=paddle.activation.Softmax(),
146+
bias_attr=bias_attr,
147+
param_attr=para_attr)
148+
149+
lbl = paddle.layer.data("label", paddle.data_type.integer_value(2))
150+
cost = paddle.layer.classification_cost(input=output, label=lbl)
151+
return cost
152+
153+
154+
def main():
155+
# init
156+
paddle.init()
157+
#data
158+
print 'load dictionary...'
159+
word_dict = paddle.dataset.imdb.word_dict()
160+
dict_dim = len(word_dict)
161+
class_dim = 2
162+
train_reader = paddle.batch(
163+
paddle.reader.shuffle(
164+
cluster_reader_recordio(TRAINER_ID, TRAINER_COUNT, "train"), buf_size=1000),
165+
batch_size=100)
166+
test_reader = paddle.batch(
167+
cluster_reader_recordio(TRAINER_ID, TRAINER_COUNT, "test"), batch_size=100)
168+
169+
feeding = {'word': 0, 'label': 1}
170+
171+
# network config
172+
# Please choose the way to build the network
173+
# by uncommenting the corresponding line.
174+
cost = convolution_net(dict_dim, class_dim=class_dim)
175+
# cost = stacked_lstm_net(dict_dim, class_dim=class_dim, stacked_num=3)
176+
177+
# create parameters
178+
parameters = paddle.parameters.create(cost)
179+
180+
# create optimizer
181+
adam_optimizer = paddle.optimizer.Adam(
182+
learning_rate=2e-3,
183+
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
184+
model_average=paddle.optimizer.ModelAverage(average_window=0.5))
185+
186+
# End batch and end pass event handler
187+
def event_handler(event):
188+
if isinstance(event, paddle.event.EndIteration):
189+
if event.batch_id % 100 == 0:
190+
print "\nPass %d, Batch %d, Cost %f, %s" % (
191+
event.pass_id, event.batch_id, event.cost, event.metrics)
192+
else:
193+
sys.stdout.write('.')
194+
sys.stdout.flush()
195+
if isinstance(event, paddle.event.EndPass):
196+
result = trainer.test(reader=test_reader, feeding=feeding)
197+
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
198+
199+
# create trainer
200+
trainer = paddle.trainer.SGD(
201+
cost=cost, parameters=parameters, update_equation=adam_optimizer)
202+
203+
trainer.train(
204+
reader=train_reader,
205+
event_handler=event_handler,
206+
feeding=feeding,
207+
num_passes=2)
208+
209+
if __name__ == '__main__':
210+
usage = "python train.py [prepare|train]"
211+
if len(sys.argv) != 2:
212+
print usage
213+
exit(1)
214+
215+
if TRAINER_ID == -1 or TRAINER_COUNT == -1:
216+
print "no cloud environ found, must run on cloud"
217+
exit(1)
218+
219+
if sys.argv[1] == "prepare":
220+
prepare_dataset()
221+
elif sys.argv[1] == "train":
222+
main()

0 commit comments

Comments
 (0)