@@ -45,7 +45,6 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
45
45
f"Training iteration { solver .global_step + 1 } "
46
46
) # Training iteration
47
47
48
- total_loss = 0.0
49
48
total_batch_size = 0
50
49
reader_cost = 0.0
51
50
batch_cost = 0.0
@@ -106,31 +105,30 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
106
105
if solver .nvtx_flag : # only for nsight analysis
107
106
core .nvprof_nvtx_push ("Loss aggregator" )
108
107
108
+ total_loss = solver .loss_aggregator (
109
+ constraint_losses , solver .global_step
110
+ )
111
+ if solver .update_freq > 1 :
112
+ total_loss = total_loss / solver .update_freq
113
+
109
114
for i , _constraint in enumerate (solver .constraint .values ()):
110
- total_loss += constraint_losses [i ]
111
- loss_dict [_constraint .name ] += (
115
+ loss_dict [_constraint .name ] = (
112
116
float (constraint_losses [i ]) / solver .update_freq
113
117
)
114
- if solver .update_freq > 1 :
115
- total_loss = total_loss / solver .update_freq
118
+ loss_dict ["loss" ] = float (total_loss )
116
119
117
120
if solver .nvtx_flag : # only for nsight analysis
118
121
core .nvprof_nvtx_pop () # Loss aggregator
119
122
120
- loss_dict ["loss" ] = float (total_loss )
121
-
122
123
# backward
123
124
if solver .nvtx_flag : # only for nsight analysis
124
125
core .nvprof_nvtx_push ("Loss backward" )
125
126
126
- if solver .loss_aggregator is None :
127
- if solver .use_amp :
128
- total_loss_scaled = solver .scaler .scale (total_loss )
129
- total_loss_scaled .backward ()
130
- else :
131
- total_loss .backward ()
127
+ if solver .use_amp :
128
+ total_loss_scaled = solver .scaler .scale (total_loss )
129
+ total_loss_scaled .backward ()
132
130
else :
133
- solver . loss_aggregator ( constraint_losses , solver . global_step ) .backward ()
131
+ total_loss .backward ()
134
132
135
133
if solver .nvtx_flag : # only for nsight analysis
136
134
core .nvprof_nvtx_pop () # Loss backward
@@ -233,7 +231,6 @@ def closure() -> paddle.Tensor:
233
231
Returns:
234
232
paddle.Tensor: Computed loss scalar.
235
233
"""
236
- total_loss = 0
237
234
with solver .no_sync_context_manager (solver .world_size > 1 , solver .model ):
238
235
with solver .autocast_context_manager (solver .use_amp , solver .amp_level ):
239
236
# forward for every constraint, including model and equation expression
@@ -248,20 +245,18 @@ def closure() -> paddle.Tensor:
248
245
label_dicts ,
249
246
weight_dicts ,
250
247
)
248
+
249
+ total_loss = solver .loss_aggregator (
250
+ constraint_losses , solver .global_step
251
+ )
251
252
# accumulate all losses
252
253
for i , _constraint in enumerate (solver .constraint .values ()):
253
- total_loss += constraint_losses [i ]
254
254
loss_dict [_constraint .name ] = float (constraint_losses [i ])
255
255
loss_dict ["loss" ] = float (total_loss )
256
256
257
257
# backward
258
258
solver .optimizer .clear_grad ()
259
- if solver .loss_aggregator is None :
260
- total_loss .backward ()
261
- else :
262
- solver .loss_aggregator (
263
- constraint_losses , solver .global_step
264
- ).backward ()
259
+ total_loss .backward ()
265
260
266
261
if solver .world_size > 1 :
267
262
# fuse + allreduce manually before optimization if use DDP model
0 commit comments