-
Notifications
You must be signed in to change notification settings - Fork 79
Run fit-a-line demo with fault tolerant mode #278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
604c19a
d0880fd
1b1cd4d
11cf895
9ad8886
267a10d
93969de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import paddle.v2 as paddle | ||
import os | ||
import gzip | ||
from paddle.v2.reader.creator import cloud_reader | ||
import paddle.v2.dataset.uci_housing as uci_housing | ||
|
||
etcd_ip = os.getenv("ETCD_IP") | ||
etcd_endpoint = "http://" + etcd_ip + ":" + "2379" | ||
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID")) | ||
|
||
def main(): | ||
# init | ||
paddle.init() | ||
|
||
# network config | ||
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) | ||
y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear()) | ||
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) | ||
cost = paddle.layer.mse_cost(input=y_predict, label=y) | ||
|
||
# create parameters | ||
parameters = paddle.parameters.create(cost) | ||
|
||
# create optimizer | ||
optimizer = paddle.optimizer.Momentum(momentum=0) | ||
|
||
trainer = paddle.trainer.SGD( | ||
cost=cost, | ||
parameters=parameters, | ||
update_equation=optimizer, | ||
is_local=False, | ||
pserver_spec=etcd_endpoint, | ||
use_etcd=True) | ||
|
||
feeding = {'x': 0, 'y': 1} | ||
|
||
# event_handler to print training and testing info | ||
def event_handler(event): | ||
if isinstance(event, paddle.event.EndIteration): | ||
if event.batch_id % 100 == 0: | ||
print "Pass %d, Batch %d, Cost %f" % ( | ||
event.pass_id, event.batch_id, event.cost) | ||
|
||
if isinstance(event, paddle.event.EndPass): | ||
result = trainer.test( | ||
reader=paddle.batch(uci_housing.test(), batch_size=2), | ||
feeding=feeding) | ||
print "Test %d, Cost %f" % (event.pass_id, result.cost) | ||
if trainer_id == "0": | ||
with gzip.open("fit-a-line_pass_%05d.tar.gz" % event.pass_id, | ||
"w") as f: | ||
parameters.to_tar(f) | ||
# training | ||
trainer.train( | ||
reader=paddle.batch( | ||
paddle.reader.shuffle(cloud_reader( | ||
["/pfs/dlnel/public/dataset/uci_housing/uci_housing_train-*"], | ||
etcd_endpoint), buf_size=500), | ||
batch_size=2), | ||
feeding=feeding, | ||
event_handler=event_handler, | ||
num_passes=30) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,15 +42,10 @@ def fetch_pserver_ips(): | |
return ",".join(pserver_ips) | ||
|
||
def fetch_master_ip(): | ||
while True: | ||
label_selector = "paddle-job-master=%s" % PADDLE_JOB_NAME | ||
pod_list = fetch_pods_info(label_selector) | ||
master_ip = "" | ||
if len(pod_list) >=1: | ||
master_ip = pod_list[0][1] | ||
if master_ip: | ||
return master_ip | ||
time.sleep(5) | ||
label_selector = "paddle-job-master=%s" % PADDLE_JOB_NAME | ||
pod_list = fetch_pods_info(label_selector) | ||
master_ips = [item[1] for item in pod_list] | ||
return master_ips[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May return None when master is still not ready? Curious why need to remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As the code:https://github.com/PaddlePaddle/cloud/pull/278/files#diff-548cc24bbda04b9e89570c52ae5df9c7R48, we waiting for the master pod until the state became RUNNING. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. |
||
|
||
def fetch_trainer_id(): | ||
label_selector = "paddle-job=%s" % PADDLE_JOB_NAME | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use
paddle.v2.reader.creator.recordio
here? (it wrapscloud_reader
, providing a uniform interface for local and cloud recordio file).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, I could not find
paddle.v2.reader.creator.recordio
, I remember we have it. Probably I got confused.