@@ -23,6 +23,7 @@ template <typename T>
23
23
class LarsMomentumOpKernel : public framework ::OpKernel<T> {
24
24
public:
25
25
void Compute (const framework::ExecutionContext& ctx) const override {
26
+ const bool merge_operation = ctx.Attr <bool >(" merge_operation" );
26
27
auto param_out = ctx.MultiOutput <framework::LoDTensor>(" ParamOut" );
27
28
auto velocity_out = ctx.MultiOutput <framework::LoDTensor>(" VelocityOut" );
28
29
auto param = ctx.MultiInput <framework::LoDTensor>(" Param" );
@@ -38,39 +39,74 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
38
39
framework::ToTypeName (grad_var[0 ]->Type ())));
39
40
auto grad = ctx.MultiInput <framework::LoDTensor>(" Grad" );
40
41
41
- param_out[0 ]->mutable_data <T>(ctx.GetPlace ());
42
- velocity_out[0 ]->mutable_data <T>(ctx.GetPlace ());
43
-
44
42
T mu = static_cast <T>(ctx.Attr <float >(" mu" ));
45
43
T lars_coeff = ctx.Attr <float >(" lars_coeff" );
46
- T lars_weight_decay = ctx.Attr <std::vector<float >>(" lars_weight_decay" )[0 ];
47
44
T epsilon = ctx.Attr <float >(" epsilon" );
48
45
49
- auto p_out = framework::EigenVector<T>::Flatten (*(param_out[0 ]));
50
- auto v_out = framework::EigenVector<T>::Flatten (*(velocity_out[0 ]));
46
+ if (!merge_operation) {
47
+ auto * lr = learning_rate[0 ]->data <T>();
48
+ T lars_weight_decay =
49
+ ctx.Attr <std::vector<float >>(" lars_weight_decay" )[0 ];
50
+ param_out[0 ]->mutable_data <T>(ctx.GetPlace ());
51
+ velocity_out[0 ]->mutable_data <T>(ctx.GetPlace ());
52
+
53
+ auto p_out = framework::EigenVector<T>::Flatten (*(param_out[0 ]));
54
+ auto v_out = framework::EigenVector<T>::Flatten (*(velocity_out[0 ]));
55
+ auto p = framework::EigenVector<T>::Flatten (*(param[0 ]));
56
+ auto v = framework::EigenVector<T>::Flatten (*(velocity[0 ]));
57
+ auto g = framework::EigenVector<T>::Flatten (*(grad[0 ]));
58
+
59
+ framework::Tensor p_norm_t , g_norm_t ;
60
+ p_norm_t .Resize ({1 });
61
+ g_norm_t .Resize ({1 });
62
+ p_norm_t .mutable_data <T>(ctx.GetPlace ());
63
+ g_norm_t .mutable_data <T>(ctx.GetPlace ());
64
+ auto ep_norm = framework::EigenScalar<T>::From (p_norm_t );
65
+ auto eg_norm = framework::EigenScalar<T>::From (g_norm_t );
66
+ ep_norm = p.square ().sum ().sqrt ();
67
+ eg_norm = g.square ().sum ().sqrt ();
68
+
69
+ T local_lr = lr[0 ];
70
+ if (lars_weight_decay > 0 && ep_norm (0 ) > 0 && eg_norm (0 ) > 0 ) {
71
+ local_lr = lr[0 ] * lars_coeff * ep_norm (0 ) /
72
+ (eg_norm (0 ) + lars_weight_decay * ep_norm (0 ) + epsilon);
73
+ }
74
+ v_out = v * mu + local_lr * (g + lars_weight_decay * p);
75
+ p_out = p - v_out;
76
+ } else {
77
+ int op_num = param.size ();
78
+ auto weight_decay_arr = ctx.Attr <std::vector<float >>(" lars_weight_decay" );
79
+ for (int i = 0 ; i < op_num; ++i) {
80
+ auto * lr = learning_rate[i]->data <T>();
81
+ T lars_weight_decay = weight_decay_arr[i];
82
+ param_out[i]->mutable_data <T>(ctx.GetPlace ());
83
+ velocity_out[i]->mutable_data <T>(ctx.GetPlace ());
51
84
52
- auto p = framework::EigenVector<T>::Flatten (*(param[0 ]));
53
- auto v = framework::EigenVector<T>::Flatten (*(velocity[0 ]));
54
- auto g = framework::EigenVector<T>::Flatten (*(grad[0 ]));
55
- auto * lr = learning_rate[0 ]->data <T>();
85
+ auto p_out = framework::EigenVector<T>::Flatten (*(param_out[i]));
86
+ auto v_out = framework::EigenVector<T>::Flatten (*(velocity_out[i]));
87
+ auto p = framework::EigenVector<T>::Flatten (*(param[i]));
88
+ auto v = framework::EigenVector<T>::Flatten (*(velocity[i]));
89
+ auto g = framework::EigenVector<T>::Flatten (*(grad[i]));
56
90
57
- framework::Tensor p_norm_t , g_norm_t ;
58
- p_norm_t .Resize ({1 });
59
- g_norm_t .Resize ({1 });
60
- p_norm_t .mutable_data <T>(ctx.GetPlace ());
61
- g_norm_t .mutable_data <T>(ctx.GetPlace ());
62
- auto ep_norm = framework::EigenScalar<T>::From (p_norm_t );
63
- auto eg_norm = framework::EigenScalar<T>::From (g_norm_t );
91
+ framework::Tensor p_norm_t , g_norm_t ;
92
+ p_norm_t .Resize ({1 });
93
+ g_norm_t .Resize ({1 });
94
+ p_norm_t .mutable_data <T>(ctx.GetPlace ());
95
+ g_norm_t .mutable_data <T>(ctx.GetPlace ());
96
+ auto ep_norm = framework::EigenScalar<T>::From (p_norm_t );
97
+ auto eg_norm = framework::EigenScalar<T>::From (g_norm_t );
98
+ ep_norm = p.square ().sum ().sqrt ();
99
+ eg_norm = g.square ().sum ().sqrt ();
64
100
65
- ep_norm = p.square ().sum ().sqrt ();
66
- eg_norm = g.square ().sum ().sqrt ();
67
- T local_lr = lr[0 ];
68
- if (lars_weight_decay > 0 && ep_norm (0 ) > 0 && eg_norm (0 ) > 0 ) {
69
- local_lr = lr[0 ] * lars_coeff * ep_norm (0 ) /
70
- (eg_norm (0 ) + lars_weight_decay * ep_norm (0 ) + epsilon);
101
+ T local_lr = lr[0 ];
102
+ if (lars_weight_decay > 0 && ep_norm (0 ) > 0 && eg_norm (0 ) > 0 ) {
103
+ local_lr = lr[0 ] * lars_coeff * ep_norm (0 ) /
104
+ (eg_norm (0 ) + lars_weight_decay * ep_norm (0 ) + epsilon);
105
+ }
106
+ v_out = v * mu + local_lr * (g + lars_weight_decay * p);
107
+ p_out = p - v_out;
108
+ }
71
109
}
72
- v_out = v * mu + local_lr * (g + lars_weight_decay * p);
73
- p_out = p - v_out;
74
110
}
75
111
};
76
112
0 commit comments