-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add API: Switch global program #5260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -538,5 +538,20 @@ def __init__(self, block, shape, dtype, **kwargs): | |
|
||
|
||
# program is a global instance. | ||
g_program = Program() | ||
g_init_program = Program() | ||
g_program_dict = dict() | ||
|
||
|
||
def switch_g_program(prog, init_prog): | ||
g_program_dict['program'] = prog | ||
g_program_dict['init_program'] = init_prog | ||
|
||
|
||
def g_program(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. g_program => main_program
|
||
return g_program_dict['program'] | ||
|
||
|
||
def g_init_program(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. g_init_program => startup_program
|
||
return g_program_dict['init_program'] | ||
|
||
|
||
switch_g_program(Program(), Program()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,24 @@ | ||
import numpy as np | ||
import paddle.v2.framework.core as core | ||
|
||
import paddle.v2 as paddle | ||
import paddle.v2.framework.layers as layers | ||
import paddle.v2.framework.core as core | ||
import paddle.v2.framework.optimizer as optimizer | ||
|
||
from paddle.v2.framework.framework import Program, g_program | ||
from paddle.v2.framework.io import save_persistables, load_persistables | ||
from paddle.v2.framework.executor import Executor | ||
|
||
import numpy as np | ||
from paddle.v2.framework.framework import Program, switch_g_program | ||
from paddle.v2.framework.io import save_persistables, load_persistables | ||
|
||
init_program = Program() | ||
program = Program() | ||
x = layers.data( | ||
name='x', | ||
shape=[13], | ||
data_type='float32', | ||
program=program, | ||
init_program=init_program) | ||
switch_g_program(program, init_program) | ||
x = layers.data(name='x', shape=[13], data_type='float32') | ||
|
||
y_predict = layers.fc(input=x, | ||
size=1, | ||
act=None, | ||
program=program, | ||
init_program=init_program) | ||
y_predict = layers.fc(input=x, size=1, act=None) | ||
|
||
y = layers.data( | ||
name='y', | ||
shape=[1], | ||
data_type='float32', | ||
program=program, | ||
init_program=init_program) | ||
y = layers.data(name='y', shape=[1], data_type='float32') | ||
|
||
cost = layers.square_error_cost( | ||
input=y_predict, label=y, program=program, init_program=init_program) | ||
avg_cost = layers.mean(x=cost, program=program, init_program=init_program) | ||
cost = layers.square_error_cost(input=y_predict, label=y) | ||
avg_cost = layers.mean(x=cost) | ||
|
||
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) | ||
opts = sgd_optimizer.minimize(avg_cost) | ||
|
@@ -52,8 +37,8 @@ | |
|
||
PASS_NUM = 100 | ||
for pass_id in range(PASS_NUM): | ||
save_persistables(exe, "./fit_a_line.model/", program=program) | ||
load_persistables(exe, "./fit_a_line.model/", program=program) | ||
save_persistables(exe, "./fit_a_line.model/") | ||
load_persistables(exe, "./fit_a_line.model/") | ||
for data in train_reader(): | ||
x_data = np.array(map(lambda x: x[0], data)).astype("float32") | ||
y_data = np.array(map(lambda x: x[1], data)).astype("float32") | ||
|
@@ -71,6 +56,7 @@ | |
fetch_list=[avg_cost]) | ||
out = np.array(outs[0]) | ||
|
||
print out | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove print? |
||
if out[0] < 10.0: | ||
exit(0) # if avg cost less than 10.0, we think our code is good. | ||
exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this API supposed to exposed to PaddlePaddle users? It might be dangerous to expose the fact that we are having two programs -- the main one and the initializer -- to the user at this moment.
I think it is reasonable to have the main program and the initializer program -- the former is like the
main
function in C/C++, and the latter the C/C++ runtime entry point that initializes the global variables. It is just that we might expose only the main program to the users.