File tree 4 files changed +13
-3
lines changed
4 files changed +13
-3
lines changed Original file line number Diff line number Diff line change 14
14
15
15
runner :
16
16
sync_mode : " geo" # 可选, string: sync/async/geo
17
- with_coodinator : 1
17
+ # with_coodinator: 1
18
18
geo_step : 100 # 可选, int, 在geo模式下控制本地的迭代次数
19
19
split_file_list : True # 可选, bool, 若每个节点上都拥有全量数据,则需设置为True
20
20
thread_num : 1 # 多线程配置
@@ -39,6 +39,13 @@ runner:
39
39
infer_load_path : " output_model_ncf"
40
40
infer_start_epoch : 2
41
41
infer_end_epoch : 3
42
+
43
+ need_dump : True
44
+ dump_fields_path : " /home/wangbin/the_one_ps/ziyoujiyi_PaddleRec/PaddleRec/models/recall/ncf"
45
+ dump_fields : ['item_input', 'user_input']
46
+ dump_param : []
47
+ local_sparse : ['embedding_0.w_0']
48
+ remote_sparse : ['embedding_1.w_0']
42
49
43
50
hyper_parameters :
44
51
optimizer :
Original file line number Diff line number Diff line change 10
10
* 在 PaddleRec/datasets/movielens_pinterest_NCF/fl_data 中新建目录 fl_test_data 和 fl_train_data,用于存放每个 client 上的训练数据集和测试数据集
11
11
* 在 PaddleRec/datasets/movielens_pinterest_NCF/fl_data 目录中执行: python gen_heter_data.py,生成 10 份训练数据
12
12
* 总样本数 4970844(按 1:4 补充负样本):0 - 518095,1 - 520165,2 - 373605,3 - 315550,4 - 483779,5 - 495635,6 - 402810,7 - 354590,8 - 262710,9 - 1243905
13
+ * 样本数据每一行表示:物品 id,用户 id,标签
13
14
14
15
# 3、运行命令
15
16
1 . 不带 coordinator 版本
Original file line number Diff line number Diff line change @@ -111,7 +111,7 @@ def init_network(self):
111
111
self .model = get_model (self .config )
112
112
self .input_data = self .model .create_feeds ()
113
113
self .metrics = self .model .net (self .input_data )
114
- self .model .create_optimizer (get_strategy (self .config ))
114
+ self .model .create_optimizer (get_strategy (self .config )) ## get_strategy
115
115
if self .pure_bf16 :
116
116
self .model .optimizer .amp_init (self .place )
117
117
Original file line number Diff line number Diff line change @@ -63,7 +63,9 @@ def get_strategy(config):
63
63
"dump_fields_path" : config .get ("runner.dump_fields_path" , "" ),
64
64
"dump_fields" : config .get ("runner.dump_fields" , []),
65
65
"dump_param" : config .get ("runner.dump_param" , []),
66
- "stat_var_names" : config .get ("stat_var_names" , [])
66
+ "stat_var_names" : config .get ("stat_var_names" , []),
67
+ "local_sparse" : config .get ("runner.local_sparse" , []),
68
+ "remote_sparse" : config .get ("runner.remote_sparse" , [])
67
69
}
68
70
print ("strategy:" , strategy .trainer_desc_configs )
69
71
You can’t perform that action at this time.
0 commit comments