@@ -1118,6 +1118,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
1118
1118
1119
1119
def _split_table_grad_and_add_send_vars (self , program , pserver_endpoints ):
1120
1120
# 2. add split_ids_op and send_op to send gradient to pservers
1121
+
1121
1122
# there should only be one table_name
1122
1123
all_ops = program .global_block ().ops
1123
1124
table_grad_name = grad_var_name (self .table_name )
@@ -1142,7 +1143,7 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
1142
1143
if self .sync_mode else []
1143
1144
},
1144
1145
attrs = {
1145
- "sync_mode" : self .sync_mode ,
1146
+ "sync_mode" : not self .sync_mode ,
1146
1147
"epmap" : pserver_endpoints ,
1147
1148
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE ,
1148
1149
OP_ROLE_VAR_ATTR_NAME : [
@@ -1188,7 +1189,15 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
1188
1189
def _create_table_optimize_block (self , pserver_index , pserver_program ,
1189
1190
pre_block_idx , grad_to_block_id ):
1190
1191
# STEP: create table optimize block
1192
+ table_opt_block = pserver_program ._create_block (pre_block_idx )
1191
1193
# create table param and grad var in pserver program
1194
+ # create table optimize block in pserver program
1195
+ table_opt_op = [
1196
+ op for op in self .optimize_ops
1197
+ if 'Param' in op .input_names and op .input ("Param" )[0 ] ==
1198
+ self .table_name
1199
+ ][0 ]
1200
+
1192
1201
origin_param_var = self .origin_program .global_block ().vars [
1193
1202
self .table_name ]
1194
1203
@@ -1204,19 +1213,16 @@ def _create_table_optimize_block(self, pserver_index, pserver_program,
1204
1213
dtype = origin_param_var .dtype ,
1205
1214
type = core .VarDesc .VarType .SELECTED_ROWS ,
1206
1215
persistable = True )
1216
+
1207
1217
# parameter must be selected rows
1208
1218
param_var .desc .set_type (core .VarDesc .VarType .SELECTED_ROWS )
1209
1219
grad_var = pserver_program .global_block ()._clone_variable (
1210
1220
self .origin_program .global_block ().vars [grad_var_name (
1211
1221
self .table_name )])
1212
1222
1213
- # create table optimize block in pserver program
1214
- table_opt_op = [
1215
- op for op in self .optimize_ops
1216
- if 'Param' in op .input_names and op .input ("Param" )[0 ] ==
1217
- self .table_name
1218
- ][0 ]
1219
- table_opt_block = pserver_program ._create_block (pre_block_idx )
1223
+ lr_var = pserver_program .global_block ()._clone_variable (
1224
+ self .origin_program .global_block ().vars [table_opt_op .input (
1225
+ "LearningRate" )[0 ]])
1220
1226
1221
1227
if self .sync_mode :
1222
1228
# create grad vars in pserver program
@@ -1248,8 +1254,6 @@ def _create_table_optimize_block(self, pserver_index, pserver_program,
1248
1254
grad_var = pserver_program .global_block ()._rename_var (
1249
1255
origin_grad_name , splited_grad_name )
1250
1256
1251
- lr_var = pserver_program .global_block ().vars [table_opt_op .input (
1252
- "LearningRate" )[0 ]]
1253
1257
inputs = {
1254
1258
"Param" : [param_var ],
1255
1259
"Grad" : [grad_var ],
0 commit comments