Skip to content

Commit 0f2192e

Browse files
committed
fix little error
1 parent d23b74f commit 0f2192e

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

models/recall/ncf/config_fl.yaml

+8-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
runner:
1616
sync_mode: "geo" # 可选, string: sync/async/geo
17-
with_coodinator: 1
17+
#with_coodinator: 1
1818
geo_step: 100 # 可选, int, 在geo模式下控制本地的迭代次数
1919
split_file_list: True # 可选, bool, 若每个节点上都拥有全量数据,则需设置为True
2020
thread_num: 1 # 多线程配置
@@ -39,6 +39,13 @@ runner:
3939
infer_load_path: "output_model_ncf"
4040
infer_start_epoch: 2
4141
infer_end_epoch: 3
42+
43+
need_dump: True
44+
dump_fields_path: "/home/wangbin/the_one_ps/ziyoujiyi_PaddleRec/PaddleRec/models/recall/ncf"
45+
dump_fields: ['item_input', 'user_input']
46+
dump_param: []
47+
local_sparse: ['embedding_0.w_0']
48+
remote_sparse: ['embedding_1.w_0']
4249

4350
hyper_parameters:
4451
optimizer:

models/recall/ncf/fl_ps_help.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
* 在 PaddleRec/datasets/movielens_pinterest_NCF/fl_data 中新建目录 fl_test_data 和 fl_train_data,用于存放每个 client 上的训练数据集和测试数据集
1111
* 在 PaddleRec/datasets/movielens_pinterest_NCF/fl_data 目录中执行: python gen_heter_data.py,生成 10 份训练数据
1212
* 总样本数 4970844(按 1:4 补充负样本):0 - 518095,1 - 520165,2 - 373605,3 - 315550,4 - 483779,5 - 495635,6 - 402810,7 - 354590,8 - 262710,9 - 1243905
13+
* 样本数据每一行表示:物品 id,用户 id,标签
1314

1415
# 3、运行命令
1516
1. 不带 coordinator 版本

tools/static_fl_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def init_network(self):
111111
self.model = get_model(self.config)
112112
self.input_data = self.model.create_feeds()
113113
self.metrics = self.model.net(self.input_data)
114-
self.model.create_optimizer(get_strategy(self.config))
114+
self.model.create_optimizer(get_strategy(self.config)) ## get_strategy
115115
if self.pure_bf16:
116116
self.model.optimizer.amp_init(self.place)
117117

tools/utils/static_ps/program_helper.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def get_strategy(config):
6363
"dump_fields_path": config.get("runner.dump_fields_path", ""),
6464
"dump_fields": config.get("runner.dump_fields", []),
6565
"dump_param": config.get("runner.dump_param", []),
66-
"stat_var_names": config.get("stat_var_names", [])
66+
"stat_var_names": config.get("stat_var_names", []),
67+
"local_sparse": config.get("runner.local_sparse", []),
68+
"remote_sparse": config.get("runner.remote_sparse", [])
6769
}
6870
print("strategy:", strategy.trainer_desc_configs)
6971

0 commit comments

Comments
 (0)