Skip to content

Commit beeb3d9

Browse files
pd: support dpa2 (#4418)
Support DPA-2 in paddle backend. This PR will be updated after #4414 is merged. ### Training curve: ![training_curves_comparison_dpa2](https://github.com/user-attachments/assets/29bdeffa-cf2d-4586-afcf-7df0569997c3) ### Accuracy test(left: paddle, right: torch): ![image](https://github.com/user-attachments/assets/5bff55f3-1c39-4b95-93f0-68783e794716) Ralated optimization of Paddle framework: - [x] PaddlePaddle/Paddle#69349 - [x] PaddlePaddle/Paddle#69333 - [x] PaddlePaddle/Paddle#69479 - [x] PaddlePaddle/Paddle#69515 - [x] PaddlePaddle/Paddle#69487 - [x] PaddlePaddle/Paddle#69661 - [x] PaddlePaddle/Paddle#69660 - [x] PaddlePaddle/Paddle#69596 - [x] PaddlePaddle/Paddle#69556 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced new classes for molecular descriptors: `DescrptDPA2`, `DescrptBlockRepformers`, `DescrptSeTTebd`, and `DescrptBlockSeTTebd`. - Added new functions for tensor operations and descriptor management, enhancing the capabilities of the module. - Updated JSON configurations for multitask models to refine selection criteria and data paths. - **Bug Fixes** - Improved error handling and parameter validation across various descriptor classes. - **Documentation** - Enhanced test coverage for new descriptor functionalities and configurations. - **Tests** - Added new test classes to validate the functionality of `DescrptDPA2` and multitask training scenarios. - Expanded test capabilities for descriptor classes based on installed dependencies. - Updated existing tests to support new configurations and functionalities. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f9f1759 commit beeb3d9

29 files changed

+4987
-63
lines changed

deepmd/pd/model/descriptor/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,34 @@
99
DescrptBlockSeAtten,
1010
DescrptDPA1,
1111
)
12+
from .dpa2 import (
13+
DescrptDPA2,
14+
)
1215
from .env_mat import (
1316
prod_env_mat,
1417
)
18+
from .repformers import (
19+
DescrptBlockRepformers,
20+
)
1521
from .se_a import (
1622
DescrptBlockSeA,
1723
DescrptSeA,
1824
)
25+
from .se_t_tebd import (
26+
DescrptBlockSeTTebd,
27+
DescrptSeTTebd,
28+
)
1929

2030
__all__ = [
2131
"BaseDescriptor",
2232
"DescriptorBlock",
33+
"DescrptBlockRepformers",
2334
"DescrptBlockSeA",
2435
"DescrptBlockSeAtten",
36+
"DescrptBlockSeTTebd",
2537
"DescrptDPA1",
38+
"DescrptDPA2",
2639
"DescrptSeA",
40+
"DescrptSeTTebd",
2741
"prod_env_mat",
2842
]

0 commit comments

Comments
 (0)