@@ -161,6 +161,7 @@ def run_evaluation(self, output_dir, current_step):
161161 eval_df = self .evaluator .run_with_ground_truth ()
162162
163163 if dist .get_rank () == 0 and eval_df is not None :
164+ eval_df = eval_df .sort_values ("__row_idx__" ).reset_index (drop = True )
164165 eval_df .to_csv (
165166 f"{ output_dir } /evaluation_results_step_{ current_step } .csv" ,
166167 index = False ,
@@ -173,67 +174,95 @@ def run(self, output_dir):
173174 """Run the DRO Trainer for the specified number of optimization steps."""
174175 log_df = pd .DataFrame ()
175176
176- logger .info (
177- f"Breaking down the training dataset into { len (self .dataloader )} batches."
178- )
177+ self ._log_dataset_info ()
179178
180179 current_step = 0
181180 while current_step < self .total_optimization_steps :
182181 for batch_number , batch in enumerate (iter (self .dataloader )):
183- self .value .train ()
184- self .policy .train ()
182+ current_step = self ._train_step (batch , current_step , batch_number )
183+ self ._log_step (log_df , output_dir , current_step , batch_number )
184+ dist .barrier ()
185185
186- self .policy_optimizer .zero_grad ()
187- self .value_optimizer .zero_grad ()
186+ if self ._should_checkpoint (current_step ):
187+ self ._maybe_save_models (output_dir , current_step )
188+ self ._maybe_run_evaluation (output_dir , current_step )
188189
189- tokenized_batch = self .collator (batch )
190+ if current_step >= self .total_optimization_steps :
191+ break
190192
191- policy_loss , value_loss = self .model_optimizer .calculate_loss (
192- tokenized_batch
193- )
193+ self ._final_save (output_dir )
194+ return log_df
194195
195- value_loss .backward ()
196- policy_loss .backward ()
196+ def _log_dataset_info (self ):
197+ dl = self .dataloader
198+ world = (
199+ dist .get_world_size ()
200+ if dist .is_available () and dist .is_initialized ()
201+ else 1
202+ )
203+ sampler = getattr (dl , "sampler" , None )
197204
198- self . policy_optimizer . step ( )
199- self . value_optimizer . step ( )
205+ per_rank_samples = len ( sampler ) if sampler is not None else len ( dl . dataset )
206+ per_rank_batches = len ( dl )
200207
201- current_step += 1
208+ logger .info (
209+ f"Per-rank: { per_rank_samples } samples → { per_rank_batches } batches "
210+ f"(batch size={ dl .batch_size } , drop_last={ dl .drop_last } ); "
211+ f"Global: world_size={ world } , effective batch size={ dl .batch_size * world } , "
212+ f"batches/epoch={ per_rank_batches * world } ."
213+ )
202214
203- logger .info (
204- f"Step { current_step } , Batch { batch_number + 1 } : "
205- f"Policy Loss: { policy_loss .item ():.4f} , "
206- f"Value Loss: { value_loss .item ():.4f} "
207- )
215+ def _train_step (self , batch , current_step , batch_number ):
216+ """Perform a single training step on the provided batch."""
217+ self .value .train ()
218+ self .policy .train ()
208219
209- if dist .get_rank () == 0 :
210- step_log_df = pd .DataFrame .from_dict (
211- {
212- "step" : [current_step ],
213- "policy_loss" : [policy_loss .item ()],
214- "value_loss" : [value_loss .item ()],
215- }
216- )
220+ self .policy_optimizer .zero_grad ()
221+ self .value_optimizer .zero_grad ()
217222
218- log_df = pd .concat ([log_df , step_log_df ])
219- log_df .to_csv (f"{ output_dir } /dro_trainer_log.csv" , index = False )
220- dist .barrier ()
223+ tokenized_batch = self .collator (batch )
221224
222- if ( current_step % self .check_point_freq == 0 ) and ( current_step > 0 ):
225+ policy_loss , value_loss = self .model_optimizer . calculate_loss ( tokenized_batch )
223226
224- if self .config ["trainer" ].get ("save_models" , True ):
225- if dist .get_rank () == 0 :
226- self .save_models (output_dir , current_step )
227- dist .barrier ()
227+ value_loss .backward ()
228+ policy_loss .backward ()
228229
229- # Run online evaluation if configured
230- if self .config ["trainer" ].get ("evaluate_during_training" , False ):
231- self .run_evaluation (output_dir , current_step )
230+ self .policy_optimizer .step ()
231+ self .value_optimizer .step ()
232232
233- if current_step >= self .total_optimization_steps :
234- break
233+ logger .info (
234+ f"Step { current_step + 1 } , Batch { batch_number + 1 } : Policy Loss: { policy_loss .item ():.4f} , Value Loss: { value_loss .item ():.4f} "
235+ )
235236
236- # Final save after training completes
237- self .policy . module . save ( output_dir / "models/final" )
237+ self . _last_policy_loss = policy_loss
238+ self ._last_value_loss = value_loss
238239
239- return log_df
240+ return current_step + 1
241+
242+ def _log_step (self , log_df , output_dir , current_step , batch_number ):
243+ if dist .get_rank () == 0 :
244+ step_log_df = pd .DataFrame .from_dict (
245+ {
246+ "step" : [current_step ],
247+ "policy_loss" : [self ._last_policy_loss .item ()],
248+ "value_loss" : [self ._last_value_loss .item ()],
249+ }
250+ )
251+ log_df = pd .concat ([log_df , step_log_df ])
252+ log_df .to_csv (f"{ output_dir } /dro_trainer_log.csv" , index = False )
253+
254+ def _should_checkpoint (self , current_step ):
255+ return (current_step % self .check_point_freq == 0 ) and (current_step > 0 )
256+
257+ def _maybe_save_models (self , output_dir , current_step ):
258+ if self .config ["trainer" ].get ("save_models" , True ):
259+ if dist .get_rank () == 0 :
260+ self .save_models (output_dir , current_step )
261+ dist .barrier ()
262+
263+ def _maybe_run_evaluation (self , output_dir , current_step ):
264+ if self .config ["trainer" ].get ("evaluate_during_training" , False ):
265+ self .run_evaluation (output_dir , current_step )
266+
267+ def _final_save (self , output_dir ):
268+ self .policy .module .save (output_dir / "models/final" )
0 commit comments