Skip to content

Conversation

YuanRisheng
Copy link
Collaborator

@YuanRisheng YuanRisheng commented Jul 31, 2025

新loader支持加载ernie模型:
当前缺少量化,不方便测300B模型,以21B模型h20测试结果如下:

旧版loader 旧版loader 新版loader
单线程内存占用 80G 8G
单线程加载耗时 46秒 18秒
4线程内存占用 320G 32G
4线程加载耗时 34秒 35秒

新Loader相比旧Loader节省90%左右的内存占用,在单卡条件下加载性能提升60%,多卡加载性能基本持平

Copy link

paddle-bot bot commented Jul 31, 2025

Thanks for your contribution!

@YuanRisheng YuanRisheng changed the title New Loader Support 0.3B New Loader Support Ernie Aug 8, 2025
for expert_id in range(self.fd_config.model_config.moe_num_experts)
for param_name, weight_name, shard_id in param_name_maping
]
expert_params_mapping.append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是没必要专门加 直接在 general_params_mapping加是不是就行了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改用默认Loader加载,但是这个感觉放在加在这里比较合理,因为这个函数返回的是moe用到的map

Comment on lines +197 to +200
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[-1]
else:
size = loaded_weight.get_shape()[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会有ndarray和tensor两种情况吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会有ndarray和pyslice俩种情况

param.copy_(loaded_weight, False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk.
if self.fd_config.model_config.pretrained_config.tensor_parallel_degree != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断用类成员变量就可以吧?self.nranks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
if loaded_shard_id is None:
# Loaded weight is already fused on disk
if self.fd_config.model_config.pretrained_config.tensor_parallel_degree != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@yuanlehome yuanlehome changed the title New Loader Support Ernie [V1 Loader] Support Ernie text Aug 11, 2025
@yuanlehome
Copy link
Collaborator

ernie有纯文/多模/moe/非moe,这个pr都支持了吗?如果不是,pr描述和标题需要改的准确些~

@@ -428,6 +429,80 @@ def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]
else:
self.lm_head.load_state_dict(state_dict)

def make_expert_params_mapping(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要不改一下这个类方法支持fused的权重 构造 mapping 这样就不用这个模型单独写自己的专家mapping了
f8b488cbf1c667d8fa58a358dad5dc6a

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我觉得每个model管理自己的专家映射更合理,统一写在moe.py里不太合理

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已按照review建议修改

for param_name, weight_name, shard_id in param_name_maping
]
expert_params_mapping.append(
("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, "gate_bias")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个gate_correciton_bias可以放到general_params_mapping下面,一开始设计的时候 是打算 fusemoe只包含 gate up down 这3个权重的,gate_correction_bias本来打算拆出来

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general_params_mapping放的是moe和dense模型共用到的一些映射,gate_correction_bias在dense里也会存在吗,应该不是吧

@YuanRisheng YuanRisheng changed the title [V1 Loader] Support Ernie text [V1 Loader] Support Ernie text(moe and dense) Aug 14, 2025
@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 09c979f into PaddlePaddle:develop Aug 14, 2025
11 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants