25
25
26
26
void BuildProgram (pir::Builder &builder) { // NOLINT
27
27
paddle::dialect::FullOp full_input_op1 =
28
- builder.Build <paddle::dialect::FullOp>(std::vector< int64_t >{ 1 , 512 , 64 },
29
- 1.5 );
28
+ builder.Build <paddle::dialect::FullOp>(
29
+ std::vector< int64_t >{ 1 , 512 , 64 }, 1.5 , phi::DataType::FLOAT16 );
30
30
// linear 1
31
31
paddle::dialect::FullOp full_weight_op1 =
32
- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 , 64 }, 1.5 );
32
+ builder.Build <paddle::dialect::FullOp>(
33
+ std::vector<int64_t >{64 , 64 }, 1.5 , phi::DataType::FLOAT16);
33
34
paddle::dialect::FullOp full_bias_op1 =
34
- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 }, 1.0 );
35
+ builder.Build <paddle::dialect::FullOp>(
36
+ std::vector<int64_t >{64 }, 1.0 , phi::DataType::FLOAT16);
35
37
paddle::dialect::MatmulOp matmul_op1 =
36
38
builder.Build <paddle::dialect::MatmulOp>(full_input_op1.out (),
37
39
full_weight_op1.out ());
38
40
paddle::dialect::AddOp add_op1 = builder.Build <paddle::dialect::AddOp>(
39
41
matmul_op1.out (), full_bias_op1.out ());
40
42
// linear 2
41
43
paddle::dialect::FullOp full_weight_op2 =
42
- builder.Build <paddle::dialect::FullOp>(std::vector< int64_t >{ 64 , 128 },
43
- 1.5 );
44
+ builder.Build <paddle::dialect::FullOp>(
45
+ std::vector< int64_t >{ 64 , 128 }, 1.5 , phi::DataType::FLOAT16 );
44
46
paddle::dialect::FullOp full_bias_op2 =
45
- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{128 }, 1.0 );
47
+ builder.Build <paddle::dialect::FullOp>(
48
+ std::vector<int64_t >{128 }, 1.0 , phi::DataType::FLOAT16);
46
49
paddle::dialect::MatmulOp matmul_op2 =
47
50
builder.Build <paddle::dialect::MatmulOp>(add_op1.out (),
48
51
full_weight_op2.out ());
@@ -52,10 +55,11 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
52
55
builder.Build <paddle::dialect::ReluOp>(add_op2.out ());
53
56
// linear 3
54
57
paddle::dialect::FullOp full_weight_op3 =
55
- builder.Build <paddle::dialect::FullOp>(std::vector< int64_t >{ 128 , 64 },
56
- 1.5 );
58
+ builder.Build <paddle::dialect::FullOp>(
59
+ std::vector< int64_t >{ 128 , 64 }, 1.5 , phi::DataType::FLOAT16 );
57
60
paddle::dialect::FullOp full_bias_op3 =
58
- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 }, 1.0 );
61
+ builder.Build <paddle::dialect::FullOp>(
62
+ std::vector<int64_t >{64 }, 1.0 , phi::DataType::FLOAT16);
59
63
paddle::dialect::MatmulOp matmul_op3 =
60
64
builder.Build <paddle::dialect::MatmulOp>(relu_op.out (),
61
65
full_weight_op3.out ());
@@ -65,9 +69,11 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
65
69
builder.Build <paddle::dialect::GeluOp>(add_op3.out ());
66
70
// linear 4
67
71
paddle::dialect::FullOp full_weight_op4 =
68
- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 , 64 }, 1.5 );
72
+ builder.Build <paddle::dialect::FullOp>(
73
+ std::vector<int64_t >{64 , 64 }, 1.5 , phi::DataType::FLOAT16);
69
74
paddle::dialect::FullOp full_bias_op4 =
70
- builder.Build <paddle::dialect::FullOp>(std::vector<int64_t >{64 }, 1.0 );
75
+ builder.Build <paddle::dialect::FullOp>(
76
+ std::vector<int64_t >{64 }, 1.0 , phi::DataType::FLOAT16);
71
77
paddle::dialect::MatmulOp matmul_op4 =
72
78
builder.Build <paddle::dialect::MatmulOp>(gelu_op1.out (),
73
79
full_weight_op4.out ());
@@ -78,7 +84,7 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
78
84
79
85
// backward
80
86
paddle::dialect::FullOp full_grad_op = builder.Build <paddle::dialect::FullOp>(
81
- std::vector<int64_t >{1 , 512 , 64 }, 1.0 );
87
+ std::vector<int64_t >{1 , 512 , 64 }, 1.0 , phi::DataType::FLOAT16 );
82
88
83
89
paddle::dialect::GeluGradOp gelu_op2_grad =
84
90
builder.Build <paddle::dialect::GeluGradOp>(
0 commit comments